aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--nvim/lua/plugins/treesitter.lua38
-rw-r--r--nvim/lua/shared.lua46
2 files changed, 76 insertions, 8 deletions
diff --git a/nvim/lua/plugins/treesitter.lua b/nvim/lua/plugins/treesitter.lua
index a218e13..2f0f654 100644
--- a/nvim/lua/plugins/treesitter.lua
+++ b/nvim/lua/plugins/treesitter.lua
@@ -1,12 +1,34 @@
return {
'nvim-treesitter/nvim-treesitter',
- build = ':TSUpdate',
- opts = {
- ensure_installed = { "help", "lua", "python", "javascript", "typescript", "c", "latex", "vim", "go" },
- auto_install = true,
- highlight = {
- enable = true,
- additional_vim_regex_highlighting = false
+ lazy = false,
+ build = function()
+ local nvim_treesitter = require("nvim-treesitter")
+ local util = require("nvim.lua.shared")
+
+ local installed = util.to_set(nvim_treesitter.get_installed())
+ local ensure_installed = util.to_set {
+ "markdown", "markdown_inline",
+ "vimdoc",
+ "c",
+ "make",
+ "rust",
+ "go", "gowork", "gotmpl",
+ "json", "yaml", "toml",
+ "python",
+ "html",
+ "javascript",
+ "julia", "r",
+ "bash",
+ "latex",
+ "sql",
}
- }
+
+ local diff = util.set_difference(ensure_installed, installed)
+ if #diff > 0 then
+ -- NOTE: potential race condition because install is async
+ nvim_treesitter.install(util.keys(diff))
+ end
+
+ nvim_treesitter.update({ summary = true })
+ end,
}
diff --git a/nvim/lua/shared.lua b/nvim/lua/shared.lua
index 907159e..a1cf289 100644
--- a/nvim/lua/shared.lua
+++ b/nvim/lua/shared.lua
@@ -100,4 +100,50 @@ function M.vert_shift_selection(amount)
vim.cmd("normal! gv")
end
+--- Build a set from a list
+--- @generic T
+--- @param list T[]
+--- @return table<T, boolean>
+function M.to_set(list)
+ --- @type table<T, boolean>
+ local set = {}
+
+ for _, v in ipairs(list) do
+ set[v] = true
+ end
+
+ return set
+end
+
+--- Computes A \ B
+--- @generic T
+--- @param A table<T, boolean>
+--- @param B table<T, boolean>
+--- @return table<T, boolean>
+function M.set_difference(A, B)
+ --- @type table<T, boolean>
+ local diff = {}
+
+ for k, v in pairs(A) do
+ if v and not B[k] then
+ diff[k] = true
+ end
+ end
+
+ return diff
+end
+
+--- Get keys of table
+--- @generic T
+--- @param tbl table<T, any>
+--- @return T[]
+function M.keys(tbl)
+ local keys = {}
+ for k, _ in pairs(tbl) do
+ table.insert(keys, k)
+ end
+
+ return keys
+end
+
return M