Skip to content

Commit

Permalink
feat: add callback and utility functions for chat completion
Browse files Browse the repository at this point in the history
- Added `callbacks.lua` with functions to handle chat completion, completion chunks, and exit events.
- Refactored `init.lua` to use the new callback functions and utility functions.
- Added `utils.lua` with helper functions for buffer operations and placeholder replacements.
  • Loading branch information
S1M0N38 committed Sep 14, 2024
1 parent ead7680 commit fde4cc2
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 134 deletions.
63 changes: 63 additions & 0 deletions lua/dante/callbacks.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
local M = {}

---Callback function to handle the completion of a chat request.
---@param res table
---@param opts Options
---@return function
M.on_chat_completion = function(res, opts)
return function(obj)
local finish_reason = obj.choices[1].finish_reason
local content = obj.choices[1].message.content
if finish_reason == "stop" then
local lines = vim.split(content, "\n", { plain = true, trimempty = false })
vim.api.nvim_buf_set_lines(res.buf, -2, -1, true, lines)
if opts.verbose and obj.usage then
vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO)
end
vim.notify("Done.", vim.log.levels.INFO)
else
vim.notify("An error occured during text genereation.", vim.log.levels.ERROR)
end
end
end

---Callback function to handle the completion chunk of a chat request.
---@param res table
---@param opts Options
---@return function
M.on_chat_completion_chunk = function(res, opts)
return function(obj)
local finish_reason = obj.choices[1].finish_reason
local content = obj.choices[1].delta.content
if finish_reason == vim.NIL then
local lines = vim.split(content, "\n", { plain = true, trimempty = false })
local last_line, last_column = require("dante.utils").last(res.buf)
vim.api.nvim_buf_set_text(res.buf, last_line, last_column, last_line, last_column, lines)
if opts.verbose and obj.usage then
vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO)
end
elseif finish_reason == "stop" then
vim.notify("Done.", vim.log.levels.INFO)
else
vim.notify("An error occured during text genereation.", vim.log.levels.ERROR)
end
end
end

---Callback function to handle the exit of a chat request.
---@param res table
---@param req table
---@param opts Options
---@param after_lines string[]
---@return function
M.on_exit = function(res, req, opts, after_lines)
return function()
vim.api.nvim_buf_set_lines(res.buf, -1, -1, true, after_lines)
vim.api.nvim_set_current_win(res.win)
vim.cmd("diffthis")
vim.api.nvim_set_current_win(req.win)
vim.cmd("diffthis")
end
end

return M
212 changes: 78 additions & 134 deletions lua/dante/init.lua
Original file line number Diff line number Diff line change
@@ -1,168 +1,112 @@
local ai = require("ai")
local utils = require("dante.utils")
local callbacks = require("dante.callbacks")

local dante = {}

---Setup global options for Dante
---@param options Options: global optoins for Dante.
--- Setup global options for Dante
---@param options Options: global options for Dante
function dante.setup(options)
require("dante.config").setup(options)
end

---Replace all placeholders in the content string
---@param content string: The content string to be formatted
---@param start_line integer: The start line of the selected text
---@param end_line integer: The end line of the selected text
---@param buf integer: The buffer number of the selected text
---@return string: The formatted content string
function dante.format(content, start_line, end_line, buf)
local result = ""
local last_end = 1

-- Function to get the replacement for a placeholder
local function get_replacement(placeholder)
if placeholder == "{{SELECTED_LINES}}" then
local range_lines = vim.api.nvim_buf_get_lines(0, start_line - 1, end_line, false)
return table.concat(range_lines, "\n")
elseif placeholder == "{{NOW}}" then
return os.date("Today is %a, %d %b %Y %H:%M:%S %z")
-- Add other placeholders here...
else
-- If not recognized, keep the original placeholder
vim.notify("Unrecognized placeholder: " .. placeholder, vim.log.levels.WARN)
return placeholder
end
end
--- Setup the UI for Dante
---@param opts Options: The options for Dante
local function setup_ui(opts)
-- Diff options
local diff = {
"internal",
"filler",
"closeoff",
"followwrap",
"iblank",
}
vim.cmd("set diffopt=" .. table.concat(diff, ","))

-- Request
local req = {
name = vim.api.nvim_buf_get_name(0),
buf = vim.api.nvim_get_current_buf(),
win = vim.api.nvim_get_current_win(),
}

local buf_opts = {
ft = vim.bo.filetype,
}

local win_opts = {
wrap = vim.wo.wrap,
lbr = vim.wo.linebreak,
bri = vim.wo.breakindent,
}

-- Find and replace all placeholders
for placeholder_start, placeholder_end in content:gmatch("(){{.-}}()") do
local placeholder = content:sub(placeholder_start, placeholder_end - 1)
result = result .. content:sub(last_end, placeholder_start - 1)
result = result .. get_replacement(placeholder)
last_end = placeholder_end
-- Response
local res = {
buf = vim.api.nvim_create_buf(false, true),
name = utils.generate_buf_name(),
}
vim.api.nvim_buf_set_name(res.buf, res.name)
for opt, value in pairs(buf_opts) do
vim.api.nvim_set_option_value(opt, value, { buf = res.buf })
end

