This is an automated email from the ASF dual-hosted git repository.
nic443 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git
The following commit(s) were added to refs/heads/master by this push:
new cc7441fb1 feat(plugin): support `ai-proxy-multi` (#11986)
cc7441fb1 is described below
commit cc7441fb1e89a234994489bbd5ead8440b94ccb3
Author: Shreemaan Abhishek <[email protected]>
AuthorDate: Mon Feb 24 08:46:10 2025 +0545
feat(plugin): support `ai-proxy-multi` (#11986)
---
Makefile | 4 +-
apisix/cli/config.lua | 1 +
apisix/plugins/ai-drivers/deepseek.lua | 24 +
.../openai-compatible.lua} | 40 +-
apisix/plugins/ai-drivers/openai.lua | 24 +
apisix/plugins/ai-proxy-multi.lua | 236 +++++++
apisix/plugins/ai-proxy.lua | 31 +-
apisix/plugins/ai-proxy/schema.lua | 88 ++-
conf/config.yaml.example | 1 +
docs/en/latest/config.json | 1 +
docs/en/latest/plugins/ai-proxy-multi.md | 195 ++++++
t/admin/plugins.t | 1 +
t/plugin/ai-proxy-multi.balancer.t | 470 ++++++++++++++
t/plugin/ai-proxy-multi.t | 723 +++++++++++++++++++++
t/plugin/ai-proxy-multi2.t | 361 ++++++++++
15 files changed, 2176 insertions(+), 24 deletions(-)
diff --git a/Makefile b/Makefile
index a24e8f7b8..c288463c9 100644
--- a/Makefile
+++ b/Makefile
@@ -374,8 +374,8 @@ install: runtime
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy
$(ENV_INSTALL) apisix/plugins/ai-proxy/*.lua
$(ENV_INST_LUADIR)/apisix/plugins/ai-proxy
- $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
- $(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua
$(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
+ $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-drivers
+ $(ENV_INSTALL) apisix/plugins/ai-drivers/*.lua
$(ENV_INST_LUADIR)/apisix/plugins/ai-drivers
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
$(ENV_INSTALL) apisix/plugins/ai-rag/embeddings/*.lua
$(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua
index 6a05fed5d..376b5ed15 100644
--- a/apisix/cli/config.lua
+++ b/apisix/cli/config.lua
@@ -223,6 +223,7 @@ local _M = {
"workflow",
"api-breaker",
"ai-proxy",
+ "ai-proxy-multi",
"limit-conn",
"limit-count",
"limit-req",
diff --git a/apisix/plugins/ai-drivers/deepseek.lua
b/apisix/plugins/ai-drivers/deepseek.lua
new file mode 100644
index 000000000..ab441c636
--- /dev/null
+++ b/apisix/plugins/ai-drivers/deepseek.lua
@@ -0,0 +1,24 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+return require("apisix.plugins.ai-drivers.openai-compatible").new(
+ {
+ host = "api.deepseek.com",
+ path = "/chat/completions",
+ port = 443
+ }
+)
diff --git a/apisix/plugins/ai-proxy/drivers/openai.lua
b/apisix/plugins/ai-drivers/openai-compatible.lua
similarity index 74%
rename from apisix/plugins/ai-proxy/drivers/openai.lua
rename to apisix/plugins/ai-drivers/openai-compatible.lua
index af0bc9758..fd5d2163c 100644
--- a/apisix/plugins/ai-proxy/drivers/openai.lua
+++ b/apisix/plugins/ai-drivers/openai-compatible.lua
@@ -16,27 +16,38 @@
--
local _M = {}
+local mt = {
+ __index = _M
+}
+
local core = require("apisix.core")
local http = require("resty.http")
local url = require("socket.url")
local pairs = pairs
local type = type
+local setmetatable = setmetatable
+
+
+function _M.new(opts)
--- globals
-local DEFAULT_HOST = "api.openai.com"
-local DEFAULT_PORT = 443
-local DEFAULT_PATH = "/v1/chat/completions"
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ }
+ return setmetatable(self, mt)
+end
-function _M.request(conf, request_table, ctx)
+function _M.request(self, conf, request_table, extra_opts)
local httpc, err = http.new()
if not httpc then
return nil, "failed to create http client to send request to LLM
server: " .. err
end
httpc:set_timeout(conf.timeout)
- local endpoint = core.table.try_read_attr(conf, "override", "endpoint")
+ local endpoint = extra_opts.endpoint
local parsed_url
if endpoint then
parsed_url = url.parse(endpoint)
@@ -44,10 +55,10 @@ function _M.request(conf, request_table, ctx)
local ok, err = httpc:connect({
scheme = endpoint and parsed_url.scheme or "https",
- host = endpoint and parsed_url.host or DEFAULT_HOST,
- port = endpoint and parsed_url.port or DEFAULT_PORT,
+ host = endpoint and parsed_url.host or self.host,
+ port = endpoint and parsed_url.port or self.port,
ssl_verify = conf.ssl_verify,
- ssl_server_name = endpoint and parsed_url.host or DEFAULT_HOST,
+ ssl_server_name = endpoint and parsed_url.host or self.host,
pool_size = conf.keepalive and conf.keepalive_pool,
})
@@ -55,7 +66,7 @@ function _M.request(conf, request_table, ctx)
return nil, "failed to connect to LLM server: " .. err
end
- local query_params = conf.auth.query or {}
+ local query_params = extra_opts.query_params
if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query
> 0 then
local args_tab = core.string.decode_args(parsed_url.query)
@@ -64,9 +75,9 @@ function _M.request(conf, request_table, ctx)
end
end
- local path = (endpoint and parsed_url.path or DEFAULT_PATH)
+ local path = (endpoint and parsed_url.path or self.path)
- local headers = (conf.auth.header or {})
+ local headers = extra_opts.headers
headers["Content-Type"] = "application/json"
local params = {
method = "POST",
@@ -77,13 +88,14 @@ function _M.request(conf, request_table, ctx)
query = query_params
}
- if conf.model.options then
- for opt, val in pairs(conf.model.options) do
+ if extra_opts.model_options then
+ for opt, val in pairs(extra_opts.model_options) do
request_table[opt] = val
end
end
params.body = core.json.encode(request_table)
+ httpc:set_timeout(conf.keepalive_timeout)
local res, err = httpc:request(params)
if not res then
return nil, err
diff --git a/apisix/plugins/ai-drivers/openai.lua
b/apisix/plugins/ai-drivers/openai.lua
new file mode 100644
index 000000000..785ede193
--- /dev/null
+++ b/apisix/plugins/ai-drivers/openai.lua
@@ -0,0 +1,24 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+return require("apisix.plugins.ai-drivers.openai-compatible").new(
+ {
+ host = "api.openai.com",
+ path = "/v1/chat/completions",
+ port = 443
+ }
+)
diff --git a/apisix/plugins/ai-proxy-multi.lua
b/apisix/plugins/ai-proxy-multi.lua
new file mode 100644
index 000000000..48f0dea94
--- /dev/null
+++ b/apisix/plugins/ai-proxy-multi.lua
@@ -0,0 +1,236 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+local core = require("apisix.core")
+local schema = require("apisix.plugins.ai-proxy.schema")
+local ai_proxy = require("apisix.plugins.ai-proxy")
+local plugin = require("apisix.plugin")
+
+local require = require
+local pcall = pcall
+local ipairs = ipairs
+local unpack = unpack
+local type = type
+
+local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
+local priority_balancer = require("apisix.balancer.priority")
+
+local pickers = {}
+local lrucache_server_picker = core.lrucache.new({
+ ttl = 300, count = 256
+})
+
+local plugin_name = "ai-proxy-multi"
+local _M = {
+ version = 0.5,
+ priority = 998,
+ name = plugin_name,
+ schema = schema.ai_proxy_multi_schema,
+}
+
+
+local function get_chash_key_schema(hash_on)
+ if hash_on == "vars" then
+ return core.schema.upstream_hash_vars_schema
+ end
+
+ if hash_on == "header" or hash_on == "cookie" then
+ return core.schema.upstream_hash_header_schema
+ end
+
+ if hash_on == "consumer" then
+ return nil, nil
+ end
+
+ if hash_on == "vars_combinations" then
+ return core.schema.upstream_hash_vars_combinations_schema
+ end
+
+ return nil, "invalid hash_on type " .. hash_on
+end
+
+
+function _M.check_schema(conf)
+ for _, provider in ipairs(conf.providers) do
+ local ai_driver = pcall(require, "apisix.plugins.ai-drivers." ..
provider.name)
+ if not ai_driver then
+ return false, "provider: " .. provider.name .. " is not supported."
+ end
+ end
+ local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
+ local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on")
+ local hash_key = core.table.try_read_attr(conf, "balancer", "key")
+
+ if type(algo) == "string" and algo == "chash" then
+ if not hash_on then
+ return false, "must configure `hash_on` when balancer algorithm is
chash"
+ end
+
+ if hash_on ~= "consumer" and not hash_key then
+ return false, "must configure `hash_key` when balancer `hash_on`
is not set to cookie"
+ end
+
+ local key_schema, err = get_chash_key_schema(hash_on)
+ if err then
+ return false, "type is chash, err: " .. err
+ end
+
+ if key_schema then
+ local ok, err = core.schema.check(key_schema, hash_key)
+ if not ok then
+ return false, "invalid configuration: " .. err
+ end
+ end
+ end
+
+ return core.schema.check(schema.ai_proxy_multi_schema, conf)
+end
+
+
+local function transform_providers(new_providers, provider)
+ if not new_providers._priority_index then
+ new_providers._priority_index = {}
+ end
+
+ if not new_providers[provider.priority] then
+ new_providers[provider.priority] = {}
+ core.table.insert(new_providers._priority_index, provider.priority)
+ end
+
+ new_providers[provider.priority][provider.name] = provider.weight
+end
+
+
+local function create_server_picker(conf, ups_tab)
+ local picker = pickers[conf.balancer.algorithm] -- nil check
+ if not picker then
+ pickers[conf.balancer.algorithm] = require("apisix.balancer." ..
conf.balancer.algorithm)
+ picker = pickers[conf.balancer.algorithm]
+ end
+ local new_providers = {}
+ for i, provider in ipairs(conf.providers) do
+ transform_providers(new_providers, provider)
+ end
+
+ if #new_providers._priority_index > 1 then
+ core.log.info("new providers: ", core.json.delay_encode(new_providers))
+ return priority_balancer.new(new_providers, ups_tab, picker)
+ end
+ core.log.info("upstream nodes: ",
+
core.json.delay_encode(new_providers[new_providers._priority_index[1]]))
+ return picker.new(new_providers[new_providers._priority_index[1]], ups_tab)
+end
+
+
+local function get_provider_conf(providers, name)
+ for i, provider in ipairs(providers) do
+ if provider.name == name then
+ return provider
+ end
+ end
+end
+
+
+local function pick_target(ctx, conf, ups_tab)
+ ctx.ai_balancer_try_count = (ctx.ai_balancer_try_count or 0) + 1
+ if ctx.ai_balancer_try_count > 1 then
+ if ctx.server_picker and ctx.server_picker.after_balance then
+ ctx.server_picker.after_balance(ctx, true)
+ end
+ end
+
+ local server_picker = ctx.server_picker
+ if not server_picker then
+ server_picker = lrucache_server_picker(ctx.matched_route.key,
plugin.conf_version(conf),
+ create_server_picker, conf,
ups_tab)
+ end
+ if not server_picker then
+ return internal_server_error, "failed to fetch server picker"
+ end
+
+ local provider_name = server_picker.get(ctx)
+ local provider_conf = get_provider_conf(conf.providers, provider_name)
+
+ ctx.balancer_server = provider_name
+ ctx.server_picker = server_picker
+
+ return provider_name, provider_conf
+end
+
+
+local function get_load_balanced_provider(ctx, conf, ups_tab, request_table)
+ local provider_name, provider_conf
+ if #conf.providers == 1 then
+ provider_name = conf.providers[1].name
+ provider_conf = conf.providers[1]
+ else
+ provider_name, provider_conf = pick_target(ctx, conf, ups_tab)
+ end
+
+ core.log.info("picked provider: ", provider_name)
+ if provider_conf.model then
+ request_table.model = provider_conf.model
+ end
+
+ provider_conf.__name = provider_name
+ return provider_name, provider_conf
+end
+
+ai_proxy.get_model_name = function (...)
+end
+
+
+ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
+ local ups_tab = {}
+ local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
+ if algo == "chash" then
+ local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on")
+ local hash_key = core.table.try_read_attr(conf, "balancer", "key")
+ ups_tab["key"] = hash_key
+ ups_tab["hash_on"] = hash_on
+ end
+
+ ::retry::
+ local provider, provider_conf = get_load_balanced_provider(ctx, conf,
ups_tab, request_table)
+ local extra_opts = {
+ endpoint = core.table.try_read_attr(provider_conf, "override",
"endpoint"),
+ query_params = provider_conf.auth.query or {},
+ headers = (provider_conf.auth.header or {}),
+ model_options = provider_conf.options,
+ }
+
+ local ai_driver = require("apisix.plugins.ai-drivers." .. provider)
+ local res, err, httpc = ai_driver:request(conf, request_table, extra_opts)
+ if not res then
+ if (ctx.balancer_try_count or 0) < 1 then
+ core.log.warn("failed to send request to LLM: ", err, ".
Retrying...")
+ goto retry
+ end
+ return nil, err, nil
+ end
+
+ request_table.model = provider_conf.model
+ return res, nil, httpc
+end
+
+
+function _M.access(conf, ctx)
+ local rets = {ai_proxy.access(conf, ctx)}
+ return unpack(rets)
+end
+
+return _M
diff --git a/apisix/plugins/ai-proxy.lua b/apisix/plugins/ai-proxy.lua
index 8a0d8fa97..c27ca9a3b 100644
--- a/apisix/plugins/ai-proxy.lua
+++ b/apisix/plugins/ai-proxy.lua
@@ -34,11 +34,11 @@ local _M = {
function _M.check_schema(conf)
- local ai_driver = pcall(require, "apisix.plugins.ai-proxy.drivers." ..
conf.model.provider)
+ local ai_driver = pcall(require, "apisix.plugins.ai-drivers." ..
conf.model.provider)
if not ai_driver then
return false, "provider: " .. conf.model.provider .. " is not
supported."
end
- return core.schema.check(schema.plugin_schema, conf)
+ return core.schema.check(schema.ai_proxy_schema, conf)
end
@@ -54,6 +54,26 @@ local function keepalive_or_close(conf, httpc)
end
+function _M.get_model_name(conf)
+ return conf.model.name
+end
+
+
+function _M.proxy_request_to_llm(conf, request_table, ctx)
+ local ai_driver = require("apisix.plugins.ai-drivers." ..
conf.model.provider)
+ local extra_opts = {
+ endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
+ query_params = conf.auth.query or {},
+ headers = (conf.auth.header or {}),
+ model_options = conf.model.options
+ }
+ local res, err, httpc = ai_driver:request(conf, request_table, extra_opts)
+ if not res then
+ return nil, err, nil
+ end
+ return res, nil, httpc
+end
+
function _M.access(conf, ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
@@ -70,16 +90,13 @@ function _M.access(conf, ctx)
return bad_request, "request format doesn't match schema: " .. err
end
- if conf.model.name then
- request_table.model = conf.model.name
- end
+ request_table.model = _M.get_model_name(conf)
if core.table.try_read_attr(conf, "model", "options", "stream") then
request_table.stream = true
end
- local ai_driver = require("apisix.plugins.ai-proxy.drivers." ..
conf.model.provider)
- local res, err, httpc = ai_driver.request(conf, request_table, ctx)
+ local res, err, httpc = _M.proxy_request_to_llm(conf, request_table, ctx)
if not res then
core.log.error("failed to send request to LLM service: ", err)
return internal_server_error
diff --git a/apisix/plugins/ai-proxy/schema.lua
b/apisix/plugins/ai-proxy/schema.lua
index 382644dc2..d0ba33fdc 100644
--- a/apisix/plugins/ai-proxy/schema.lua
+++ b/apisix/plugins/ai-proxy/schema.lua
@@ -105,7 +105,48 @@ local model_schema = {
required = {"provider", "name"}
}
-_M.plugin_schema = {
+local provider_schema = {
+ type = "array",
+ minItems = 1,
+ items = {
+ type = "object",
+ properties = {
+ name = {
+ type = "string",
+ description = "Name of the AI service provider.",
+ enum = { "openai", "deepseek" }, -- add more providers later
+
+ },
+ model = {
+ type = "string",
+ description = "Model to execute.",
+ },
+ priority = {
+ type = "integer",
+ description = "Priority of the provider for load balancing",
+ default = 0,
+ },
+ weight = {
+ type = "integer",
+ },
+ auth = auth_schema,
+ options = model_options_schema,
+ override = {
+ type = "object",
+ properties = {
+ endpoint = {
+ type = "string",
+ description = "To be specified to override the host of
the AI provider",
+ },
+ },
+ },
+ },
+ required = {"name", "model", "auth"}
+ },
+}
+
+
+_M.ai_proxy_schema = {
type = "object",
properties = {
auth = auth_schema,
@@ -126,6 +167,51 @@ _M.plugin_schema = {
required = {"model", "auth"}
}
+_M.ai_proxy_multi_schema = {
+ type = "object",
+ properties = {
+ balancer = {
+ type = "object",
+ properties = {
+ algorithm = {
+ type = "string",
+ enum = { "chash", "roundrobin" },
+ },
+ hash_on = {
+ type = "string",
+ default = "vars",
+ enum = {
+ "vars",
+ "header",
+ "cookie",
+ "consumer",
+ "vars_combinations",
+ },
+ },
+ key = {
+ description = "the key of chash for dynamic load
balancing",
+ type = "string",
+ },
+ },
+ default = { algorithm = "roundrobin" }
+ },
+ providers = provider_schema,
+ passthrough = { type = "boolean", default = false },
+ timeout = {
+ type = "integer",
+ minimum = 1,
+ maximum = 60000,
+ default = 3000,
+ description = "timeout in milliseconds",
+ },
+ keepalive = {type = "boolean", default = true},
+ keepalive_timeout = {type = "integer", minimum = 1000, default =
60000},
+ keepalive_pool = {type = "integer", minimum = 1, default = 30},
+ ssl_verify = {type = "boolean", default = true },
+ },
+ required = {"providers", }
+}
+
_M.chat_request_schema = {
type = "object",
properties = {
diff --git a/conf/config.yaml.example b/conf/config.yaml.example
index 8052beef6..780340dcb 100644
--- a/conf/config.yaml.example
+++ b/conf/config.yaml.example
@@ -491,6 +491,7 @@ plugins: # plugin list (sorted by
priority)
- limit-req # priority: 1001
#- node-status # priority: 1000
- ai-proxy # priority: 999
+ - ai-proxy-multi # priority: 998
#- brotli # priority: 996
- gzip # priority: 995
- server-info # priority: 990
diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json
index a17a6ae48..c8bf09ca7 100644
--- a/docs/en/latest/config.json
+++ b/docs/en/latest/config.json
@@ -100,6 +100,7 @@
"plugins/degraphql",
"plugins/body-transformer",
"plugins/ai-proxy",
+ "plugins/ai-proxy-multi",
"plugins/attach-consumer-label",
"plugins/ai-rag"
]
diff --git a/docs/en/latest/plugins/ai-proxy-multi.md
b/docs/en/latest/plugins/ai-proxy-multi.md
new file mode 100644
index 000000000..72d8a9cfa
--- /dev/null
+++ b/docs/en/latest/plugins/ai-proxy-multi.md
@@ -0,0 +1,195 @@
+---
+title: ai-proxy
+keywords:
+ - Apache APISIX
+ - API Gateway
+ - Plugin
+ - ai-proxy-multi
+description: This document contains information about the Apache APISIX
ai-proxy-multi Plugin.
+---
+
+<!--
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+-->
+
+## Description
+
+The `ai-prox-multi` plugin simplifies access to LLM providers and models by
defining a standard request format
+that allows key fields in plugin configuration to be embedded into the request.
+
+This plugin adds additional features like `load balancing` and `retries` to
the existing `ai-proxy` plugin.
+
+Proxying requests to OpenAI is supported now. Other LLM services will be
supported soon.
+
+## Request Format
+
+### OpenAI
+
+- Chat API
+
+| Name | Type | Required | Description
|
+| ------------------ | ------ | -------- |
--------------------------------------------------- |
+| `messages` | Array | Yes | An array of message objects
|
+| `messages.role` | String | Yes | Role of the message (`system`,
`user`, `assistant`) |
+| `messages.content` | String | Yes | Content of the message
|
+
+## Plugin Attributes
+
+| **Name** | **Required** | **Type** | **Description**
| **Default** |
+| ---------------------------- | ------------ | -------- |
-------------------------------------------------------------------------------------------------------------
| ----------- |
+| providers | Yes | array | List of AI
providers, each following the provider schema.
| |
+| provider.name | Yes | string | Name of the AI
service provider. Allowed values: `openai`, `deepseek`.
| |
+| provider.model | Yes | string | Name of the AI
model to execute. Example: `gpt-4o`.
| |
+| provider.priority | No | integer | Priority of the
provider for load balancing.
| 0 |
+| provider.weight | No | integer | Load balancing
weight.
| |
+| balancer.algorithm | No | string | Load balancing
algorithm. Allowed values: `chash`, `roundrobin`.
| roundrobin |
+| balancer.hash_on | No | string | Defines what to
hash on for consistent hashing (`vars`, `header`, `cookie`, `consumer`,
`vars_combinations`). | vars |
+| balancer.key | No | string | Key for consistent
hashing in dynamic load balancing.
| |
+| provider.auth | Yes | object | Authentication
details, including headers and query parameters.
| |
+| provider.auth.header | No | object | Authentication
details sent via headers. Header name must match `^[a-zA-Z0-9._-]+$`.
| |
+| provider.auth.query | No | object | Authentication
details sent via query parameters. Keys must match `^[a-zA-Z0-9._-]+$`.
| |
+| provider.options.max_tokens | No | integer | Defines the maximum
tokens for chat or completion models.
| 256 |
+| provider.options.input_cost | No | number | Cost per 1M tokens
in the input prompt. Minimum is 0.
| |
+| provider.options.output_cost | No | number | Cost per 1M tokens
in the AI-generated output. Minimum is 0.
| |
+| provider.options.temperature | No | number | Defines the model's
temperature (0.0 - 5.0) for randomness in responses.
| |
+| provider.options.top_p | No | number | Defines the top-p
probability mass (0 - 1) for nucleus sampling.
| |
+| provider.options.stream | No | boolean | Enables streaming
responses via SSE.
| false |
+| provider.override.endpoint | No | string | Custom host
override for the AI provider.
| |
+| passthrough | No | boolean | If true, requests
are forwarded without processing.
| false |
+| timeout | No | integer | Request timeout in
milliseconds (1-60000).
| 3000 |
+| keepalive | No | boolean | Enables keepalive
connections.
| true |
+| keepalive_timeout | No | integer | Timeout for
keepalive connections (minimum 1000ms).
| 60000 |
+| keepalive_pool | No | integer | Maximum keepalive
connections.
| 30 |
+| ssl_verify | No | boolean | Enables SSL
certificate verification.
| true |
+
+## Example usage
+
+Create a route with the `ai-proxy-multi` plugin like so:
+
+```shell
+curl "http://127.0.0.1:9180/apisix/admin/routes" -X PUT \
+ -H "X-API-KEY: ${ADMIN_API_KEY}" \
+ -d '{
+ "id": "ai-proxy-multi-route",
+ "uri": "/anything",
+ "methods": ["POST"],
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-4",
+ "weight": 1,
+ "priority": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer '"$OPENAI_API_KEY"'"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ },
+ {
+ "name": "deepseek",
+ "model": "deepseek-chat",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer '"$DEEPSEEK_API_KEY"'"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ }
+ ],
+ "passthrough": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "httpbin.org": 1
+ }
+ }
+ }'
+```
+
+In the above configuration, requests will be equally balanced among the
`openai` and `deepseek` providers.
+
+### Retry and fallback:
+
+The `priority` attribute can be adjusted to implement the fallback and retry
feature.
+
+```shell
+curl "http://127.0.0.1:9180/apisix/admin/routes" -X PUT \
+ -H "X-API-KEY: ${ADMIN_API_KEY}" \
+ -d '{
+ "id": "ai-proxy-multi-route",
+ "uri": "/anything",
+ "methods": ["POST"],
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-4",
+ "weight": 1,
+ "priority": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer '"$OPENAI_API_KEY"'"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ },
+ {
+ "name": "deepseek",
+ "model": "deepseek-chat",
+ "weight": 1,
+ "priority": 0,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer '"$DEEPSEEK_API_KEY"'"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ }
+ ],
+ "passthrough": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "httpbin.org": 1
+ }
+ }
+ }'
+```
+
+In the above configuration `priority` for the deepseek provider is set to `0`.
Which means if `openai` provider is unavailable then `ai-proxy-multi` plugin
will retry sending request to `deepseek` in the second attempt.
diff --git a/t/admin/plugins.t b/t/admin/plugins.t
index 6c574c2a4..7cb852cbf 100644
--- a/t/admin/plugins.t
+++ b/t/admin/plugins.t
@@ -106,6 +106,7 @@ limit-conn
limit-count
limit-req
ai-proxy
+ai-proxy-multi
gzip
server-info
traffic-split
diff --git a/t/plugin/ai-proxy-multi.balancer.t
b/t/plugin/ai-proxy-multi.balancer.t
new file mode 100644
index 000000000..da26957fb
--- /dev/null
+++ b/t/plugin/ai-proxy-multi.balancer.t
@@ -0,0 +1,470 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+ my ($block) = @_;
+
+ if (!defined $block->request) {
+ $block->set_value("request", "GET /t");
+ }
+
+ my $user_yaml_config = <<_EOC_;
+plugins:
+ - ai-proxy-multi
+_EOC_
+ $block->set_value("extra_yaml_config", $user_yaml_config);
+
+ my $http_config = $block->http_config // <<_EOC_;
+ server {
+ server_name openai;
+ listen 6724;
+
+ default_type 'application/json';
+
+ location /v1/chat/completions {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ local header_auth = ngx.req.get_headers()["authorization"]
+ local query_auth = ngx.req.get_uri_args()["apikey"]
+
+ if header_auth ~= "Bearer token" and query_auth ~=
"apikey" then
+ ngx.status = 401
+ ngx.say("Unauthorized")
+ return
+ end
+
+ if header_auth == "Bearer token" or query_auth == "apikey"
then
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ if not body.messages or #body.messages < 1 then
+ ngx.status = 400
+ ngx.say([[{ "error": "bad request"}]])
+ return
+ end
+
+ ngx.status = 200
+ ngx.print("openai")
+ return
+ end
+
+
+ ngx.status = 503
+ ngx.say("reached the end of the test suite")
+ }
+ }
+
+ location /chat/completions {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ local header_auth = ngx.req.get_headers()["authorization"]
+ local query_auth = ngx.req.get_uri_args()["apikey"]
+
+ if header_auth ~= "Bearer token" and query_auth ~=
"apikey" then
+ ngx.status = 401
+ ngx.say("Unauthorized")
+ return
+ end
+
+ if header_auth == "Bearer token" or query_auth == "apikey"
then
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ if not body.messages or #body.messages < 1 then
+ ngx.status = 400
+ ngx.say([[{ "error": "bad request"}]])
+ return
+ end
+
+ ngx.status = 200
+ ngx.print("deepseek")
+ return
+ end
+
+
+ ngx.status = 503
+ ngx.say("reached the end of the test suite")
+ }
+ }
+ }
+_EOC_
+
+ $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: set route with roundrobin balancer, weight 4 and 1
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-4",
+ "weight": 4,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ }
+ },
+ {
+ "name": "deepseek",
+ "model": "gpt-4",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint":
"http://localhost:6724/chat/completions"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 2: test
+--- config
+ location /t {
+ content_by_lua_block {
+ local http = require "resty.http"
+ local uri = "http://127.0.0.1:" .. ngx.var.server_port
+ .. "/anything"
+
+ local restab = {}
+
+ local body = [[{ "messages": [ { "role": "system", "content": "You
are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]]
+ for i = 1, 10 do
+ local httpc = http.new()
+ local res, err = httpc:request_uri(uri, {method = "POST", body
= body})
+ if not res then
+ ngx.say(err)
+ return
+ end
+ table.insert(restab, res.body)
+ end
+
+ table.sort(restab)
+ ngx.log(ngx.WARN, "test picked providers: ", table.concat(restab,
"."))
+
+ }
+ }
+--- request
+GET /t
+--- error_log
+deepseek.deepseek.openai.openai.openai.openai.openai.openai.openai.openai
+
+
+
+=== TEST 3: set route with chash balancer, weight 4 and 1
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "balancer": {
+ "algorithm": "chash",
+ "hash_on": "vars",
+ "key": "query_string"
+ },
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-4",
+ "weight": 4,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ }
+ },
+ {
+ "name": "deepseek",
+ "model": "gpt-4",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint":
"http://localhost:6724/chat/completions"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 4: test
+--- config
+ location /t {
+ content_by_lua_block {
+ local http = require "resty.http"
+ local uri = "http://127.0.0.1:" .. ngx.var.server_port
+ .. "/anything"
+
+ local restab = {}
+
+ local body = [[{ "messages": [ { "role": "system", "content": "You
are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]]
+ for i = 1, 10 do
+ local httpc = http.new()
+ local query = {
+ index = i
+ }
+ local res, err = httpc:request_uri(uri, {method = "POST", body
= body, query = query})
+ if not res then
+ ngx.say(err)
+ return
+ end
+ table.insert(restab, res.body)
+ end
+
+ local count = {}
+ for _, value in ipairs(restab) do
+ count[value] = (count[value] or 0) + 1
+ end
+
+ for p, num in pairs(count) do
+ ngx.log(ngx.WARN, "distribution: ", p, ": ", num)
+ end
+
+ }
+ }
+--- request
+GET /t
+--- timeout: 10
+--- error_log
+distribution: deepseek: 2
+distribution: openai: 8
+
+
+
+=== TEST 5: retry logic with different priorities
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-4",
+ "weight": 1,
+ "priority": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:9999"
+ }
+ },
+ {
+ "name": "deepseek",
+ "model": "gpt-4",
+ "priority": 0,
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint":
"http://localhost:6724/chat/completions"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 6: test
+--- config
+ location /t {
+ content_by_lua_block {
+ local http = require "resty.http"
+ local uri = "http://127.0.0.1:" .. ngx.var.server_port
+ .. "/anything"
+
+ local restab = {}
+
+ local body = [[{ "messages": [ { "role": "system", "content": "You
are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]]
+ local httpc = http.new()
+ local res, err = httpc:request_uri(uri, {method = "POST", body
= body})
+ if not res then
+ ngx.say(err)
+ return
+ end
+ ngx.say(res.body)
+
+ }
+ }
+--- request
+GET /t
+--- response_body
+deepseek
+--- error_log
+failed to send request to LLM: failed to connect to LLM server: connection
refused. Retrying...
diff --git a/t/plugin/ai-proxy-multi.t b/t/plugin/ai-proxy-multi.t
new file mode 100644
index 000000000..68eed015d
--- /dev/null
+++ b/t/plugin/ai-proxy-multi.t
@@ -0,0 +1,723 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+ my ($block) = @_;
+
+ if (!defined $block->request) {
+ $block->set_value("request", "GET /t");
+ }
+
+ my $user_yaml_config = <<_EOC_;
+plugins:
+ - ai-proxy-multi
+_EOC_
+ $block->set_value("extra_yaml_config", $user_yaml_config);
+
+ my $http_config = $block->http_config // <<_EOC_;
+ server {
+ server_name openai;
+ listen 6724;
+
+ default_type 'application/json';
+
+ location /anything {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+ ngx.req.read_body()
+ local body = ngx.req.get_body_data()
+
+ if body ~= "SELECT * FROM STUDENTS" then
+ ngx.status = 503
+ ngx.say("passthrough doesn't work")
+ return
+ end
+ ngx.say('{"foo", "bar"}')
+ }
+ }
+
+ location /v1/chat/completions {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ local test_type = ngx.req.get_headers()["test-type"]
+ if test_type == "options" then
+ if body.foo == "bar" then
+ ngx.status = 200
+ ngx.say("options works")
+ else
+ ngx.status = 500
+ ngx.say("model options feature doesn't work")
+ end
+ return
+ end
+
+ local header_auth = ngx.req.get_headers()["authorization"]
+ local query_auth = ngx.req.get_uri_args()["apikey"]
+
+ if header_auth ~= "Bearer token" and query_auth ~=
"apikey" then
+ ngx.status = 401
+ ngx.say("Unauthorized")
+ return
+ end
+
+ if header_auth == "Bearer token" or query_auth == "apikey"
then
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ if not body.messages or #body.messages < 1 then
+ ngx.status = 400
+ ngx.say([[{ "error": "bad request"}]])
+ return
+ end
+
+ if body.messages[1].content == "write an SQL query to
get all rows from student table" then
+ ngx.print("SELECT * FROM STUDENTS")
+ return
+ end
+
+ ngx.status = 200
+ ngx.say([[$resp]])
+ return
+ end
+
+
+ ngx.status = 503
+ ngx.say("reached the end of the test suite")
+ }
+ }
+
+ location /random {
+ content_by_lua_block {
+ ngx.say("path override works")
+ }
+ }
+ }
+_EOC_
+
+ $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: minimal viable configuration
+--- config
+ location /t {
+ content_by_lua_block {
+ local plugin = require("apisix.plugins.ai-proxy-multi")
+ local ok, err = plugin.check_schema({
+ providers = {
+ {
+ name = "openai",
+ model = "gpt-4",
+ weight = 1,
+ auth = {
+ header = {
+ some_header = "some_value"
+ }
+ }
+ }
+ }
+ })
+
+ if not ok then
+ ngx.say(err)
+ else
+ ngx.say("passed")
+ end
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 2: unsupported provider
+--- config
+ location /t {
+ content_by_lua_block {
+ local plugin = require("apisix.plugins.ai-proxy-multi")
+ local ok, err = plugin.check_schema({
+ providers = {
+ {
+ name = "some-unique",
+ model = "gpt-4",
+ weight = 1,
+ auth = {
+ header = {
+ some_header = "some_value"
+ }
+ }
+ }
+ }
+ })
+
+ if not ok then
+ ngx.say(err)
+ else
+ ngx.say("passed")
+ end
+ }
+ }
+--- response_body eval
+qr/.*provider: some-unique is not supported.*/
+
+
+
+=== TEST 3: set route with wrong auth header
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-4",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer
wrongtoken"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 4: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 401
+--- response_body
+Unauthorized
+
+
+
+=== TEST 5: set route with right auth header
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-4",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 6: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body eval
+qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/
+
+
+
+=== TEST 7: send request with empty body
+--- request
+POST /anything
+--- more_headers
+Authorization: Bearer token
+--- error_code: 400
+--- response_body_chomp
+failed to get request body: request body is empty
+
+
+
+=== TEST 8: send request with wrong method (GET) should work
+--- request
+GET /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body eval
+qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/
+
+
+
+=== TEST 9: wrong JSON in request body should give error
+--- request
+GET /anything
+{}"messages": [ { "role": "system", "cont
+--- error_code: 400
+--- response_body
+{"message":"could not get parse JSON request body: Expected the end but found
T_STRING at character 3"}
+
+
+
+=== TEST 10: content-type should be JSON
+--- request
+POST /anything
+prompt%3Dwhat%2520is%25201%2520%252B%25201
+--- more_headers
+Content-Type: application/x-www-form-urlencoded
+--- error_code: 400
+--- response_body chomp
+unsupported content-type: application/x-www-form-urlencoded
+
+
+
+=== TEST 11: request schema validity check
+--- request
+POST /anything
+{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 400
+--- response_body chomp
+request format doesn't match schema: property "messages" is required
+
+
+
+=== TEST 12: model options being merged to request body
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "some-model",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "foo": "bar",
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+
+ local code, body, actual_body = t("/anything",
+ ngx.HTTP_POST,
+ [[{
+ "messages": [
+ { "role": "system", "content": "You are a
mathematician" },
+ { "role": "user", "content": "What is 1+1?" }
+ ]
+ }]],
+ nil,
+ {
+ ["test-type"] = "options",
+ ["Content-Type"] = "application/json",
+ }
+ )
+
+ ngx.status = code
+ ngx.say(actual_body)
+
+ }
+ }
+--- error_code: 200
+--- response_body_chomp
+options_works
+
+
+
+=== TEST 13: override path
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "some-model",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "foo": "bar",
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint":
"http://localhost:6724/random"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+
+ local code, body, actual_body = t("/anything",
+ ngx.HTTP_POST,
+ [[{
+ "messages": [
+ { "role": "system", "content": "You are a
mathematician" },
+ { "role": "user", "content": "What is 1+1?" }
+ ]
+ }]],
+ nil,
+ {
+ ["test-type"] = "path",
+ ["Content-Type"] = "application/json",
+ }
+ )
+
+ ngx.status = code
+ ngx.say(actual_body)
+
+ }
+ }
+--- response_body_chomp
+path override works
+
+
+
+=== TEST 14: set route with right auth header
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-35-turbo-instruct",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ }
+ }
+ ],
+ "ssl_verify": false,
+ "passthrough": true
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "127.0.0.1:6724": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 15: send request with wrong method should work
+--- request
+POST /anything
+{ "messages": [ { "role": "user", "content": "write an SQL query to get all
rows from student table" } ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body
+{"foo", "bar"}
+
+
+
+=== TEST 16: set route with stream = true (SSE)
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-35-turbo-instruct",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0,
+ "stream": true
+ },
+ "override": {
+ "endpoint": "http://localhost:7737"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 17: test is SSE works as expected
+--- config
+ location /t {
+ content_by_lua_block {
+ local http = require("resty.http")
+ local httpc = http.new()
+ local core = require("apisix.core")
+
+ local ok, err = httpc:connect({
+ scheme = "http",
+ host = "localhost",
+ port = ngx.var.server_port,
+ })
+
+ if not ok then
+ ngx.status = 500
+ ngx.say(err)
+ return
+ end
+
+ local params = {
+ method = "POST",
+ headers = {
+ ["Content-Type"] = "application/json",
+ },
+ path = "/anything",
+ body = [[{
+ "messages": [
+ { "role": "system", "content": "some content" }
+ ]
+ }]],
+ }
+
+ local res, err = httpc:request(params)
+ if not res then
+ ngx.status = 500
+ ngx.say(err)
+ return
+ end
+
+ local final_res = {}
+ while true do
+ local chunk, err = res.body_reader() -- will read chunk by
chunk
+ if err then
+ core.log.error("failed to read response chunk: ", err)
+ break
+ end
+ if not chunk then
+ break
+ end
+ core.table.insert_tail(final_res, chunk)
+ end
+
+ ngx.print(#final_res .. final_res[6])
+ }
+ }
+--- response_body_like eval
+qr/6data: \[DONE\]\n\n/
diff --git a/t/plugin/ai-proxy-multi2.t b/t/plugin/ai-proxy-multi2.t
new file mode 100644
index 000000000..af5c4e880
--- /dev/null
+++ b/t/plugin/ai-proxy-multi2.t
@@ -0,0 +1,361 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+ my ($block) = @_;
+
+ if (!defined $block->request) {
+ $block->set_value("request", "GET /t");
+ }
+
+ my $user_yaml_config = <<_EOC_;
+plugins:
+ - ai-proxy-multi
+_EOC_
+ $block->set_value("extra_yaml_config", $user_yaml_config);
+
+ my $http_config = $block->http_config // <<_EOC_;
+ server {
+ server_name openai;
+ listen 6724;
+
+ default_type 'application/json';
+
+ location /v1/chat/completions {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ local query_auth = ngx.req.get_uri_args()["api_key"]
+
+ if query_auth ~= "apikey" then
+ ngx.status = 401
+ ngx.say("Unauthorized")
+ return
+ end
+
+
+ ngx.status = 200
+ ngx.say("passed")
+ }
+ }
+
+
+ location /test/params/in/overridden/endpoint {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+ local core = require("apisix.core")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+
+ local query_auth = ngx.req.get_uri_args()["api_key"]
+ ngx.log(ngx.INFO, "found query params: ",
core.json.stably_encode(ngx.req.get_uri_args()))
+
+ if query_auth ~= "apikey" then
+ ngx.status = 401
+ ngx.say("Unauthorized")
+ return
+ end
+
+ ngx.status = 200
+ ngx.say("passed")
+ }
+ }
+ }
+_EOC_
+
+ $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: set route with wrong query param
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-35-turbo-instruct",
+ "weight": 1,
+ "auth": {
+ "query": {
+ "api_key": "wrong_key"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 2: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 401
+--- response_body
+Unauthorized
+
+
+
+=== TEST 3: set route with right query param
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-35-turbo-instruct",
+ "weight": 1,
+ "auth": {
+ "query": {
+ "api_key": "apikey"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 4: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 200
+--- response_body
+passed
+
+
+
+=== TEST 5: set route without overriding the endpoint_url
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-35-turbo-instruct",
+ "weight": 1,
+ "auth": {
+ "header": {
+ "Authorization": "some-key"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "httpbin.org": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 6: send request
+--- custom_trusted_cert: /etc/ssl/cert.pem
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 401
+
+
+
+=== TEST 7: query params in override.endpoint should be sent to LLM
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy-multi": {
+ "providers": [
+ {
+ "name": "openai",
+ "model": "gpt-35-turbo-instruct",
+ "weight": 1,
+ "auth": {
+ "query": {
+ "api_key": "apikey"
+ }
+ },
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ },
+ "override": {
+ "endpoint":
"http://localhost:6724/test/params/in/overridden/endpoint?some_query=yes"
+ }
+ }
+ ],
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 8: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 200
+--- error_log
+found query params: {"api_key":"apikey","some_query":"yes"}
+--- response_body
+passed