From 1576e003301f61b016e1a2ad115d76764dc2005b Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 21 Dec 2023 23:04:14 +0000 Subject: [PATCH] feat(plugins): ai-prompt-decorator plugin --- .github/labeler.yml | 4 + .../kong/add-ai-prompt-decorator-plugin.yml | 3 + kong-3.6.0-0.rockspec | 3 + kong/plugins/ai-prompt-decorator/access.lua | 42 ++++ kong/plugins/ai-prompt-decorator/handler.lua | 67 +++++++ kong/plugins/ai-prompt-decorator/schema.lua | 80 ++++++++ .../40-ai-prompt-decorator/00-config_spec.lua | 52 +++++ .../40-ai-prompt-decorator/01-unit_spec.lua | 184 ++++++++++++++++++ .../02-integration_spec.lua | 123 ++++++++++++ 9 files changed, 558 insertions(+) create mode 100644 changelog/unreleased/kong/add-ai-prompt-decorator-plugin.yml create mode 100644 kong/plugins/ai-prompt-decorator/access.lua create mode 100644 kong/plugins/ai-prompt-decorator/handler.lua create mode 100644 kong/plugins/ai-prompt-decorator/schema.lua create mode 100644 spec/03-plugins/40-ai-prompt-decorator/00-config_spec.lua create mode 100644 spec/03-plugins/40-ai-prompt-decorator/01-unit_spec.lua create mode 100644 spec/03-plugins/40-ai-prompt-decorator/02-integration_spec.lua diff --git a/.github/labeler.yml b/.github/labeler.yml index d75a21fa48a..cccbc1b2da9 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -86,6 +86,10 @@ plugins/acme: - changed-files: - any-glob-to-any-file: kong/plugins/acme/**/* +plugins/ai-prompt-decorator: +- changed-files: + - any-glob-to-any-file: kong/plugins/ai-prompt-decorator/**/* + plugins/aws-lambda: - changed-files: - any-glob-to-any-file: kong/plugins/aws-lambda/**/* diff --git a/changelog/unreleased/kong/add-ai-prompt-decorator-plugin.yml b/changelog/unreleased/kong/add-ai-prompt-decorator-plugin.yml new file mode 100644 index 00000000000..1e8db2b788e --- /dev/null +++ b/changelog/unreleased/kong/add-ai-prompt-decorator-plugin.yml @@ -0,0 +1,3 @@ +message: Introduced the new **AI Prompt Decorator** plugin that allows Kong administrators to inject chat messages into LLM consumer prompts. +type: feature +scope: Plugin \ No newline at end of file diff --git a/kong-3.6.0-0.rockspec b/kong-3.6.0-0.rockspec index 4e07f3823b0..5f4911ebbd4 100644 --- a/kong-3.6.0-0.rockspec +++ b/kong-3.6.0-0.rockspec @@ -550,6 +550,9 @@ build = { ["kong.plugins.opentelemetry.proto"] = "kong/plugins/opentelemetry/proto.lua", ["kong.plugins.opentelemetry.otlp"] = "kong/plugins/opentelemetry/otlp.lua", + ["kong.plugins.ai-prompt-decorator.handler"] = "kong/plugins/ai-prompt-decorator/handler.lua", + ["kong.plugins.ai-prompt-decorator.schema"] = "kong/plugins/ai-prompt-decorator/schema.lua", + ["kong.vaults.env"] = "kong/vaults/env/init.lua", ["kong.vaults.env.schema"] = "kong/vaults/env/schema.lua", diff --git a/kong/plugins/ai-prompt-decorator/access.lua b/kong/plugins/ai-prompt-decorator/access.lua new file mode 100644 index 00000000000..e43d1417cce --- /dev/null +++ b/kong/plugins/ai-prompt-decorator/access.lua @@ -0,0 +1,42 @@ +local _M = {} + +-- imports +local kong_meta = require "kong.meta" +local re_match = ngx.re.match +local re_find = ngx.re.find +local fmt = string.format +local table_insert = table.insert +-- + +_M.PRIORITY = 772 +_M.VERSION = kong_meta.version + +local function to_chat_prompt(version, role, content) + if version == "v1" then + return { role = role, content = content } + else + return nil + end +end + +function _M.execute(request, conf) + -- 1. add in-order to the head of the chat + if conf.prompts.prepend and #conf.prompts.prepend > 0 then + for i, v in ipairs(conf.prompts.prepend) do + table.insert(request.messages, i, to_chat_prompt("v1", v.role, v.content)) + end + end + + -- 2. add in-order to the tail of the chat + if conf.prompts.append and #conf.prompts.append > 0 then + local messages_length = #request.messages + + for i, v in ipairs(conf.prompts.append) do + request.messages[i + messages_length] = to_chat_prompt("v1", v.role, v.content) + end + end + + return nil, nil +end + +return _M diff --git a/kong/plugins/ai-prompt-decorator/handler.lua b/kong/plugins/ai-prompt-decorator/handler.lua new file mode 100644 index 00000000000..3a368f6c5f3 --- /dev/null +++ b/kong/plugins/ai-prompt-decorator/handler.lua @@ -0,0 +1,67 @@ +local _M = {} + +-- imports +local kong_meta = require "kong.meta" +local access_handler = require("kong.plugins.ai-prompt-decorator.access") +local re_match = ngx.re.match +local re_find = ngx.re.find +local fmt = string.format +local table_insert = table.insert +-- + +_M.PRIORITY = 772 +_M.VERSION = kong_meta.version + + +local function do_bad_request(msg) + kong.log.warn(msg) + kong.response.exit(400, { error = true, message = msg }) +end + + +local function do_internal_server_error(msg) + kong.log.err(msg) + kong.response.exit(500, { error = true, message = msg }) +end + + +function _M:access(conf) + kong.log.debug("IN: ai-prompt-decorator/access") + kong.service.request.enable_buffering() + kong.ctx.shared.prompt_decorated = true + + -- if plugin ordering was altered, receive the "decorated" request + local request, err + if not kong.ctx.replacement_request then + request, err = kong.request.get_body("application/json") + + if err then + do_bad_request("ai-prompt-decorator only supports application/json requests") + end + else + request = kong.ctx.replacement_request + end + + if not request.messages or #request.messages < 1 then + do_bad_request("ai-prompt-decorator only support llm/chat type requests") + end + + -- run access handler to decorate the messages[] block + local code, err = access_handler.execute(request, conf) + if err then + -- don't run header_filter and body_filter from ai-proxy plugin + kong.ctx.shared.skip_response_transformer = true + + if code == 500 then kong.log.err(err) end + kong.response.exit(code, err) + end + + -- stash the result for parsing later (in ai-proxy) + kong.ctx.shared.replacement_request = request + + -- all good + kong.log.debug("OUT: ai-prompt-decorator/access") +end + + +return _M diff --git a/kong/plugins/ai-prompt-decorator/schema.lua b/kong/plugins/ai-prompt-decorator/schema.lua new file mode 100644 index 00000000000..86bb8e9dd83 --- /dev/null +++ b/kong/plugins/ai-prompt-decorator/schema.lua @@ -0,0 +1,80 @@ +local typedefs = require "kong.db.schema.typedefs" + +-- local prompt_schema = { +-- type = "record", +-- required = false, +-- fields = { +-- { role = { type = "string", required = false, one_of = { "system", "assistant", "user" }, default = "system" }}, +-- { content = { type = "string", required = true } }, +-- { position = { type = "string", required = true, one_of = { "BEGINNING", "AFTER_FINAL_SYSTEM", "AFTER_FINAL_ASSISTANT" "END" }, default = "BEGINNING" }}, +-- } +-- } + +local prompt_record = { + type = "record", + required = false, + fields = { + { role = { type = "string", required = true, one_of = { "system", "assistant", "user" }, default = "system" }}, + { content = { type = "string", required = true } }, + } +} + +local prompts_record = { + type = "record", + required = false, + fields = { + { prepend = { + type = "array", + description = [[Insert chat messages at the beginning of the chat message array. + This array preserves exact order when adding messages.]], + elements = prompt_record, + required = false, + }}, + { append = { + type = "array", + description = [[Insert chat messages at the end of the chat message array. + This array preserves exact order when adding messages.]], + elements = prompt_record, + required = false, + }}, + } +} + +return { + name = "ai-prompt-injector", + fields = { + { protocols = typedefs.protocols_http }, + { config = { + type = "record", + fields = { + { prompts = prompts_record } + } + } + } + }, + entity_checks = { + { + custom_entity_check = { + field_sources = { "config" }, + fn = function(entity) + local config = entity.config + + if config and config.prompts ~= ngx.null then + local head_prompts_set = (config.prompts.prepend ~= ngx.null) and (#config.prompts.prepend > 0) + local tail_prompts_set = (config.prompts.append ~= ngx.null) and (#config.prompts.append > 0) + + if (not head_prompts_set) and (not tail_prompts_set) then + return nil, "must set one array item in either [prompts.prepend] or [prompts.append]" + end + + else + return nil, "must specify one or more [prompts.prepend] or [prompts.append] to add to requests" + + end + + return true + end + } + } + } +} diff --git a/spec/03-plugins/40-ai-prompt-decorator/00-config_spec.lua b/spec/03-plugins/40-ai-prompt-decorator/00-config_spec.lua new file mode 100644 index 00000000000..a64728074f5 --- /dev/null +++ b/spec/03-plugins/40-ai-prompt-decorator/00-config_spec.lua @@ -0,0 +1,52 @@ +local PLUGIN_NAME = "ai-prompt-decorator" + + +-- helper function to validate data against a schema +local validate do + local validate_entity = require("spec.helpers").validate_plugin_config_schema + local plugin_schema = require("kong.plugins."..PLUGIN_NAME..".schema") + + function validate(data) + return validate_entity(data, plugin_schema) + end +end + +describe(PLUGIN_NAME .. ": (schema)", function() + it("won't allow empty config object", function() + local config = { + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.not_nil(err) + assert.equal("must specify one or more [prompts.prepend] or [prompts.append] to add to requests", err["@entity"][1]) + end) + + it("won't allow both head and tail to be unset", function() + local config = { + prompts = {}, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.not_nil(err) + assert.equal("must set one array item in either [prompts.prepend] or [prompts.append]", err["@entity"][1]) + end) + + it("won't allow both allow_patterns and deny_patterns to be empty arrays", function() + local config = { + prompts = { + prepend = {}, + append = {}, + }, + } + + local ok, err = validate(config) + + assert.is_falsy(ok) + assert.not_nil(err) + assert.equal("must set one array item in either [prompts.prepend] or [prompts.append]", err["@entity"][1]) + end) +end) diff --git a/spec/03-plugins/40-ai-prompt-decorator/01-unit_spec.lua b/spec/03-plugins/40-ai-prompt-decorator/01-unit_spec.lua new file mode 100644 index 00000000000..e314bba2769 --- /dev/null +++ b/spec/03-plugins/40-ai-prompt-decorator/01-unit_spec.lua @@ -0,0 +1,184 @@ +local PLUGIN_NAME = "ai-prompt-decorator" + +-- imports +local access_handler = require("kong.plugins.ai-prompt-decorator.access") +-- + + +local function deepcopy(o, seen) + seen = seen or {} + if o == nil then return nil end + if seen[o] then return seen[o] end + + local no + if type(o) == 'table' then + no = {} + seen[o] = no + + for k, v in next, o, nil do + no[deepcopy(k, seen)] = deepcopy(v, seen) + end + setmetatable(no, deepcopy(getmetatable(o), seen)) + else -- number, string, boolean, etc + no = o + end + return no +end + + +local general_chat_request = { + messages = { + [1] = { + role = "system", + content = "You are a mathematician." + }, + [2] = { + role = "user", + content = "What is 1 + 1?" + }, + [3] = { + role = "assistant", + content = "The answer is 2?" + }, + [4] = { + role = "user", + content = "Now double it." + }, + }, +} + +local injector_conf_prepend = { + prompts = { + prepend = { + [1] = { + role = "system", + content = "Give me answers in French language." + }, + [2] = { + role = "user", + content = "Consider you are a mathematician." + }, + [3] = { + role = "assistant", + content = "Okay I am a mathematician. What is your maths question?" + }, + }, + }, +} + +local injector_conf_append = { + prompts = { + append = { + [1] = { + role = "system", + content = "Give me answers in French language." + }, + [2] = { + role = "system", + content = "Give me the answer in JSON format." + }, + }, + }, +} + +local injector_conf_both = { + prompts = { + prepend = { + [1] = { + role = "system", + content = "Give me answers in French language." + }, + [2] = { + role = "user", + content = "Consider you are a mathematician." + }, + [3] = { + role = "assistant", + content = "Okay I am a mathematician. What is your maths question?" + }, + }, + append = { + [1] = { + role = "system", + content = "Give me answers in French language." + }, + [2] = { + role = "system", + content = "Give me the answer in JSON format." + }, + }, + }, +} + +local function dump(o) + if type(o) == 'table' then + local s = '{ ' + for k,v in pairs(o) do + if type(k) ~= 'number' then k = '"'..k..'"' end + s = s .. '['..k..'] = ' .. dump(v) .. ',' + end + return s .. '} ' + else + return tostring(o) + end +end + + +describe(PLUGIN_NAME .. ": (unit)", function() + + describe("chat v1 operations", function() + + + it("adds messages to the start of the array", function() + local request_copy = deepcopy(general_chat_request) + local expected_request_copy = deepcopy(general_chat_request) + + -- combine the tables manually, and check the code does the same + table.insert(expected_request_copy.messages, 1, injector_conf_prepend.prompts.prepend[1]) + table.insert(expected_request_copy.messages, 2, injector_conf_prepend.prompts.prepend[2]) + table.insert(expected_request_copy.messages, 3, injector_conf_prepend.prompts.prepend[3]) + + local code, err = access_handler.execute(request_copy, injector_conf_prepend) + + assert.is_nil(code) + assert.is_nil(err) + assert.same(expected_request_copy, request_copy) + end) + + it("adds messages to the end of the array", function() + local request_copy = deepcopy(general_chat_request) + local expected_request_copy = deepcopy(general_chat_request) + + -- combine the tables manually, and check the code does the same + table.insert(expected_request_copy.messages, #expected_request_copy.messages + 1, injector_conf_append.prompts.append[1]) + table.insert(expected_request_copy.messages, #expected_request_copy.messages + 1, injector_conf_append.prompts.append[2]) + + local code, err = access_handler.execute(request_copy, injector_conf_append) + + assert.is_nil(code) + assert.is_nil(err) + assert.same(expected_request_copy, request_copy) + end) + + it("adds messages to the start and the end of the array", function() + local request_copy = deepcopy(general_chat_request) + local expected_request_copy = deepcopy(general_chat_request) + + -- combine the tables manually, and check the code does the same + table.insert(expected_request_copy.messages, 1, injector_conf_both.prompts.prepend[1]) + table.insert(expected_request_copy.messages, 2, injector_conf_both.prompts.prepend[2]) + table.insert(expected_request_copy.messages, 3, injector_conf_both.prompts.prepend[3]) + table.insert(expected_request_copy.messages, #expected_request_copy.messages + 1, injector_conf_both.prompts.append[1]) + table.insert(expected_request_copy.messages, #expected_request_copy.messages + 1, injector_conf_both.prompts.append[2]) + + local code, err = access_handler.execute(request_copy, injector_conf_both) + + assert.is_nil(code) + assert.is_nil(err) + assert.same(expected_request_copy, request_copy) + end) + + + end) + +end) diff --git a/spec/03-plugins/40-ai-prompt-decorator/02-integration_spec.lua b/spec/03-plugins/40-ai-prompt-decorator/02-integration_spec.lua new file mode 100644 index 00000000000..4937f01ce2a --- /dev/null +++ b/spec/03-plugins/40-ai-prompt-decorator/02-integration_spec.lua @@ -0,0 +1,123 @@ +local helpers = require "spec.helpers" +local cjson = require "cjson" + +local PLUGIN_NAME = "ai-prompt-decorator" + + +for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then + describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() + local client + + lazy_setup(function() + + local bp = helpers.get_db_utils(strategy == "off" and "postgres" or strategy, nil, { PLUGIN_NAME }) + + -- Inject a test route. No need to create a service, there is a default + -- service which will echo the request. + local route1 = bp.routes:insert({ + hosts = { "test1.com" }, + }) + -- add the plugin to test to the route we created + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = route1.id }, + config = { + prompts = { + prepend = { + [1] = { + role = "system", + content = "Prepend text 1 here.", + }, + [2] = { + role = "system", + content = "Prepend text 2 here.", + }, + }, + append = { + [1] = { + role = "system", + content = "Append text 1 here.", + }, + [2] = { + role = "system", + content = "Append text 2 here.", + }, + }, + }, + }, + } + + -- start kong + assert(helpers.start_kong({ + -- set the strategy + database = strategy, + -- use the custom test template to create a local mock server + nginx_conf = "spec/fixtures/custom_nginx.template", + -- make sure our plugin gets loaded + plugins = "bundled," .. PLUGIN_NAME, + -- write & load declarative config, only if 'strategy=off' + declarative_config = strategy == "off" and helpers.make_yaml_file() or nil, + })) + end) + + lazy_teardown(function() + helpers.stop_kong(nil, true) + end) + + before_each(function() + client = helpers.proxy_client() + end) + + after_each(function() + if client then client:close() end + end) + + + + describe("request", function() + it("sends in a non-chat message", function() + local r = client:get("/request", { + headers = { + host = "test1.com", + ["Content-Type"] = "application/json", + }, + body = [[ + { + "anything": [ + { + "random": "data" + } + ] + }]], + method = "POST", + }) + + local body = assert.res_status(400, r) + local json = cjson.decode(body) + + assert.same(json, { error = true, message = "ai-prompt-decorator only support llm/chat type requests" }) + end) + + it("sends in an empty messages array", function() + local r = client:get("/request", { + headers = { + host = "test1.com", + ["Content-Type"] = "application/json", + }, + body = [[ + { + "messages": [] + }]], + method = "POST", + }) + + local body = assert.res_status(400, r) + local json = cjson.decode(body) + + assert.same(json, { error = true, message = "ai-prompt-decorator only support llm/chat type requests" }) + end) + end) + + end) + +end end