diff --git a/lua/dante/callbacks.lua b/lua/dante/callbacks.lua new file mode 100644 index 0000000..114c0ba --- /dev/null +++ b/lua/dante/callbacks.lua @@ -0,0 +1,63 @@ +local M = {} + +---Callback function to handle the completion of a chat request. +---@param res table +---@param opts Options +---@return function +M.on_chat_completion = function(res, opts) + return function(obj) + local finish_reason = obj.choices[1].finish_reason + local content = obj.choices[1].message.content + if finish_reason == "stop" then + local lines = vim.split(content, "\n", { plain = true, trimempty = false }) + vim.api.nvim_buf_set_lines(res.buf, -2, -1, true, lines) + if opts.verbose and obj.usage then + vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO) + end + vim.notify("Done.", vim.log.levels.INFO) + else + vim.notify("An error occured during text genereation.", vim.log.levels.ERROR) + end + end +end + +---Callback function to handle the completion chunk of a chat request. +---@param res table +---@param opts Options +---@return function +M.on_chat_completion_chunk = function(res, opts) + return function(obj) + local finish_reason = obj.choices[1].finish_reason + local content = obj.choices[1].delta.content + if finish_reason == vim.NIL then + local lines = vim.split(content, "\n", { plain = true, trimempty = false }) + local last_line, last_column = require("dante.utils").last(res.buf) + vim.api.nvim_buf_set_text(res.buf, last_line, last_column, last_line, last_column, lines) + if opts.verbose and obj.usage then + vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO) + end + elseif finish_reason == "stop" then + vim.notify("Done.", vim.log.levels.INFO) + else + vim.notify("An error occured during text genereation.", vim.log.levels.ERROR) + end + end +end + +---Callback function to handle the exit of a chat request. +---@param res table +---@param req table +---@param opts Options +---@param after_lines string[] +---@return function +M.on_exit = function(res, req, opts, after_lines) + return function() + vim.api.nvim_buf_set_lines(res.buf, -1, -1, true, after_lines) + vim.api.nvim_set_current_win(res.win) + vim.cmd("diffthis") + vim.api.nvim_set_current_win(req.win) + vim.cmd("diffthis") + end +end + +return M diff --git a/lua/dante/init.lua b/lua/dante/init.lua index 4dc147b..e38e31d 100644 --- a/lua/dante/init.lua +++ b/lua/dante/init.lua @@ -1,168 +1,112 @@ local ai = require("ai") +local utils = require("dante.utils") +local callbacks = require("dante.callbacks") local dante = {} ----Setup global options for Dante ----@param options Options: global optoins for Dante. +--- Setup global options for Dante +---@param options Options: global options for Dante function dante.setup(options) require("dante.config").setup(options) end ----Replace all placeholders in the content string ----@param content string: The content string to be formatted ----@param start_line integer: The start line of the selected text ----@param end_line integer: The end line of the selected text ----@param buf integer: The buffer number of the selected text ----@return string: The formatted content string -function dante.format(content, start_line, end_line, buf) - local result = "" - local last_end = 1 - - -- Function to get the replacement for a placeholder - local function get_replacement(placeholder) - if placeholder == "{{SELECTED_LINES}}" then - local range_lines = vim.api.nvim_buf_get_lines(0, start_line - 1, end_line, false) - return table.concat(range_lines, "\n") - elseif placeholder == "{{NOW}}" then - return os.date("Today is %a, %d %b %Y %H:%M:%S %z") - -- Add other placeholders here... - else - -- If not recognized, keep the original placeholder - vim.notify("Unrecognized placeholder: " .. placeholder, vim.log.levels.WARN) - return placeholder - end - end +--- Setup the UI for Dante +---@param opts Options: The options for Dante +local function setup_ui(opts) + -- Diff options + local diff = { + "internal", + "filler", + "closeoff", + "followwrap", + "iblank", + } + vim.cmd("set diffopt=" .. table.concat(diff, ",")) + + -- Request + local req = { + name = vim.api.nvim_buf_get_name(0), + buf = vim.api.nvim_get_current_buf(), + win = vim.api.nvim_get_current_win(), + } + + local buf_opts = { + ft = vim.bo.filetype, + } + + local win_opts = { + wrap = vim.wo.wrap, + lbr = vim.wo.linebreak, + bri = vim.wo.breakindent, + } - -- Find and replace all placeholders - for placeholder_start, placeholder_end in content:gmatch("(){{.-}}()") do - local placeholder = content:sub(placeholder_start, placeholder_end - 1) - result = result .. content:sub(last_end, placeholder_start - 1) - result = result .. get_replacement(placeholder) - last_end = placeholder_end + -- Response + local res = { + buf = vim.api.nvim_create_buf(false, true), + name = utils.generate_buf_name(), + } + vim.api.nvim_buf_set_name(res.buf, res.name) + for opt, value in pairs(buf_opts) do + vim.api.nvim_set_option_value(opt, value, { buf = res.buf }) end - -- Append any remaining content after the last placeholder - result = result .. content:sub(last_end) - - return result -end - ----Get the last line, column and line count in the chat buffer ----Thanks to Oli Morris for this function ----https://github.com/olimorris/codecompanion.nvim/blob/ ---- f8db284e1197a8cc4235afa30dcc3e8d4f3f45a5 ---- /lua/codecompanion/strategies/chat.lua#L987 ----@param buf integer: The buffer number ----@return integer: number of the last line ----@return integer: number of columns in the last line -local function last(buf) - local line_count = vim.api.nvim_buf_line_count(buf) - local last_line = line_count - 1 - if last_line < 0 then - return 0, 0 - end - local last_line_content = vim.api.nvim_buf_get_lines(buf, -2, -1, false) - if not last_line_content or #last_line_content == 0 then - return last_line, 0 + -- TODO: do not open the window if overlay | vertical | horizontal + res.win = vim.api.nvim_open_win(res.buf, true, { split = "right", win = req.win }) + for opt, value in pairs(win_opts) do + vim.api.nvim_set_option_value(opt, value, { win = res.win }) end - local last_column = #last_line_content[1] - return last_line, last_column + + return req, res end ----Represents a chat completion request to be sent to the model. ----Reference: https://platform.openai.com/docs/api-reference/chat/create +--- Represents a chat completion request to be sent to the model. +--- Reference: https://platform.openai.com/docs/api-reference/chat/create ----@alias RequestObject table (already defined in ai.nvim) ----Run Dante with the given preset ----@param preset_key string: One of the preset from options +--- Run Dante with the given preset +---@param preset_key string: One of the presets from options ---@param start_line integer: The start line of the selected text ---@param end_line integer: The end line of the selected text ---@return integer: Job id of the running job (curl request) function dante.main(preset_key, start_line, end_line) + local opts = require("dante.config").options + -- LLM Client - local options = require("dante.config").options - local preset = vim.deepcopy(options.presets[preset_key]) + local preset = vim.deepcopy(opts.presets[preset_key]) local client = ai.Client:new(preset.client.base_url, preset.client.api_key) -- Format the messages content (e.g. substitute selected text) for _, message in ipairs(preset.request.messages) do - message.content = dante.format(message.content, start_line, end_line) + message.content = utils.format(message.content, start_line, end_line) end - -- Request UI - local req_win = vim.api.nvim_get_current_win() - local req_buf = vim.api.nvim_get_current_buf() - local filetype = vim.bo.filetype - local wrap = vim.wo.wrap - local linebreak = vim.wo.linebreak - local breakindent = vim.wo.breakindent - vim.cmd("set diffopt=internal,filler,closeoff,followwrap,iblank") - - -- Generate a unique buffer name - local current_time = os.time() - local hex_time = string.format("%x", current_time) - local res_buf_name = "[Dante] " .. hex_time - - -- Response - vim.cmd("vsplit") - local res_win = vim.api.nvim_get_current_win() - local res_buf = vim.api.nvim_create_buf(false, true) - vim.api.nvim_win_set_buf(res_win, res_buf) - vim.api.nvim_buf_set_name(res_buf, res_buf_name) - vim.api.nvim_set_option_value("filetype", filetype, { buf = res_buf }) - vim.api.nvim_set_option_value("wrap", wrap, { win = res_win }) - vim.api.nvim_set_option_value("linebreak", linebreak, { win = res_win }) - vim.api.nvim_set_option_value("breakindent", breakindent, { win = res_win }) + local req, res = setup_ui(opts) -- Partition request buffer - local before_lines = vim.api.nvim_buf_get_lines(req_buf, 0, start_line - 1, true) - local after_lines = vim.api.nvim_buf_get_lines(req_buf, end_line, -1, true) + local before_lines = vim.api.nvim_buf_get_lines(req.buf, 0, start_line - 1, true) + local after_lines = vim.api.nvim_buf_get_lines(req.buf, end_line, -1, true) -- Add line before the response - vim.api.nvim_buf_set_lines(res_buf, 0, 0, true, before_lines) - vim.api.nvim_win_set_cursor(res_win, { start_line, 0 }) - - local function on_chat_completion(obj) - local finish_reason = obj.choices[1].finish_reason - local content = obj.choices[1].message.content - if finish_reason == "stop" then - local lines = vim.split(content, "\n", { plain = true, trimempty = false }) - vim.api.nvim_buf_set_lines(res_buf, -2, -1, true, lines) - if options.verbose and obj.usage then - vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO) - end - vim.notify("Done.", vim.log.levels.INFO) - else - vim.notify("An error occured during text genereation.", vim.log.levels.ERROR) - end - end - - local function on_chat_completion_chunk(obj) - local finish_reason = obj.choices[1].finish_reason - local content = obj.choices[1].delta.content - if finish_reason == vim.NIL then - local lines = vim.split(content, "\n", { plain = true, trimempty = false }) - local last_line, last_column = last(res_buf) - vim.api.nvim_buf_set_text(res_buf, last_line, last_column, last_line, last_column, lines) - if options.verbose and obj.usage then - vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO) - end - elseif finish_reason == "stop" then - vim.notify("Done.", vim.log.levels.INFO) - else - vim.notify("An error occured during text genereation.", vim.log.levels.ERROR) - end - end - - local function on_exit() - vim.api.nvim_buf_set_lines(res_buf, -1, -1, true, after_lines) - vim.api.nvim_set_current_win(res_win) - vim.cmd("diffthis") - vim.api.nvim_set_current_win(req_win) - vim.cmd("diffthis") - end - - return client:chat_completion_create(preset.request, on_chat_completion, on_chat_completion_chunk, nil, nil, on_exit) + vim.api.nvim_buf_set_lines(res.buf, 0, 0, true, before_lines) + + -- NOTE: maybe remove this + -- vim.api.nvim_win_set_cursor(res_win, { start_line, 0 }) + + local on_chat_completion = callbacks.on_chat_completion(res, opts) + local on_chat_completion_chunk = callbacks.on_chat_completion_chunk(res, opts) + + local on_stdout = nil -- use ai.nvim on_stdout + local on_stderr = nil -- use ai.nvim on_stderr + local on_exit = callbacks.on_exit(res, req, opts, after_lines) + + return client:chat_completion_create( + preset.request, + on_chat_completion, + on_chat_completion_chunk, + on_stdout, + on_stderr, + on_exit + ) end return dante diff --git a/lua/dante/utils.lua b/lua/dante/utils.lua new file mode 100644 index 0000000..cc32eea --- /dev/null +++ b/lua/dante/utils.lua @@ -0,0 +1,80 @@ +local M = {} + +---Get the last line, column and line count in the chat buffer +---Thanks to Oli Morris for this function +---https://github.com/olimorris/codecompanion.nvim/blob/ +--- f8db284e1197a8cc4235afa30dcc3e8d4f3f45a5 +--- /lua/codecompanion/strategies/chat.lua#L987 +---@param buf integer: The buffer number +---@return integer: number of the last line +---@return integer: number of columns in the last line +M.last = function(buf) + local line_count = vim.api.nvim_buf_line_count(buf) + local last_line = line_count - 1 + if last_line < 0 then + return 0, 0 + end + local last_line_content = vim.api.nvim_buf_get_lines(buf, -2, -1, false) + if not last_line_content or #last_line_content == 0 then + return last_line, 0 + end + local last_column = #last_line_content[1] + return last_line, last_column +end + +---Generate a unique buffer name. +---It uses the current time in hexadecimal format prefixed with "[Dante] ". +---@return string: The generated buffer name +M.generate_buf_name = function() + -- Generate a unique buffer name + local current_time = os.time() + local hex_time = string.format("%x", current_time) + local res_buf_name = "[Dante] " .. hex_time + return res_buf_name +end + +---Replace all placeholders in the content string +---This is the core function that enable the user specify how to add context in LLM requests. +---Supported place holders are: +--- - `{{SELECTED_LINES}}`: The selected lines in the editor from VISUAL mode +--- - `{{NOW}}`: The current date and time +---@param content string: The content string to be formatted +---@param start_line integer: The start line of the selected text +---@param end_line integer: The end line of the selected text +---@return string: The formatted content string +M.format = function(content, start_line, end_line) + local result = "" + local last_end = 1 + + ---Function to get the replacement for a placeholder + ---@param placeholder string: The placeholder to be replaced + ---@return string + local function get_replacement(placeholder) + if placeholder == "{{SELECTED_LINES}}" then + local range_lines = vim.api.nvim_buf_get_lines(0, start_line - 1, end_line, false) + return table.concat(range_lines, "\n") + elseif placeholder == "{{NOW}}" then + return tostring(os.date("Today is %a, %d %b %Y %H:%M:%S %z")) + -- Add other placeholders here... + else + -- If not recognized, keep the original placeholder + vim.notify("Unrecognized placeholder: " .. placeholder, vim.log.levels.WARN) + return placeholder + end + end + + -- Find and replace all placeholders + for placeholder_start, placeholder_end in content:gmatch("(){{.-}}()") do + ---@diagnostic disable-next-line: param-type-mismatch + local placeholder = content:sub(placeholder_start, placeholder_end - 1) + result = result .. content:sub(last_end, placeholder_start - 1) + result = result .. get_replacement(placeholder) + last_end = placeholder_end + end + + -- Append any remaining content after the last placeholder + result = result .. content:sub(last_end) + return result +end + +return M