refactor(diffsplit): clarify diff side API

This commit is contained in:
2026-05-29 14:22:24 +02:00
parent f0ae3fc656
commit 193616e87d
6 changed files with 262 additions and 55 deletions
+116 -30
View File
@@ -6,13 +6,15 @@ local util = require("git.core.util")
local M = {}
---@class ow.Git.Diffsplit.OpenOpts
---@field target string?
---@field other string?
---@field layout ("vertical"|"horizontal")?
---@field mods vim.api.keyset.cmd.mods?
---@field focus ("current"|"other")?
---@param cur_buf integer
---@return string? target
---@return string? other
---@return string? err
local function infer_target(cur_buf)
local function infer_other(cur_buf)
local cur_name = vim.api.nvim_buf_get_name(cur_buf)
local cur_rev = object.parse_uri(cur_name)
if cur_rev then
@@ -48,16 +50,16 @@ local function infer_target(cur_buf)
return object.format_uri(Revision.new({ stage = 0, path = rel })), nil
end
---@param target string
---@param other string
---@param cur_buf integer
---@return string? resolved
---@return string? err
local function resolve_target(target, cur_buf)
if vim.startswith(target, object.URI_PREFIX) then
return target, nil
local function resolve_other(other, cur_buf)
if vim.startswith(other, object.URI_PREFIX) then
return other, nil
end
if vim.fn.filereadable(target) == 1 then
return target, nil
if vim.fn.filereadable(other) == 1 then
return other, nil
end
local cur_name = vim.api.nvim_buf_get_name(cur_buf)
local cur_rev = object.parse_uri(cur_name)
@@ -78,57 +80,141 @@ local function resolve_target(target, cur_buf)
if not rel then
return nil, "current buffer has no path"
end
if not r:rev_parse(target, true) then
return nil, "invalid rev: " .. target
if not r:rev_parse(other, true) then
return nil, "invalid rev: " .. other
end
return object.format_uri(Revision.new({ base = target, path = rel })), nil
return object.format_uri(Revision.new({ base = other, path = rel })), nil
end
---@param cur_buf integer
---@param target string
---@param other string
---@return 'aboveleft'|'belowright'|nil
local function default_split(cur_buf, target)
local function default_split(cur_buf, other)
local cur_rev = object.parse_uri(vim.api.nvim_buf_get_name(cur_buf))
local target_rev = object.parse_uri(target)
if not cur_rev and target_rev then
local other_rev = object.parse_uri(other)
if not cur_rev and other_rev then
return "aboveleft"
end
if cur_rev and not target_rev then
if cur_rev and not other_rev then
return "belowright"
end
if cur_rev and target_rev then
if cur_rev.stage == 0 and target_rev.base then
if cur_rev and other_rev then
if cur_rev.stage == 0 and other_rev.base then
return "aboveleft"
end
if cur_rev.base and target_rev.stage == 0 then
if cur_rev.base and other_rev.stage == 0 then
return "belowright"
end
end
return nil
end
---@alias ow.Git.Diffsplit.Side string|integer
---@class ow.Git.Diffsplit.OpenPairOpts
---@field layout ("vertical"|"horizontal")?
---@field mods vim.api.keyset.cmd.mods?
---@field focus ("old"|"new")?
---@param mods vim.api.keyset.cmd.mods?
---@param layout ("vertical"|"horizontal")?
---@return vim.api.keyset.cmd.mods
local function layout_mods(mods, layout)
mods = vim.tbl_extend("force", {}, mods or {})
if mods.vertical == nil then
mods.vertical = layout ~= "horizontal"
end
return mods
end
---@param side ow.Git.Diffsplit.Side
---@param cur_buf integer
---@return string? name
---@return integer? buf
---@return string? err
local function resolve_side(side, cur_buf)
if type(side) == "number" then
local name = vim.api.nvim_buf_get_name(side)
if name == "" then
return nil, nil, "diff side buffer has no name"
end
return name, side, nil
end
local name, err = resolve_other(side, cur_buf)
return name, nil, err
end
---@param side ow.Git.Diffsplit.Side
---@param cur_buf integer
---@return integer? buf
---@return string? err
local function buf_for_side(side, cur_buf)
local name, buf, err = resolve_side(side, cur_buf)
if not name then
return nil, err
end
if buf then
return buf, nil
end
buf = vim.fn.bufadd(name)
vim.fn.bufload(buf)
return buf, nil
end
---@param opts? ow.Git.Diffsplit.OpenOpts
function M.open(opts)
opts = opts or {}
local cur_buf = vim.api.nvim_get_current_buf()
local target, err
if opts.target then
target, err = resolve_target(opts.target, cur_buf)
local other, err
if opts.other then
other, err = resolve_other(opts.other, cur_buf)
else
target, err = infer_target(cur_buf)
other, err = infer_other(cur_buf)
end
if not target then
util.error("%s", err or "no diff target")
if not other then
util.error("%s", err or "no diff side")
return
end
local mods = opts.mods
if not mods or mods.split == nil then
local placement = default_split(cur_buf, target)
local mods = layout_mods(opts.mods, opts.layout)
if mods.split == nil then
local placement = default_split(cur_buf, other)
if placement then
mods = vim.tbl_extend("force", mods or {}, { split = placement })
end
end
vim.cmd.diffsplit({ args = { target }, mods = mods })
local cur_win = vim.api.nvim_get_current_win()
vim.cmd.diffsplit({ args = { other }, mods = mods })
if opts.focus == "current" and vim.api.nvim_win_is_valid(cur_win) then
vim.api.nvim_set_current_win(cur_win)
end
end
---@param old ow.Git.Diffsplit.Side
---@param new ow.Git.Diffsplit.Side
---@param opts? ow.Git.Diffsplit.OpenPairOpts
function M.open_pair(old, new, opts)
opts = opts or {}
local cur_buf = vim.api.nvim_get_current_buf()
local new_buf, err = buf_for_side(new, cur_buf)
if not new_buf then
util.error("%s", err or "no new diff side")
return
end
local old_name, _, old_err = resolve_side(old, cur_buf)
if not old_name then
util.error("%s", old_err or "no old diff side")
return
end
vim.cmd.normal({ "m'", bang = true })
vim.api.nvim_set_current_buf(new_buf)
local mods = layout_mods(opts.mods, opts.layout)
mods.split = mods.split or "aboveleft"
vim.cmd.diffsplit({ args = { old_name }, mods = mods })
if opts.focus ~= "old" then
vim.cmd.wincmd("p")
end
end
return M