-- Append any remaining content after the last placeholder
result = result .. content:sub(last_end)

return result
end

---Get the last line, column and line count in the chat buffer
---Thanks to Oli Morris for this function
---https://github.com/olimorris/codecompanion.nvim/blob/
--- f8db284e1197a8cc4235afa30dcc3e8d4f3f45a5
--- /lua/codecompanion/strategies/chat.lua#L987
---@param buf integer: The buffer number
---@return integer: number of the last line
---@return integer: number of columns in the last line
local function last(buf)
local line_count = vim.api.nvim_buf_line_count(buf)
local last_line = line_count - 1
if last_line < 0 then
return 0, 0
end
local last_line_content = vim.api.nvim_buf_get_lines(buf, -2, -1, false)
if not last_line_content or #last_line_content == 0 then
return last_line, 0
-- TODO: do not open the window if overlay | vertical | horizontal
res.win = vim.api.nvim_open_win(res.buf, true, { split = "right", win = req.win })
for opt, value in pairs(win_opts) do
vim.api.nvim_set_option_value(opt, value, { win = res.win })
end
local last_column = #last_line_content[1]
return last_line, last_column

return req, res
end

---Represents a chat completion request to be sent to the model.
---Reference: https://platform.openai.com/docs/api-reference/chat/create
--- Represents a chat completion request to be sent to the model.
--- Reference: https://platform.openai.com/docs/api-reference/chat/create
----@alias RequestObject table (already defined in ai.nvim)

---Run Dante with the given preset
---@param preset_key string: One of the preset from options
--- Run Dante with the given preset
---@param preset_key string: One of the presets from options
---@param start_line integer: The start line of the selected text
---@param end_line integer: The end line of the selected text
---@return integer: Job id of the running job (curl request)
function dante.main(preset_key, start_line, end_line)
local opts = require("dante.config").options

-- LLM Client
local options = require("dante.config").options
local preset = vim.deepcopy(options.presets[preset_key])
local preset = vim.deepcopy(opts.presets[preset_key])
local client = ai.Client:new(preset.client.base_url, preset.client.api_key)

-- Format the messages content (e.g. substitute selected text)
for _, message in ipairs(preset.request.messages) do
message.content = dante.format(message.content, start_line, end_line)
message.content = utils.format(message.content, start_line, end_line)
end

-- Request UI
local req_win = vim.api.nvim_get_current_win()
local req_buf = vim.api.nvim_get_current_buf()
local filetype = vim.bo.filetype
local wrap = vim.wo.wrap
local linebreak = vim.wo.linebreak
local breakindent = vim.wo.breakindent
vim.cmd("set diffopt=internal,filler,closeoff,followwrap,iblank")

-- Generate a unique buffer name
local current_time = os.time()
local hex_time = string.format("%x", current_time)
local res_buf_name = "[Dante] " .. hex_time

-- Response
vim.cmd("vsplit")
local res_win = vim.api.nvim_get_current_win()
local res_buf = vim.api.nvim_create_buf(false, true)
vim.api.nvim_win_set_buf(res_win, res_buf)
vim.api.nvim_buf_set_name(res_buf, res_buf_name)
vim.api.nvim_set_option_value("filetype", filetype, { buf = res_buf })
vim.api.nvim_set_option_value("wrap", wrap, { win = res_win })
vim.api.nvim_set_option_value("linebreak", linebreak, { win = res_win })
vim.api.nvim_set_option_value("breakindent", breakindent, { win = res_win })
local req, res = setup_ui(opts)

-- Partition request buffer
local before_lines = vim.api.nvim_buf_get_lines(req_buf, 0, start_line - 1, true)
local after_lines = vim.api.nvim_buf_get_lines(req_buf, end_line, -1, true)
local before_lines = vim.api.nvim_buf_get_lines(req.buf, 0, start_line - 1, true)
local after_lines = vim.api.nvim_buf_get_lines(req.buf, end_line, -1, true)

-- Add line before the response
vim.api.nvim_buf_set_lines(res_buf, 0, 0, true, before_lines)
vim.api.nvim_win_set_cursor(res_win, { start_line, 0 })

local function on_chat_completion(obj)
local finish_reason = obj.choices[1].finish_reason
local content = obj.choices[1].message.content
if finish_reason == "stop" then
local lines = vim.split(content, "\n", { plain = true, trimempty = false })
vim.api.nvim_buf_set_lines(res_buf, -2, -1, true, lines)
if options.verbose and obj.usage then
vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO)
end
vim.notify("Done.", vim.log.levels.INFO)
else
vim.notify("An error occured during text genereation.", vim.log.levels.ERROR)
end
end

