This is an automated email from the ASF dual-hosted git repository. shreemaanabhishek pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/apisix.git
The following commit(s) were added to refs/heads/master by this push: new 53e5b0227 refactor(ai): ai-proxy and ai-proxy-multi (#12030) 53e5b0227 is described below commit 53e5b02270e307942f93fc2f94e1dc1a2fed9cd6 Author: Shreemaan Abhishek <shreemaan.abhis...@gmail.com> AuthorDate: Tue Mar 11 16:11:50 2025 +0545 refactor(ai): ai-proxy and ai-proxy-multi (#12030) --- apisix/cli/config.lua | 4 +- apisix/plugins/ai-aws-content-moderation.lua | 2 +- apisix/plugins/ai-drivers/openai-base.lua | 128 +++++++++++++++++++++++++-- apisix/plugins/ai-drivers/schema.lua | 44 +++++++++ apisix/plugins/ai-proxy-multi.lua | 123 ++++++++++--------------- apisix/plugins/ai-proxy.lua | 36 +++----- apisix/plugins/ai-proxy/base.lua | 95 ++++++-------------- apisix/plugins/ai-proxy/schema.lua | 73 +++++++-------- t/admin/plugins.t | 4 +- t/plugin/ai-proxy-multi.balancer.t | 120 +++---------------------- t/plugin/ai-proxy-multi.openai-compatible.t | 34 ++----- t/plugin/ai-proxy-multi.t | 57 +++++++----- t/plugin/ai-proxy-multi2.t | 93 ++++--------------- t/plugin/ai-proxy.openai-compatible.t | 57 ++++-------- t/plugin/ai-proxy.t | 78 +++++++--------- t/plugin/ai-proxy2.t | 47 ++++------ 16 files changed, 432 insertions(+), 563 deletions(-) diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua index be7694130..56af978c2 100644 --- a/apisix/cli/config.lua +++ b/apisix/cli/config.lua @@ -219,13 +219,13 @@ local _M = { "ai-prompt-decorator", "ai-prompt-guard", "ai-rag", + "ai-proxy-multi", + "ai-proxy", "ai-aws-content-moderation", "proxy-mirror", "proxy-rewrite", "workflow", "api-breaker", - "ai-proxy", - "ai-proxy-multi", "limit-conn", "limit-count", "limit-req", diff --git a/apisix/plugins/ai-aws-content-moderation.lua b/apisix/plugins/ai-aws-content-moderation.lua index c7b54ed4e..e5a870bd3 100644 --- a/apisix/plugins/ai-aws-content-moderation.lua +++ b/apisix/plugins/ai-aws-content-moderation.lua @@ -72,7 +72,7 @@ local schema = { local _M = { version = 0.1, - priority = 1040, -- TODO: might change + priority = 1050, name = "ai-aws-content-moderation", schema = schema, } diff --git a/apisix/plugins/ai-drivers/openai-base.lua b/apisix/plugins/ai-drivers/openai-base.lua index a9eb31059..4f0b38afe 100644 --- a/apisix/plugins/ai-drivers/openai-base.lua +++ b/apisix/plugins/ai-drivers/openai-base.lua @@ -20,12 +20,20 @@ local mt = { __index = _M } +local CONTENT_TYPE_JSON = "application/json" + local core = require("apisix.core") local http = require("resty.http") local url = require("socket.url") +local schema = require("apisix.plugins.ai-drivers.schema") +local ngx_re = require("ngx.re") + +local ngx_print = ngx.print +local ngx_flush = ngx.flush local pairs = pairs local type = type +local ipairs = ipairs local setmetatable = setmetatable @@ -40,6 +48,26 @@ function _M.new(opts) end +function _M.validate_request(ctx) + local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON + if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then + return nil, "unsupported content-type: " .. ct .. ", only application/json is supported" + end + + local request_table, err = core.request.get_json_request_body_table() + if not request_table then + return nil, err + end + + local ok, err = core.schema.check(schema.chat_request_schema, request_table) + if not ok then + return nil, "request format doesn't match schema: " .. err + end + + return request_table, nil +end + + function _M.request(self, conf, request_table, extra_opts) local httpc, err = http.new() if not httpc then @@ -54,11 +82,11 @@ function _M.request(self, conf, request_table, extra_opts) end local ok, err = httpc:connect({ - scheme = endpoint and parsed_url.scheme or "https", - host = endpoint and parsed_url.host or self.host, - port = endpoint and parsed_url.port or self.port, + scheme = parsed_url and parsed_url.scheme or "https", + host = parsed_url and parsed_url.host or self.host, + port = parsed_url and parsed_url.port or self.port, ssl_verify = conf.ssl_verify, - ssl_server_name = endpoint and parsed_url.host or self.host, + ssl_server_name = parsed_url and parsed_url.host or self.host, pool_size = conf.keepalive and conf.keepalive_pool, }) @@ -75,7 +103,7 @@ function _M.request(self, conf, request_table, extra_opts) end end - local path = (endpoint and parsed_url.path or self.path) + local path = (parsed_url and parsed_url.path or self.path) local headers = extra_opts.headers headers["Content-Type"] = "application/json" @@ -106,7 +134,95 @@ function _M.request(self, conf, request_table, extra_opts) return nil, err end - return res, nil, httpc + return res, nil end + +function _M.read_response(ctx, res) + local body_reader = res.body_reader + if not body_reader then + core.log.error("AI service sent no response body") + return 500 + end + + local content_type = res.headers["Content-Type"] + core.response.set_header("Content-Type", content_type) + + if core.string.find(content_type, "text/event-stream") then + while true do + local chunk, err = body_reader() -- will read chunk by chunk + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + + ngx_print(chunk) + ngx_flush(true) + + local events, err = ngx_re.split(chunk, "\n") + if err then + core.log.warn("failed to split response chunk [", chunk, "] to events: ", err) + goto CONTINUE + end + + for _, event in ipairs(events) do + if not core.string.find(event, "data:") or core.string.find(event, "[DONE]") then + goto CONTINUE + end + + local parts, err = ngx_re.split(event, ":", nil, nil, 2) + if err then + core.log.warn("failed to split data event [", event, "] to parts: ", err) + goto CONTINUE + end + + if #parts ~= 2 then + core.log.warn("malformed data event: ", event) + goto CONTINUE + end + + local data, err = core.json.decode(parts[2]) + if err then + core.log.warn("failed to decode data event [", parts[2], "] to json: ", err) + goto CONTINUE + end + + -- usage field is null for non-last events, null is parsed as userdata type + if data and data.usage and type(data.usage) ~= "userdata" then + ctx.ai_token_usage = { + prompt_tokens = data.usage.prompt_tokens or 0, + completion_tokens = data.usage.completion_tokens or 0, + total_tokens = data.usage.total_tokens or 0, + } + end + end + + ::CONTINUE:: + end + return + end + + local raw_res_body, err = res:read_body() + if not raw_res_body then + core.log.error("failed to read response body: ", err) + return 500 + end + local res_body, err = core.json.decode(raw_res_body) + if err then + core.log.warn("invalid response body from ai service: ", raw_res_body, " err: ", err, + ", it will cause token usage not available") + else + ctx.ai_token_usage = { + prompt_tokens = res_body.usage and res_body.usage.prompt_tokens or 0, + completion_tokens = res_body.usage and res_body.usage.completion_tokens or 0, + total_tokens = res_body.usage and res_body.usage.total_tokens or 0, + } + end + return res.status, raw_res_body +end + + return _M diff --git a/apisix/plugins/ai-drivers/schema.lua b/apisix/plugins/ai-drivers/schema.lua new file mode 100644 index 000000000..7a469bd01 --- /dev/null +++ b/apisix/plugins/ai-drivers/schema.lua @@ -0,0 +1,44 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local _M = {} + +_M.chat_request_schema = { + type = "object", + properties = { + messages = { + type = "array", + minItems = 1, + items = { + properties = { + role = { + type = "string", + enum = {"system", "user", "assistant"} + }, + content = { + type = "string", + minLength = "1", + }, + }, + additionalProperties = false, + required = {"role", "content"}, + }, + } + }, + required = {"messages"} +} + +return _M diff --git a/apisix/plugins/ai-proxy-multi.lua b/apisix/plugins/ai-proxy-multi.lua index 4993270b9..3b4fc7e84 100644 --- a/apisix/plugins/ai-proxy-multi.lua +++ b/apisix/plugins/ai-proxy-multi.lua @@ -17,8 +17,8 @@ local core = require("apisix.core") local schema = require("apisix.plugins.ai-proxy.schema") +local base = require("apisix.plugins.ai-proxy.base") local plugin = require("apisix.plugin") -local base = require("apisix.plugins.ai-proxy.base") local require = require local pcall = pcall @@ -36,7 +36,7 @@ local lrucache_server_picker = core.lrucache.new({ local plugin_name = "ai-proxy-multi" local _M = { version = 0.5, - priority = 998, + priority = 1041, name = plugin_name, schema = schema.ai_proxy_multi_schema, } @@ -64,10 +64,16 @@ end function _M.check_schema(conf) - for _, provider in ipairs(conf.providers) do - local ai_driver = pcall(require, "apisix.plugins.ai-drivers." .. provider.name) + local ok, err = core.schema.check(schema.ai_proxy_multi_schema, conf) + if not ok then + return false, err + end + + for _, instance in ipairs(conf.instances) do + local ai_driver, err = pcall(require, "apisix.plugins.ai-drivers." .. instance.provider) if not ai_driver then - return false, "provider: " .. provider.name .. " is not supported." + core.log.warn("fail to require ai provider: ", instance.provider, ", err", err) + return false, "ai provider: " .. instance.provider .. " is not supported." end end local algo = core.table.try_read_attr(conf, "balancer", "algorithm") @@ -96,21 +102,21 @@ function _M.check_schema(conf) end end - return core.schema.check(schema.ai_proxy_multi_schema, conf) + return ok end -local function transform_providers(new_providers, provider) - if not new_providers._priority_index then - new_providers._priority_index = {} +local function transform_instances(new_instances, instance) + if not new_instances._priority_index then + new_instances._priority_index = {} end - if not new_providers[provider.priority] then - new_providers[provider.priority] = {} - core.table.insert(new_providers._priority_index, provider.priority) + if not new_instances[instance.priority] then + new_instances[instance.priority] = {} + core.table.insert(new_instances._priority_index, instance.priority) end - new_providers[provider.priority][provider.name] = provider.weight + new_instances[instance.priority][instance.name] = instance.weight end @@ -120,37 +126,31 @@ local function create_server_picker(conf, ups_tab) pickers[conf.balancer.algorithm] = require("apisix.balancer." .. conf.balancer.algorithm) picker = pickers[conf.balancer.algorithm] end - local new_providers = {} - for i, provider in ipairs(conf.providers) do - transform_providers(new_providers, provider) + local new_instances = {} + for _, ins in ipairs(conf.instances) do + transform_instances(new_instances, ins) end - if #new_providers._priority_index > 1 then - core.log.info("new providers: ", core.json.delay_encode(new_providers)) - return priority_balancer.new(new_providers, ups_tab, picker) + if #new_instances._priority_index > 1 then + core.log.info("new instances: ", core.json.delay_encode(new_instances)) + return priority_balancer.new(new_instances, ups_tab, picker) end core.log.info("upstream nodes: ", - core.json.delay_encode(new_providers[new_providers._priority_index[1]])) - return picker.new(new_providers[new_providers._priority_index[1]], ups_tab) + core.json.delay_encode(new_instances[new_instances._priority_index[1]])) + return picker.new(new_instances[new_instances._priority_index[1]], ups_tab) end -local function get_provider_conf(providers, name) - for i, provider in ipairs(providers) do - if provider.name == name then - return provider +local function get_instance_conf(instances, name) + for _, ins in ipairs(instances) do + if ins.name == name then + return ins end end end local function pick_target(ctx, conf, ups_tab) - if ctx.ai_balancer_try_count > 1 then - if ctx.server_picker and ctx.server_picker.after_balance then - ctx.server_picker.after_balance(ctx, true) - end - end - local server_picker = ctx.server_picker if not server_picker then server_picker = lrucache_server_picker(ctx.matched_route.key, plugin.conf_version(conf), @@ -160,40 +160,31 @@ local function pick_target(ctx, conf, ups_tab) return internal_server_error, "failed to fetch server picker" end - local provider_name = server_picker.get(ctx) - local provider_conf = get_provider_conf(conf.providers, provider_name) +local instance_name = server_picker.get(ctx) + local instance_conf = get_instance_conf(conf.instances, instance_name) - ctx.balancer_server = provider_name + ctx.balancer_server = instance_name ctx.server_picker = server_picker - return provider_name, provider_conf + return instance_name, instance_conf end -local function get_load_balanced_provider(ctx, conf, ups_tab, request_table) - ctx.ai_balancer_try_count = (ctx.ai_balancer_try_count or 0) + 1 - local provider_name, provider_conf - if #conf.providers == 1 then - provider_name = conf.providers[1].name - provider_conf = conf.providers[1] +local function pick_ai_instance(ctx, conf, ups_tab) + local instance_name, instance_conf + if #conf.instances == 1 then + instance_name = conf.instances[1].name + instance_conf = conf.instances[1] else - provider_name, provider_conf = pick_target(ctx, conf, ups_tab) - end - - core.log.info("picked provider: ", provider_name) - if provider_conf.model then - request_table.model = provider_conf.model + instance_name, instance_conf = pick_target(ctx, conf, ups_tab) end - provider_conf.__name = provider_name - return provider_name, provider_conf -end - -local function get_model_name(...) + core.log.info("picked instance: ", instance_name) + return instance_name, instance_conf end -local function proxy_request_to_llm(conf, request_table, ctx) +function _M.access(conf, ctx) local ups_tab = {} local algo = core.table.try_read_attr(conf, "balancer", "algorithm") if algo == "chash" then @@ -203,31 +194,13 @@ local function proxy_request_to_llm(conf, request_table, ctx) ups_tab["hash_on"] = hash_on end - ::retry:: - local provider, provider_conf = get_load_balanced_provider(ctx, conf, ups_tab, request_table) - local extra_opts = { - endpoint = core.table.try_read_attr(provider_conf, "override", "endpoint"), - query_params = provider_conf.auth.query or {}, - headers = (provider_conf.auth.header or {}), - model_options = provider_conf.options, - } - - local ai_driver = require("apisix.plugins.ai-drivers." .. provider) - local res, err, httpc = ai_driver:request(conf, request_table, extra_opts) - if not res then - if (ctx.ai_balancer_try_count or 0) < 2 then - core.log.warn("failed to send request to LLM: ", err, ". Retrying...") - goto retry - end - return nil, err, nil - end - - request_table.model = provider_conf.model - return res, nil, httpc + local name, ai_instance = pick_ai_instance(ctx, conf, ups_tab) + ctx.picked_ai_instance_name = name + ctx.picked_ai_instance = ai_instance end -_M.access = base.new(proxy_request_to_llm, get_model_name) +_M.before_proxy = base.before_proxy return _M diff --git a/apisix/plugins/ai-proxy.lua b/apisix/plugins/ai-proxy.lua index ffc82f856..2301a65e6 100644 --- a/apisix/plugins/ai-proxy.lua +++ b/apisix/plugins/ai-proxy.lua @@ -24,41 +24,33 @@ local pcall = pcall local plugin_name = "ai-proxy" local _M = { version = 0.5, - priority = 999, + priority = 1040, name = plugin_name, - schema = schema, + schema = schema.ai_proxy_schema, } function _M.check_schema(conf) - local ai_driver = pcall(require, "apisix.plugins.ai-drivers." .. conf.model.provider) + local ok, err = core.schema.check(schema.ai_proxy_schema, conf) + if not ok then + return false, err + end + local ai_driver, err = pcall(require, "apisix.plugins.ai-drivers." .. conf.provider) if not ai_driver then - return false, "provider: " .. conf.model.provider .. " is not supported." + core.log.warn("fail to require ai provider: ", conf.provider, ", err", err) + return false, "ai provider: " .. conf.provider .. " is not supported." end - return core.schema.check(schema.ai_proxy_schema, conf) + return ok end -local function get_model_name(conf) - return conf.model.name +function _M.access(conf, ctx) + ctx.picked_ai_instance_name = "ai-proxy" + ctx.picked_ai_instance = conf end -local function proxy_request_to_llm(conf, request_table, ctx) - local ai_driver = require("apisix.plugins.ai-drivers." .. conf.model.provider) - local extra_opts = { - endpoint = core.table.try_read_attr(conf, "override", "endpoint"), - query_params = conf.auth.query or {}, - headers = (conf.auth.header or {}), - model_options = conf.model.options - } - local res, err, httpc = ai_driver:request(conf, request_table, extra_opts) - if not res then - return nil, err, nil - end - return res, nil, httpc -end +_M.before_proxy = base.before_proxy -_M.access = base.new(proxy_request_to_llm, get_model_name) return _M diff --git a/apisix/plugins/ai-proxy/base.lua b/apisix/plugins/ai-proxy/base.lua index 6de6ceb8c..d8f1a8944 100644 --- a/apisix/plugins/ai-proxy/base.lua +++ b/apisix/plugins/ai-proxy/base.lua @@ -15,84 +15,43 @@ -- limitations under the License. -- -local CONTENT_TYPE_JSON = "application/json" local core = require("apisix.core") +local require = require local bad_request = ngx.HTTP_BAD_REQUEST local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR -local schema = require("apisix.plugins.ai-proxy.schema") -local ngx_print = ngx.print -local ngx_flush = ngx.flush - -local function keepalive_or_close(conf, httpc) - if conf.set_keepalive then - httpc:set_keepalive(10000, 100) - return - end - httpc:close() -end local _M = {} -function _M.new(proxy_request_to_llm_func, get_model_name_func) - return function(conf, ctx) - local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON - if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then - return bad_request, "unsupported content-type: " .. ct - end - - local request_table, err = core.request.get_json_request_body_table() - if not request_table then - return bad_request, err - end - - local ok, err = core.schema.check(schema.chat_request_schema, request_table) - if not ok then - return bad_request, "request format doesn't match schema: " .. err - end - - request_table.model = get_model_name_func(conf) +function _M.before_proxy(conf, ctx) + local ai_instance = ctx.picked_ai_instance + local ai_driver = require("apisix.plugins.ai-drivers." .. ai_instance.provider) - if core.table.try_read_attr(conf, "model", "options", "stream") then - request_table.stream = true - end - - local res, err, httpc = proxy_request_to_llm_func(conf, request_table, ctx) - if not res then - core.log.error("failed to send request to LLM service: ", err) - return internal_server_error - end + local request_body, err = ai_driver.validate_request(ctx) + if not request_body then + return bad_request, err + end - local body_reader = res.body_reader - if not body_reader then - core.log.error("LLM sent no response body") - return internal_server_error - end + local extra_opts = { + endpoint = core.table.try_read_attr(ai_instance, "override", "endpoint"), + query_params = ai_instance.auth.query or {}, + headers = (ai_instance.auth.header or {}), + model_options = ai_instance.options, + } + + if request_body.stream then + request_body.stream_options = { + include_usage = true + } + end - if request_table.stream then - while true do - local chunk, err = body_reader() -- will read chunk by chunk - if err then - core.log.error("failed to read response chunk: ", err) - break - end - if not chunk then - break - end - ngx_print(chunk) - ngx_flush(true) - end - keepalive_or_close(conf, httpc) - return - else - local res_body, err = res:read_body() - if not res_body then - core.log.error("failed to read response body: ", err) - return internal_server_error - end - keepalive_or_close(conf, httpc) - return res.status, res_body - end + local res, err = ai_driver:request(conf, request_body, extra_opts) + if not res then + core.log.error("failed to send request to AI service: ", err) + return internal_server_error end + + return ai_driver.read_response(ctx, res) end + return _M diff --git a/apisix/plugins/ai-proxy/schema.lua b/apisix/plugins/ai-proxy/schema.lua index a2c25e924..7170b5bfc 100644 --- a/apisix/plugins/ai-proxy/schema.lua +++ b/apisix/plugins/ai-proxy/schema.lua @@ -38,6 +38,10 @@ local model_options_schema = { description = "Key/value settings for the model", type = "object", properties = { + model = { + type = "string", + description = "Model to execute.", + }, max_tokens = { type = "integer", description = "Defines the max_tokens, if using chat or completion models.", @@ -74,36 +78,10 @@ local model_options_schema = { description = "Stream response by SSE", type = "boolean", } - } -} - -local model_schema = { - type = "object", - properties = { - provider = { - type = "string", - description = "Name of the AI service provider.", - enum = { "openai", "openai-compatible", "deepseek" }, -- add more providers later - }, - name = { - type = "string", - description = "Model name to execute.", - }, - options = model_options_schema, - override = { - type = "object", - properties = { - endpoint = { - type = "string", - description = "To be specified to override the host of the AI provider", - }, - } - } }, - required = {"provider", "name"} } -local provider_schema = { +local ai_instance_schema = { type = "array", minItems = 1, items = { @@ -111,13 +89,15 @@ local provider_schema = { properties = { name = { type = "string", - description = "Name of the AI service provider.", - enum = { "openai", "deepseek", "openai-compatible" }, -- add more providers later - + minLength = 1, + maxLength = 100, + description = "Name of the AI service instance.", }, - model = { + provider = { type = "string", - description = "Model to execute.", + description = "Type of the AI service instance.", + enum = { "openai", "deepseek", "openai-compatible" }, -- add more providers later + }, priority = { type = "integer", @@ -126,6 +106,7 @@ local provider_schema = { }, weight = { type = "integer", + minimum = 0, }, auth = auth_schema, options = model_options_schema, @@ -134,12 +115,12 @@ local provider_schema = { properties = { endpoint = { type = "string", - description = "To be specified to override the host of the AI provider", + description = "To be specified to override the endpoint of the AI Instance", }, }, }, }, - required = {"name", "model", "auth"} + required = {"name", "provider", "auth"} }, } @@ -147,8 +128,14 @@ local provider_schema = { _M.ai_proxy_schema = { type = "object", properties = { + provider = { + type = "string", + description = "Type of the AI service instance.", + enum = { "openai", "deepseek", "openai-compatible" }, -- add more providers later + + }, auth = auth_schema, - model = model_schema, + options = model_options_schema, timeout = { type = "integer", minimum = 1, @@ -159,8 +146,17 @@ _M.ai_proxy_schema = { keepalive = {type = "boolean", default = true}, keepalive_pool = {type = "integer", minimum = 1, default = 30}, ssl_verify = {type = "boolean", default = true }, + override = { + type = "object", + properties = { + endpoint = { + type = "string", + description = "To be specified to override the endpoint of the AI Instance", + }, + }, + }, }, - required = {"model", "auth"} + required = {"provider", "auth"} } _M.ai_proxy_multi_schema = { @@ -191,7 +187,7 @@ _M.ai_proxy_multi_schema = { }, default = { algorithm = "roundrobin" } }, - providers = provider_schema, + instances = ai_instance_schema, timeout = { type = "integer", minimum = 1, @@ -200,11 +196,10 @@ _M.ai_proxy_multi_schema = { description = "timeout in milliseconds", }, keepalive = {type = "boolean", default = true}, - keepalive_timeout = {type = "integer", minimum = 1000, default = 60000}, keepalive_pool = {type = "integer", minimum = 1, default = 30}, ssl_verify = {type = "boolean", default = true }, }, - required = {"providers", } + required = {"instances"} } _M.chat_request_schema = { diff --git a/t/admin/plugins.t b/t/admin/plugins.t index 20cf4a8fc..c43d5ffeb 100644 --- a/t/admin/plugins.t +++ b/t/admin/plugins.t @@ -99,6 +99,8 @@ ai-prompt-template ai-prompt-decorator ai-rag ai-aws-content-moderation +ai-proxy-multi +ai-proxy proxy-mirror proxy-rewrite workflow @@ -106,8 +108,6 @@ api-breaker limit-conn limit-count limit-req -ai-proxy -ai-proxy-multi gzip server-info traffic-split diff --git a/t/plugin/ai-proxy-multi.balancer.t b/t/plugin/ai-proxy-multi.balancer.t index da26957fb..09076e4a8 100644 --- a/t/plugin/ai-proxy-multi.balancer.t +++ b/t/plugin/ai-proxy-multi.balancer.t @@ -158,10 +158,10 @@ __DATA__ "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { "name": "openai", - "model": "gpt-4", + "provider": "openai", "weight": 4, "auth": { "header": { @@ -169,6 +169,7 @@ __DATA__ } }, "options": { + "model": "gpt-4", "max_tokens": 512, "temperature": 1.0 }, @@ -178,7 +179,7 @@ __DATA__ }, { "name": "deepseek", - "model": "gpt-4", + "provider": "deepseek", "weight": 1, "auth": { "header": { @@ -186,6 +187,7 @@ __DATA__ } }, "options": { + "model": "deepseek-chat", "max_tokens": 512, "temperature": 1.0 }, @@ -239,7 +241,7 @@ passed end table.sort(restab) - ngx.log(ngx.WARN, "test picked providers: ", table.concat(restab, ".")) + ngx.log(ngx.WARN, "test picked instances: ", table.concat(restab, ".")) } } @@ -266,10 +268,10 @@ deepseek.deepseek.openai.openai.openai.openai.openai.openai.openai.openai "hash_on": "vars", "key": "query_string" }, - "providers": [ + "instances": [ { "name": "openai", - "model": "gpt-4", + "provider": "openai", "weight": 4, "auth": { "header": { @@ -277,6 +279,7 @@ deepseek.deepseek.openai.openai.openai.openai.openai.openai.openai.openai } }, "options": { + "model": "gpt-4", "max_tokens": 512, "temperature": 1.0 }, @@ -286,7 +289,7 @@ deepseek.deepseek.openai.openai.openai.openai.openai.openai.openai.openai }, { "name": "deepseek", - "model": "gpt-4", + "provider": "deepseek", "weight": 1, "auth": { "header": { @@ -294,6 +297,7 @@ deepseek.deepseek.openai.openai.openai.openai.openai.openai.openai.openai } }, "options": { + "model": "deepseek-chat", "max_tokens": 512, "temperature": 1.0 }, @@ -366,105 +370,3 @@ GET /t --- error_log distribution: deepseek: 2 distribution: openai: 8 - - - -=== TEST 5: retry logic with different priorities ---- config - location /t { - content_by_lua_block { - local t = require("lib.test_admin").test - local code, body = t('/apisix/admin/routes/1', - ngx.HTTP_PUT, - [[{ - "uri": "/anything", - "plugins": { - "ai-proxy-multi": { - "providers": [ - { - "name": "openai", - "model": "gpt-4", - "weight": 1, - "priority": 1, - "auth": { - "header": { - "Authorization": "Bearer token" - } - }, - "options": { - "max_tokens": 512, - "temperature": 1.0 - }, - "override": { - "endpoint": "http://localhost:9999" - } - }, - { - "name": "deepseek", - "model": "gpt-4", - "priority": 0, - "weight": 1, - "auth": { - "header": { - "Authorization": "Bearer token" - } - }, - "options": { - "max_tokens": 512, - "temperature": 1.0 - }, - "override": { - "endpoint": "http://localhost:6724/chat/completions" - } - } - ], - "ssl_verify": false - } - }, - "upstream": { - "type": "roundrobin", - "nodes": { - "canbeanything.com": 1 - } - } - }]] - ) - - if code >= 300 then - ngx.status = code - end - ngx.say(body) - } - } ---- response_body -passed - - - -=== TEST 6: test ---- config - location /t { - content_by_lua_block { - local http = require "resty.http" - local uri = "http://127.0.0.1:" .. ngx.var.server_port - .. "/anything" - - local restab = {} - - local body = [[{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]] - local httpc = http.new() - local res, err = httpc:request_uri(uri, {method = "POST", body = body}) - if not res then - ngx.say(err) - return - end - ngx.say(res.body) - - } - } ---- request -GET /t ---- response_body -deepseek ---- error_log -failed to send request to LLM: failed to connect to LLM server: connection refused. Retrying... diff --git a/t/plugin/ai-proxy-multi.openai-compatible.t b/t/plugin/ai-proxy-multi.openai-compatible.t index f80be6dc4..923c12d37 100644 --- a/t/plugin/ai-proxy-multi.openai-compatible.t +++ b/t/plugin/ai-proxy-multi.openai-compatible.t @@ -52,26 +52,6 @@ _EOC_ default_type 'application/json'; - location /anything { - content_by_lua_block { - local json = require("cjson.safe") - - if ngx.req.get_method() ~= "POST" then - ngx.status = 400 - ngx.say("Unsupported request method: ", ngx.req.get_method()) - end - ngx.req.read_body() - local body = ngx.req.get_body_data() - - if body ~= "SELECT * FROM STUDENTS" then - ngx.status = 503 - ngx.say("passthrough doesn't work") - return - end - ngx.say('{"foo", "bar"}') - } - } - location /v1/chat/completions { content_by_lua_block { local json = require("cjson.safe") @@ -158,10 +138,10 @@ __DATA__ "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai-compatible", - "model": "custom", + "name": "self-hosted", + "provider": "openai-compatible", "weight": 1, "auth": { "header": { @@ -169,6 +149,7 @@ __DATA__ } }, "options": { + "model": "custom", "max_tokens": 512, "temperature": 1.0 }, @@ -223,10 +204,10 @@ qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai-compatible", - "model": "custom-instruct", + "name": "self-hosted", + "provider": "openai-compatible", "weight": 1, "auth": { "header": { @@ -234,6 +215,7 @@ qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ } }, "options": { + "model": "custom-instruct", "max_tokens": 512, "temperature": 1.0, "stream": true diff --git a/t/plugin/ai-proxy-multi.t b/t/plugin/ai-proxy-multi.t index 7969dbd81..0da04c9de 100644 --- a/t/plugin/ai-proxy-multi.t +++ b/t/plugin/ai-proxy-multi.t @@ -133,10 +133,13 @@ __DATA__ content_by_lua_block { local plugin = require("apisix.plugins.ai-proxy-multi") local ok, err = plugin.check_schema({ - providers = { + instances = { { - name = "openai", - model = "gpt-4", + name = "openai-official", + provider = "openai", + options = { + model = "gpt-4", + }, weight = 1, auth = { header = { @@ -165,10 +168,13 @@ passed content_by_lua_block { local plugin = require("apisix.plugins.ai-proxy-multi") local ok, err = plugin.check_schema({ - providers = { + instances = { { - name = "some-unique", - model = "gpt-4", + name = "self-hosted", + provider = "some-unique", + options = { + model = "gpt-4", + }, weight = 1, auth = { header = { @@ -187,7 +193,7 @@ passed } } --- response_body eval -qr/.*provider: some-unique is not supported.*/ +qr/.*property "provider" validation failed: matches none of the enum values*/ @@ -202,10 +208,10 @@ qr/.*provider: some-unique is not supported.*/ "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "gpt-4", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "header": { @@ -213,6 +219,7 @@ qr/.*provider: some-unique is not supported.*/ } }, "options": { + "model": "gpt-4", "max_tokens": 512, "temperature": 1.0 }, @@ -265,10 +272,10 @@ Unauthorized "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "gpt-4", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "header": { @@ -276,6 +283,7 @@ Unauthorized } }, "options": { + "model": "gpt-4", "max_tokens": 512, "temperature": 1.0 }, @@ -360,7 +368,7 @@ prompt%3Dwhat%2520is%25201%2520%252B%25201 Content-Type: application/x-www-form-urlencoded --- error_code: 400 --- response_body chomp -unsupported content-type: application/x-www-form-urlencoded +unsupported content-type: application/x-www-form-urlencoded, only application/json is supported @@ -387,10 +395,10 @@ request format doesn't match schema: property "messages" is required "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "some-model", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "header": { @@ -398,6 +406,7 @@ request format doesn't match schema: property "messages" is required } }, "options": { + "model": "some-model", "foo": "bar", "temperature": 1.0 }, @@ -461,10 +470,10 @@ options_works "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "some-model", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "header": { @@ -472,6 +481,7 @@ options_works } }, "options": { + "model": "some-model", "foo": "bar", "temperature": 1.0 }, @@ -534,10 +544,10 @@ path override works "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "gpt-35-turbo-instruct", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "header": { @@ -545,6 +555,7 @@ path override works } }, "options": { + "model": "gpt-35-turbo-instruct", "max_tokens": 512, "temperature": 1.0, "stream": true diff --git a/t/plugin/ai-proxy-multi2.t b/t/plugin/ai-proxy-multi2.t index 00c1714a3..c54e7a67e 100644 --- a/t/plugin/ai-proxy-multi2.t +++ b/t/plugin/ai-proxy-multi2.t @@ -123,10 +123,10 @@ __DATA__ "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "gpt-35-turbo-instruct", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "query": { @@ -134,6 +134,7 @@ __DATA__ } }, "options": { + "model": "gpt-35-turbo-instruct", "max_tokens": 512, "temperature": 1.0 }, @@ -186,10 +187,10 @@ Unauthorized "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "gpt-35-turbo-instruct", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "query": { @@ -197,6 +198,7 @@ Unauthorized } }, "options": { + "model": "gpt-35-turbo-instruct", "max_tokens": 512, "temperature": 1.0 }, @@ -249,10 +251,10 @@ passed "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "gpt-35-turbo-instruct", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "header": { @@ -260,6 +262,7 @@ passed } }, "options": { + "model": "gpt-35-turbo-instruct", "max_tokens": 512, "temperature": 1.0 } @@ -308,10 +311,10 @@ POST /anything "uri": "/anything", "plugins": { "ai-proxy-multi": { - "providers": [ + "instances": [ { - "name": "openai", - "model": "gpt-35-turbo-instruct", + "name": "openai-official", + "provider": "openai", "weight": 1, "auth": { "query": { @@ -319,6 +322,7 @@ POST /anything } }, "options": { + "model": "gpt-35-turbo-instruct", "max_tokens": 512, "temperature": 1.0 }, @@ -359,68 +363,3 @@ POST /anything found query params: {"api_key":"apikey","some_query":"yes"} --- response_body passed - - - -=== TEST 9: set route with unavailable endpoint ---- config - location /t { - content_by_lua_block { - local t = require("lib.test_admin").test - local code, body = t('/apisix/admin/routes/1', - ngx.HTTP_PUT, - [[{ - "uri": "/anything", - "plugins": { - "ai-proxy-multi": { - "providers": [ - { - "name": "openai", - "model": "gpt-4", - "weight": 1, - "auth": { - "header": { - "Authorization": "Bearer token" - } - }, - "options": { - "max_tokens": 512, - "temperature": 1.0 - }, - "override": { - "endpoint": "http://unavailable.endpoint.ehfwuehr:404" - } - } - ], - "ssl_verify": false - } - }, - "upstream": { - "type": "roundrobin", - "nodes": { - "canbeanything.com": 1 - } - } - }]] - ) - - if code >= 300 then - ngx.status = code - end - ngx.say(body) - } - } ---- response_body -passed - - - -=== TEST 10: ai-proxy-multi should retry once and fail -# i.e it should not attempt to proxy request endlessly ---- request -POST /anything -{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] } ---- error_code: 500 ---- error_log -parse_domain(): failed to parse domain: unavailable.endpoint.ehfwuehr, error: failed to query the DNS server: dns -phase_func(): failed to send request to LLM service: failed to connect to LLM server: failed to parse domain diff --git a/t/plugin/ai-proxy.openai-compatible.t b/t/plugin/ai-proxy.openai-compatible.t index a98161a48..84ae175da 100644 --- a/t/plugin/ai-proxy.openai-compatible.t +++ b/t/plugin/ai-proxy.openai-compatible.t @@ -132,18 +132,16 @@ __DATA__ "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai-compatible", "auth": { "header": { "Authorization": "Bearer token" } }, - "model": { - "provider": "openai-compatible", - "name": "custom", - "options": { - "max_tokens": 512, - "temperature": 1.0 - } + "options": { + "model": "custom", + "max_tokens": 512, + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724/v1/chat/completions" @@ -194,18 +192,16 @@ qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/ "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai-compatible", "auth": { "header": { "Authorization": "Bearer token" } }, - "model": { - "provider": "openai-compatible", - "name": "some-model", - "options": { - "foo": "bar", - "temperature": 1.0 - } + "options": { + "model": "some-model", + "foo": "bar", + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724/random" @@ -264,19 +260,17 @@ path override works "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai-compatible", "auth": { "header": { "Authorization": "Bearer token" } }, - "model": { - "provider": "openai-compatible", - "name": "custom", - "options": { - "max_tokens": 512, - "temperature": 1.0, - "stream": true - } + "options": { + "model": "custom", + "max_tokens": 512, + "temperature": 1.0, + "stream": true }, "override": { "endpoint": "http://localhost:7737/v1/chat/completions" @@ -343,22 +337,3 @@ passed ngx.say(err) return end - - local final_res = {} - while true do - local chunk, err = res.body_reader() -- will read chunk by chunk - if err then - core.log.error("failed to read response chunk: ", err) - break - end - if not chunk then - break - end - core.table.insert_tail(final_res, chunk) - end - - ngx.print(#final_res .. final_res[6]) - } - } ---- response_body_like eval -qr/6data: \[DONE\]\n\n/ diff --git a/t/plugin/ai-proxy.t b/t/plugin/ai-proxy.t index 8cfd88018..08220fc3c 100644 --- a/t/plugin/ai-proxy.t +++ b/t/plugin/ai-proxy.t @@ -127,9 +127,9 @@ __DATA__ content_by_lua_block { local plugin = require("apisix.plugins.ai-proxy") local ok, err = plugin.check_schema({ - model = { - provider = "openai", - name = "gpt-4", + provider = "openai", + options = { + model = "gpt-4", }, auth = { header = { @@ -156,9 +156,9 @@ passed content_by_lua_block { local plugin = require("apisix.plugins.ai-proxy") local ok, err = plugin.check_schema({ - model = { - provider = "some-unique", - name = "gpt-4", + provider = "some-unique", + options = { + model = "gpt-4", }, auth = { header = { @@ -175,7 +175,7 @@ passed } } --- response_body eval -qr/.*provider: some-unique is not supported.*/ +qr/.*property "provider" validation failed: matches none of the enum values.*/ @@ -190,18 +190,16 @@ qr/.*provider: some-unique is not supported.*/ "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", "auth": { "header": { "Authorization": "Bearer wrongtoken" } }, - "model": { - "provider": "openai", - "name": "gpt-35-turbo-instruct", - "options": { - "max_tokens": 512, - "temperature": 1.0 - } + "options": { + "model": "gpt-35-turbo-instruct", + "max_tokens": 512, + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724" @@ -250,18 +248,16 @@ Unauthorized "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", "auth": { "header": { "Authorization": "Bearer token" } }, - "model": { - "provider": "openai", - "name": "gpt-35-turbo-instruct", - "options": { - "max_tokens": 512, - "temperature": 1.0 - } + "options": { + "model": "gpt-35-turbo-instruct", + "max_tokens": 512, + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724" @@ -342,7 +338,7 @@ prompt%3Dwhat%2520is%25201%2520%252B%25201 Content-Type: application/x-www-form-urlencoded --- error_code: 400 --- response_body chomp -unsupported content-type: application/x-www-form-urlencoded +unsupported content-type: application/x-www-form-urlencoded, only application/json is supported @@ -369,18 +365,16 @@ request format doesn't match schema: property "messages" is required "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", "auth": { "header": { "Authorization": "Bearer token" } }, - "model": { - "provider": "openai", - "name": "some-model", - "options": { - "foo": "bar", - "temperature": 1.0 - } + "options": { + "model": "some-model", + "foo": "bar", + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724" @@ -440,18 +434,16 @@ options_works "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", + "model": "some-model", "auth": { "header": { "Authorization": "Bearer token" } }, - "model": { - "provider": "openai", - "name": "some-model", - "options": { - "foo": "bar", - "temperature": 1.0 - } + "options": { + "foo": "bar", + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724/random" @@ -510,19 +502,17 @@ path override works "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", "auth": { "header": { "Authorization": "Bearer token" } }, - "model": { - "provider": "openai", - "name": "gpt-35-turbo-instruct", - "options": { - "max_tokens": 512, - "temperature": 1.0, - "stream": true - } + "options": { + "model": "gpt-35-turbo-instruct", + "max_tokens": 512, + "temperature": 1.0, + "stream": true }, "override": { "endpoint": "http://localhost:7737" diff --git a/t/plugin/ai-proxy2.t b/t/plugin/ai-proxy2.t index f372e4fbd..942f449cd 100644 --- a/t/plugin/ai-proxy2.t +++ b/t/plugin/ai-proxy2.t @@ -117,18 +117,16 @@ __DATA__ "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", "auth": { "query": { "api_key": "wrong_key" } }, - "model": { - "provider": "openai", - "name": "gpt-35-turbo-instruct", - "options": { - "max_tokens": 512, - "temperature": 1.0 - } + "options": { + "model": "gpt-35-turbo-instruct", + "max_tokens": 512, + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724" @@ -177,18 +175,16 @@ Unauthorized "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", "auth": { "query": { "api_key": "apikey" } }, - "model": { - "provider": "openai", - "name": "gpt-35-turbo-instruct", - "options": { - "max_tokens": 512, - "temperature": 1.0 - } + "options": { + "model": "gpt-35-turbo-instruct", + "max_tokens": 512, + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724" @@ -237,18 +233,16 @@ passed "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", "auth": { "header": { "Authorization": "some-key" } }, - "model": { - "provider": "openai", - "name": "gpt-4", - "options": { - "max_tokens": 512, - "temperature": 1.0 - } + "options": { + "model": "gpt-4", + "max_tokens": 512, + "temperature": 1.0 } } }, @@ -292,18 +286,15 @@ POST /anything "uri": "/anything", "plugins": { "ai-proxy": { + "provider": "openai", "auth": { "query": { "api_key": "apikey" } }, - "model": { - "provider": "openai", - "name": "gpt-35-turbo-instruct", - "options": { - "max_tokens": 512, - "temperature": 1.0 - } + "options": { + "max_tokens": 512, + "temperature": 1.0 }, "override": { "endpoint": "http://localhost:6724/test/params/in/overridden/endpoint?some_query=yes"