-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add callback and utility functions for chat completion
- Added `callbacks.lua` with functions to handle chat completion, completion chunks, and exit events. - Refactored `init.lua` to use the new callback functions and utility functions. - Added `utils.lua` with helper functions for buffer operations and placeholder replacements.
- Loading branch information
Showing
3 changed files
with
221 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
local M = {} | ||
|
||
---Callback function to handle the completion of a chat request. | ||
---@param res table | ||
---@param opts Options | ||
---@return function | ||
M.on_chat_completion = function(res, opts) | ||
return function(obj) | ||
local finish_reason = obj.choices[1].finish_reason | ||
local content = obj.choices[1].message.content | ||
if finish_reason == "stop" then | ||
local lines = vim.split(content, "\n", { plain = true, trimempty = false }) | ||
vim.api.nvim_buf_set_lines(res.buf, -2, -1, true, lines) | ||
if opts.verbose and obj.usage then | ||
vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO) | ||
end | ||
vim.notify("Done.", vim.log.levels.INFO) | ||
else | ||
vim.notify("An error occured during text genereation.", vim.log.levels.ERROR) | ||
end | ||
end | ||
end | ||
|
||
---Callback function to handle the completion chunk of a chat request. | ||
---@param res table | ||
---@param opts Options | ||
---@return function | ||
M.on_chat_completion_chunk = function(res, opts) | ||
return function(obj) | ||
local finish_reason = obj.choices[1].finish_reason | ||
local content = obj.choices[1].delta.content | ||
if finish_reason == vim.NIL then | ||
local lines = vim.split(content, "\n", { plain = true, trimempty = false }) | ||
local last_line, last_column = require("dante.utils").last(res.buf) | ||
vim.api.nvim_buf_set_text(res.buf, last_line, last_column, last_line, last_column, lines) | ||
if opts.verbose and obj.usage then | ||
vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO) | ||
end | ||
elseif finish_reason == "stop" then | ||
vim.notify("Done.", vim.log.levels.INFO) | ||
else | ||
vim.notify("An error occured during text genereation.", vim.log.levels.ERROR) | ||
end | ||
end | ||
end | ||
|
||
---Callback function to handle the exit of a chat request. | ||
---@param res table | ||
---@param req table | ||
---@param opts Options | ||
---@param after_lines string[] | ||
---@return function | ||
M.on_exit = function(res, req, opts, after_lines) | ||
return function() | ||
vim.api.nvim_buf_set_lines(res.buf, -1, -1, true, after_lines) | ||
vim.api.nvim_set_current_win(res.win) | ||
vim.cmd("diffthis") | ||
vim.api.nvim_set_current_win(req.win) | ||
vim.cmd("diffthis") | ||
end | ||
end | ||
|
||
return M |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,168 +1,112 @@ | ||
local ai = require("ai") | ||
local utils = require("dante.utils") | ||
local callbacks = require("dante.callbacks") | ||
|
||
local dante = {} | ||
|
||
---Setup global options for Dante | ||
---@param options Options: global optoins for Dante. | ||
--- Setup global options for Dante | ||
---@param options Options: global options for Dante | ||
function dante.setup(options) | ||
require("dante.config").setup(options) | ||
end | ||
|
||
---Replace all placeholders in the content string | ||
---@param content string: The content string to be formatted | ||
---@param start_line integer: The start line of the selected text | ||
---@param end_line integer: The end line of the selected text | ||
---@param buf integer: The buffer number of the selected text | ||
---@return string: The formatted content string | ||
function dante.format(content, start_line, end_line, buf) | ||
local result = "" | ||
local last_end = 1 | ||
|
||
-- Function to get the replacement for a placeholder | ||
local function get_replacement(placeholder) | ||
if placeholder == "{{SELECTED_LINES}}" then | ||
local range_lines = vim.api.nvim_buf_get_lines(0, start_line - 1, end_line, false) | ||
return table.concat(range_lines, "\n") | ||
elseif placeholder == "{{NOW}}" then | ||
return os.date("Today is %a, %d %b %Y %H:%M:%S %z") | ||
-- Add other placeholders here... | ||
else | ||
-- If not recognized, keep the original placeholder | ||
vim.notify("Unrecognized placeholder: " .. placeholder, vim.log.levels.WARN) | ||
return placeholder | ||
end | ||
end | ||
--- Setup the UI for Dante | ||
---@param opts Options: The options for Dante | ||
local function setup_ui(opts) | ||
-- Diff options | ||
local diff = { | ||
"internal", | ||
"filler", | ||
"closeoff", | ||
"followwrap", | ||
"iblank", | ||
} | ||
vim.cmd("set diffopt=" .. table.concat(diff, ",")) | ||
|
||
-- Request | ||
local req = { | ||
name = vim.api.nvim_buf_get_name(0), | ||
buf = vim.api.nvim_get_current_buf(), | ||
win = vim.api.nvim_get_current_win(), | ||
} | ||
|
||
local buf_opts = { | ||
ft = vim.bo.filetype, | ||
} | ||
|
||
local win_opts = { | ||
wrap = vim.wo.wrap, | ||
lbr = vim.wo.linebreak, | ||
bri = vim.wo.breakindent, | ||
} | ||
|
||
-- Find and replace all placeholders | ||
for placeholder_start, placeholder_end in content:gmatch("(){{.-}}()") do | ||
local placeholder = content:sub(placeholder_start, placeholder_end - 1) | ||
result = result .. content:sub(last_end, placeholder_start - 1) | ||
result = result .. get_replacement(placeholder) | ||
last_end = placeholder_end | ||
-- Response | ||
local res = { | ||
buf = vim.api.nvim_create_buf(false, true), | ||
name = utils.generate_buf_name(), | ||
} | ||
vim.api.nvim_buf_set_name(res.buf, res.name) | ||
for opt, value in pairs(buf_opts) do | ||
vim.api.nvim_set_option_value(opt, value, { buf = res.buf }) | ||
end | ||
|
||
-- Append any remaining content after the last placeholder | ||
result = result .. content:sub(last_end) | ||
|
||
return result | ||
end | ||
|
||
---Get the last line, column and line count in the chat buffer | ||
---Thanks to Oli Morris for this function | ||
---https://github.com/olimorris/codecompanion.nvim/blob/ | ||
--- f8db284e1197a8cc4235afa30dcc3e8d4f3f45a5 | ||
--- /lua/codecompanion/strategies/chat.lua#L987 | ||
---@param buf integer: The buffer number | ||
---@return integer: number of the last line | ||
---@return integer: number of columns in the last line | ||
local function last(buf) | ||
local line_count = vim.api.nvim_buf_line_count(buf) | ||
local last_line = line_count - 1 | ||
if last_line < 0 then | ||
return 0, 0 | ||
end | ||
local last_line_content = vim.api.nvim_buf_get_lines(buf, -2, -1, false) | ||
if not last_line_content or #last_line_content == 0 then | ||
return last_line, 0 | ||
-- TODO: do not open the window if overlay | vertical | horizontal | ||
res.win = vim.api.nvim_open_win(res.buf, true, { split = "right", win = req.win }) | ||
for opt, value in pairs(win_opts) do | ||
vim.api.nvim_set_option_value(opt, value, { win = res.win }) | ||
end | ||
local last_column = #last_line_content[1] | ||
return last_line, last_column | ||
|
||
return req, res | ||
end | ||
|
||
---Represents a chat completion request to be sent to the model. | ||
---Reference: https://platform.openai.com/docs/api-reference/chat/create | ||
--- Represents a chat completion request to be sent to the model. | ||
--- Reference: https://platform.openai.com/docs/api-reference/chat/create | ||
----@alias RequestObject table (already defined in ai.nvim) | ||
|
||
---Run Dante with the given preset | ||
---@param preset_key string: One of the preset from options | ||
--- Run Dante with the given preset | ||
---@param preset_key string: One of the presets from options | ||
---@param start_line integer: The start line of the selected text | ||
---@param end_line integer: The end line of the selected text | ||
---@return integer: Job id of the running job (curl request) | ||
function dante.main(preset_key, start_line, end_line) | ||
local opts = require("dante.config").options | ||
|
||
-- LLM Client | ||
local options = require("dante.config").options | ||
local preset = vim.deepcopy(options.presets[preset_key]) | ||
local preset = vim.deepcopy(opts.presets[preset_key]) | ||
local client = ai.Client:new(preset.client.base_url, preset.client.api_key) | ||
|
||
-- Format the messages content (e.g. substitute selected text) | ||
for _, message in ipairs(preset.request.messages) do | ||
message.content = dante.format(message.content, start_line, end_line) | ||
message.content = utils.format(message.content, start_line, end_line) | ||
end | ||
|
||
-- Request UI | ||
local req_win = vim.api.nvim_get_current_win() | ||
local req_buf = vim.api.nvim_get_current_buf() | ||
local filetype = vim.bo.filetype | ||
local wrap = vim.wo.wrap | ||
local linebreak = vim.wo.linebreak | ||
local breakindent = vim.wo.breakindent | ||
vim.cmd("set diffopt=internal,filler,closeoff,followwrap,iblank") | ||
|
||
-- Generate a unique buffer name | ||
local current_time = os.time() | ||
local hex_time = string.format("%x", current_time) | ||
local res_buf_name = "[Dante] " .. hex_time | ||
|
||
-- Response | ||
vim.cmd("vsplit") | ||
local res_win = vim.api.nvim_get_current_win() | ||
local res_buf = vim.api.nvim_create_buf(false, true) | ||
vim.api.nvim_win_set_buf(res_win, res_buf) | ||
vim.api.nvim_buf_set_name(res_buf, res_buf_name) | ||
vim.api.nvim_set_option_value("filetype", filetype, { buf = res_buf }) | ||
vim.api.nvim_set_option_value("wrap", wrap, { win = res_win }) | ||
vim.api.nvim_set_option_value("linebreak", linebreak, { win = res_win }) | ||
vim.api.nvim_set_option_value("breakindent", breakindent, { win = res_win }) | ||
local req, res = setup_ui(opts) | ||
|
||
-- Partition request buffer | ||
local before_lines = vim.api.nvim_buf_get_lines(req_buf, 0, start_line - 1, true) | ||
local after_lines = vim.api.nvim_buf_get_lines(req_buf, end_line, -1, true) | ||
local before_lines = vim.api.nvim_buf_get_lines(req.buf, 0, start_line - 1, true) | ||
local after_lines = vim.api.nvim_buf_get_lines(req.buf, end_line, -1, true) | ||
|
||
-- Add line before the response | ||
vim.api.nvim_buf_set_lines(res_buf, 0, 0, true, before_lines) | ||
vim.api.nvim_win_set_cursor(res_win, { start_line, 0 }) | ||
|
||
local function on_chat_completion(obj) | ||
local finish_reason = obj.choices[1].finish_reason | ||
local content = obj.choices[1].message.content | ||
if finish_reason == "stop" then | ||
local lines = vim.split(content, "\n", { plain = true, trimempty = false }) | ||
vim.api.nvim_buf_set_lines(res_buf, -2, -1, true, lines) | ||
if options.verbose and obj.usage then | ||
vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO) | ||
end | ||
vim.notify("Done.", vim.log.levels.INFO) | ||
else | ||
vim.notify("An error occured during text genereation.", vim.log.levels.ERROR) | ||
end | ||
end | ||
|
||
local function on_chat_completion_chunk(obj) | ||
local finish_reason = obj.choices[1].finish_reason | ||
local content = obj.choices[1].delta.content | ||
if finish_reason == vim.NIL then | ||
local lines = vim.split(content, "\n", { plain = true, trimempty = false }) | ||
local last_line, last_column = last(res_buf) | ||
vim.api.nvim_buf_set_text(res_buf, last_line, last_column, last_line, last_column, lines) | ||
if options.verbose and obj.usage then | ||
vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO) | ||
end | ||
elseif finish_reason == "stop" then | ||
vim.notify("Done.", vim.log.levels.INFO) | ||
else | ||
vim.notify("An error occured during text genereation.", vim.log.levels.ERROR) | ||
end | ||
end | ||
|
||
local function on_exit() | ||
vim.api.nvim_buf_set_lines(res_buf, -1, -1, true, after_lines) | ||
vim.api.nvim_set_current_win(res_win) | ||
vim.cmd("diffthis") | ||
vim.api.nvim_set_current_win(req_win) | ||
vim.cmd("diffthis") | ||
end | ||
|
||
return client:chat_completion_create(preset.request, on_chat_completion, on_chat_completion_chunk, nil, nil, on_exit) | ||
vim.api.nvim_buf_set_lines(res.buf, 0, 0, true, before_lines) | ||
|
||
-- NOTE: maybe remove this | ||
-- vim.api.nvim_win_set_cursor(res_win, { start_line, 0 }) | ||
|
||
local on_chat_completion = callbacks.on_chat_completion(res, opts) | ||
local on_chat_completion_chunk = callbacks.on_chat_completion_chunk(res, opts) | ||
|
||
local on_stdout = nil -- use ai.nvim on_stdout | ||
local on_stderr = nil -- use ai.nvim on_stderr | ||
local on_exit = callbacks.on_exit(res, req, opts, after_lines) | ||
|
||
return client:chat_completion_create( | ||
preset.request, | ||
on_chat_completion, | ||
on_chat_completion_chunk, | ||
on_stdout, | ||
on_stderr, | ||
on_exit | ||
) | ||
end | ||
|
||
return dante |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
local M = {} | ||
|
||
---Get the last line, column and line count in the chat buffer | ||
---Thanks to Oli Morris for this function | ||
---https://github.com/olimorris/codecompanion.nvim/blob/ | ||
--- f8db284e1197a8cc4235afa30dcc3e8d4f3f45a5 | ||
--- /lua/codecompanion/strategies/chat.lua#L987 | ||
---@param buf integer: The buffer number | ||
---@return integer: number of the last line | ||
---@return integer: number of columns in the last line | ||
M.last = function(buf) | ||
local line_count = vim.api.nvim_buf_line_count(buf) | ||
local last_line = line_count - 1 | ||
if last_line < 0 then | ||
return 0, 0 | ||
end | ||
local last_line_content = vim.api.nvim_buf_get_lines(buf, -2, -1, false) | ||
if not last_line_content or #last_line_content == 0 then | ||
return last_line, 0 | ||
end | ||
local last_column = #last_line_content[1] | ||
return last_line, last_column | ||
end | ||
|
||
---Generate a unique buffer name. | ||
---It uses the current time in hexadecimal format prefixed with "[Dante] ". | ||
---@return string: The generated buffer name | ||
M.generate_buf_name = function() | ||
-- Generate a unique buffer name | ||
local current_time = os.time() | ||
local hex_time = string.format("%x", current_time) | ||
local res_buf_name = "[Dante] " .. hex_time | ||
return res_buf_name | ||
end | ||
|
||
---Replace all placeholders in the content string | ||
---This is the core function that enable the user specify how to add context in LLM requests. | ||
---Supported place holders are: | ||
--- - `{{SELECTED_LINES}}`: The selected lines in the editor from VISUAL mode | ||
--- - `{{NOW}}`: The current date and time | ||
---@param content string: The content string to be formatted | ||
---@param start_line integer: The start line of the selected text | ||
---@param end_line integer: The end line of the selected text | ||
---@return string: The formatted content string | ||
M.format = function(content, start_line, end_line) | ||
local result = "" | ||
local last_end = 1 | ||
|
||
---Function to get the replacement for a placeholder | ||
---@param placeholder string: The placeholder to be replaced | ||
---@return string | ||
local function get_replacement(placeholder) | ||
if placeholder == "{{SELECTED_LINES}}" then | ||
local range_lines = vim.api.nvim_buf_get_lines(0, start_line - 1, end_line, false) | ||
return table.concat(range_lines, "\n") | ||
elseif placeholder == "{{NOW}}" then | ||
return tostring(os.date("Today is %a, %d %b %Y %H:%M:%S %z")) | ||
-- Add other placeholders here... | ||
else | ||
-- If not recognized, keep the original placeholder | ||
vim.notify("Unrecognized placeholder: " .. placeholder, vim.log.levels.WARN) | ||
return placeholder | ||
end | ||
end | ||
|
||
-- Find and replace all placeholders | ||
for placeholder_start, placeholder_end in content:gmatch("(){{.-}}()") do | ||
---@diagnostic disable-next-line: param-type-mismatch | ||
local placeholder = content:sub(placeholder_start, placeholder_end - 1) | ||
result = result .. content:sub(last_end, placeholder_start - 1) | ||
result = result .. get_replacement(placeholder) | ||
last_end = placeholder_end | ||
end | ||
|
||
-- Append any remaining content after the last placeholder | ||
result = result .. content:sub(last_end) | ||
return result | ||
end | ||
|
||
return M |