aboutsummaryrefslogtreecommitdiffstats
path: root/nvim/lua/shared.lua
blob: a1cf2894a02862377123157488b4c8e802ea2bc6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
local M = {}

-- pcall a vim command
function M.cmd_pcall(string)
    local success, msg = pcall(function() vim.cmd(string) end)
    if not success then
        msg = string.match(tostring(msg), "E%d+.*")
        vim.api.nvim_echo({ { msg } }, true, { err = true })
    end
end

--- Get the positions of the current visual selection using (1,0) indexing
--- @param opts {reset_mark: boolean?, reselect_text: boolean?}
--- @return { start: integer[], stop:integer[]}
M.get_visual_range = function(opts)
    local old_start = {}
    local old_stop = {}
    if opts.reset_mark then
        old_start = vim.api.nvim_buf_get_mark(0, "<")
        old_stop = vim.api.nvim_buf_get_mark(0, ">")
    end
    -- return to normal mode to correctly set '< and '> marks
    vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes("<ESC>", true, false, true), 'x', false)
    local start_sel = vim.api.nvim_buf_get_mark(0, "<")
    local stop_sel = vim.api.nvim_buf_get_mark(0, ">")

    if opts.reselect_text then
        vim.cmd("normal! gv")
    end

    if opts.reset_mark then
        vim.api.nvim_buf_set_mark(0, "<", old_start[1], old_start[2], {})
        vim.api.nvim_buf_set_mark(0, ">", old_stop[1], old_stop[2], {})
    end

    return {
        start = start_sel,
        stop = stop_sel,
    }
end

--- Get the lines of a visual selection
--- @return { lines : string[], range : { start : integer, stop : integer } }
M.get_visual = function()
    local range = M.get_visual_range({})
    local lines = vim.api.nvim_buf_get_lines(0, range.start[1] - 1, range.stop[1], false)

    return { lines = lines, range = range }
end

-- send lines from the current buffer to quickfix
--- @param lines string[]
--- @param line_nums integer[]
function M.send_to_qf(lines, line_nums)
    local bufnr = vim.api.nvim_get_current_buf()
    local qf_items = {}
    for i, line in ipairs(lines) do
        if line:match("^%s*$") == nil then
            table.insert(qf_items, { bufnr = bufnr, lnum = line_nums[i], text = line })
        end
    end
    vim.fn.setqflist(qf_items, 'r')
end

-- Shift the visual selection by amount, clamping to buffer size
--- @param amount integer
function M.vert_shift_selection(amount)
    local init = M.get_visual_range({})
    local target = {
        start = { init.start[1] + amount, init.start[2] },
        stop = { init.stop[1] + amount, init.stop[2] }
    }

    local linecount = vim.api.nvim_buf_line_count(0)
    if target.start[1] < 1 then
        target.start[1] = 1
        target.stop[1] = 1 + (init.stop[1] - init.start[1])
    elseif target.stop[1] > linecount then
        target.stop[1] = linecount
        target.start[1] = linecount - (target.stop[1] - target.start[1])
    end

    -- HACK: should be shifted into target calculation
    if target.stop[1] > linecount or target.start[1] < 1 then
        return
    end

    -- shift lines
    if amount < 0 then
        vim.cmd("'<,'>mo '<" .. amount - 1)
    elseif amount > 0 then
        vim.cmd("'<,'>mo '>+" .. amount)
    end

    -- update marks
    local bufnr = vim.api.nvim_get_current_buf()
    vim.api.nvim_buf_set_mark(bufnr, '<', target.start[1], target.start[2], {})
    vim.api.nvim_buf_set_mark(bufnr, '>', target.stop[1], target.stop[2], {})

    vim.cmd("normal! gv")
end

--- Build a set from a list
--- @generic T
--- @param list T[]
--- @return table<T, boolean>
function M.to_set(list)
    --- @type table<T, boolean>
    local set = {}

    for _, v in ipairs(list) do
        set[v] = true
    end

    return set
end

--- Computes A \ B
--- @generic T
--- @param A table<T, boolean>
--- @param B table<T, boolean>
--- @return table<T, boolean>
function M.set_difference(A, B)
    --- @type table<T, boolean>
    local diff = {}

    for k, v in pairs(A) do
        if v and not B[k] then
            diff[k] = true
        end
    end

    return diff
end

--- Get keys of table
--- @generic T
--- @param tbl table<T, any>
--- @return T[]
function M.keys(tbl)
    local keys = {}
    for k, _ in pairs(tbl) do
        table.insert(keys, k)
    end

    return keys
end

return M