From 9502bb4322d446560363af6a06aaabaf282678ab Mon Sep 17 00:00:00 2001 From: Shreemaan Abhishek Date: Mon, 30 Dec 2024 14:42:18 +0545 Subject: [PATCH] fix(ai-proxy): query params in override.endpoint not being sent to LLMs (#11863) --- apisix/plugins/ai-proxy/drivers/openai.lua | 12 ++- t/plugin/ai-proxy2.t | 87 ++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/apisix/plugins/ai-proxy/drivers/openai.lua b/apisix/plugins/ai-proxy/drivers/openai.lua index cefc5a728eaa..af0bc97588d5 100644 --- a/apisix/plugins/ai-proxy/drivers/openai.lua +++ b/apisix/plugins/ai-proxy/drivers/openai.lua @@ -21,6 +21,7 @@ local http = require("resty.http") local url = require("socket.url") local pairs = pairs +local type = type -- globals local DEFAULT_HOST = "api.openai.com" @@ -54,6 +55,15 @@ function _M.request(conf, request_table, ctx) return nil, "failed to connect to LLM server: " .. err end + local query_params = conf.auth.query or {} + + if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query > 0 then + local args_tab = core.string.decode_args(parsed_url.query) + if type(args_tab) == "table" then + core.table.merge(query_params, args_tab) + end + end + local path = (endpoint and parsed_url.path or DEFAULT_PATH) local headers = (conf.auth.header or {}) @@ -64,7 +74,7 @@ function _M.request(conf, request_table, ctx) keepalive = conf.keepalive, ssl_verify = conf.ssl_verify, path = path, - query = conf.auth.query + query = query_params } if conf.model.options then diff --git a/t/plugin/ai-proxy2.t b/t/plugin/ai-proxy2.t index cda3786b7e75..f372e4fbdd5a 100644 --- a/t/plugin/ai-proxy2.t +++ b/t/plugin/ai-proxy2.t @@ -67,6 +67,31 @@ add_block_preprocessor(sub { end + ngx.status = 200 + ngx.say("passed") + } + } + + + location /test/params/in/overridden/endpoint { + content_by_lua_block { + local json = require("cjson.safe") + local core = require("apisix.core") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + + local query_auth = ngx.req.get_uri_args()["api_key"] + ngx.log(ngx.INFO, "found query params: ", core.json.stably_encode(ngx.req.get_uri_args())) + + if query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + ngx.status = 200 ngx.say("passed") } @@ -253,3 +278,65 @@ passed POST /anything { "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } --- error_code: 401 + + + +=== TEST 7: query params in override.endpoint should be sent to LLM +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy": { + "auth": { + "query": { + "api_key": "apikey" + } + }, + "model": { + "provider": "openai", + "name": "gpt-35-turbo-instruct", + "options": { + "max_tokens": 512, + "temperature": 1.0 + } + }, + "override": { + "endpoint": "http://localhost:6724/test/params/in/overridden/endpoint?some_query=yes" + }, + "ssl_verify": false + } + }, + "upstream": { + "type": "roundrobin", + "nodes": { + "canbeanything.com": 1 + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 8: send request +--- request +POST /anything +{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 200 +--- error_log +found query params: {"api_key":"apikey","some_query":"yes"} +--- response_body +passed