local function on_chat_completion_chunk(obj)
local finish_reason = obj.choices[1].finish_reason
local content = obj.choices[1].delta.content
if finish_reason == vim.NIL then
local lines = vim.split(content, "\n", { plain = true, trimempty = false })
local last_line, last_column = last(res_buf)
vim.api.nvim_buf_set_text(res_buf, last_line, last_column, last_line, last_column, lines)
if options.verbose and obj.usage then
vim.notify("usage = " .. vim.inspect(obj.usage), vim.log.levels.INFO)
end
elseif finish_reason == "stop" then
vim.notify("Done.", vim.log.levels.INFO)
else
vim.notify("An error occured during text genereation.", vim.log.levels.ERROR)
end
end

local function on_exit()
vim.api.nvim_buf_set_lines(res_buf, -1, -1, true, after_lines)
vim.api.nvim_set_current_win(res_win)
vim.cmd("diffthis")
vim.api.nvim_set_current_win(req_win)
vim.cmd("diffthis")
end

return client:chat_completion_create(preset.request, on_chat_completion, on_chat_completion_chunk, nil, nil, on_exit)
vim.api.nvim_buf_set_lines(res.buf, 0, 0, true, before_lines)

-- NOTE: maybe remove this
-- vim.api.nvim_win_set_cursor(res_win, { start_line, 0 })

local on_chat_completion = callbacks.on_chat_completion(res, opts)
local on_chat_completion_chunk = callbacks.on_chat_completion_chunk(res, opts)

local on_stdout = nil -- use ai.nvim on_stdout
local on_stderr = nil -- use ai.nvim on_stderr
local on_exit = callbacks.on_exit(res, req, opts, after_lines)

return client:chat_completion_create(
preset.request,
on_chat_completion,
on_chat_completion_chunk,
on_stdout,
on_stderr,
on_exit
)
end

return dante
80 changes: 80 additions & 0 deletions lua/dante/utils.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
local M = {}

---Get the last line, column and line count in the chat buffer
---Thanks to Oli Morris for this function
---https://github.com/olimorris/codecompanion.nvim/blob/
--- f8db284e1197a8cc4235afa30dcc3e8d4f3f45a5
--- /lua/codecompanion/strategies/chat.lua#L987
---@param buf integer: The buffer number
---@return integer: number of the last line
---@return integer: number of columns in the last line
M.last = function(buf)
local line_count = vim.api.nvim_buf_line_count(buf)
local last_line = line_count - 1
if last_line < 0 then
return 0, 0
end
local last_line_content = vim.api.nvim_buf_get_lines(buf, -2, -1, false)
if not last_line_content or #last_line_content == 0 then
return last_line, 0
end
local last_column = #last_line_content[1]
return last_line, last_column
end

---Generate a unique buffer name.
---It uses the current time in hexadecimal format prefixed with "[Dante] ".
---@return string: The generated buffer name
M.generate_buf_name = function()
-- Generate a unique buffer name
local current_time = os.time()
local hex_time = string.format("%x", current_time)
local res_buf_name = "[Dante] " .. hex_time
return res_buf_name
end

---Replace all placeholders in the content string
---This is the core function that enable the user specify how to add context in LLM requests.
---Supported place holders are:
--- - `{{SELECTED_LINES}}`: The selected lines in the editor from VISUAL mode
--- - `{{NOW}}`: The current date and time
---@param content string: The content string to be formatted
---@param start_line integer: The start line of the selected text
---@param end_line integer: The end line of the selected text
---@return string: The formatted content string
M.format = function(content, start_line, end_line)
local result = ""
local last_end = 1

---Function to get the replacement for a placeholder
---@param placeholder string: The placeholder to be replaced
---@return string
local function get_replacement(placeholder)
if placeholder == "{{SELECTED_LINES}}" then
local range_lines = vim.api.nvim_buf_get_lines(0, start_line - 1, end_line, false)
return table.concat(range_lines, "\n")
elseif placeholder == "{{NOW}}" then
return tostring(os.date("Today is %a, %d %b %Y %H:%M:%S %z"))
-- Add other placeholders here...
else
-- If not recognized, keep the original placeholder
vim.notify("Unrecognized placeholder: " .. placeholder, vim.log.levels.WARN)
return placeholder
end
end

-- Find and replace all placeholders
for placeholder_start, placeholder_end in content:gmatch("(){{.-}}()") do
---@diagnostic disable-next-line: param-type-mismatch
local placeholder = content:sub(placeholder_start, placeholder_end - 1)
result = result .. content:sub(last_end, placeholder_start - 1)
result = result .. get_replacement(placeholder)
last_end = placeholder_end
end

-- Append any remaining content after the last placeholder
result = result .. content:sub(last_end)
return result
end

return M

0 comments on commit fde4cc2

Please sign in to comment.