This is an automated email from the ASF dual-hosted git repository.

nic443 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git


The following commit(s) were added to refs/heads/master by this push:
     new cc7441fb1 feat(plugin): support `ai-proxy-multi` (#11986)
cc7441fb1 is described below

commit cc7441fb1e89a234994489bbd5ead8440b94ccb3
Author: Shreemaan Abhishek <[email protected]>
AuthorDate: Mon Feb 24 08:46:10 2025 +0545

    feat(plugin): support `ai-proxy-multi` (#11986)
---
 Makefile                                           |   4 +-
 apisix/cli/config.lua                              |   1 +
 apisix/plugins/ai-drivers/deepseek.lua             |  24 +
 .../openai-compatible.lua}                         |  40 +-
 apisix/plugins/ai-drivers/openai.lua               |  24 +
 apisix/plugins/ai-proxy-multi.lua                  | 236 +++++++
 apisix/plugins/ai-proxy.lua                        |  31 +-
 apisix/plugins/ai-proxy/schema.lua                 |  88 ++-
 conf/config.yaml.example                           |   1 +
 docs/en/latest/config.json                         |   1 +
 docs/en/latest/plugins/ai-proxy-multi.md           | 195 ++++++
 t/admin/plugins.t                                  |   1 +
 t/plugin/ai-proxy-multi.balancer.t                 | 470 ++++++++++++++
 t/plugin/ai-proxy-multi.t                          | 723 +++++++++++++++++++++
 t/plugin/ai-proxy-multi2.t                         | 361 ++++++++++
 15 files changed, 2176 insertions(+), 24 deletions(-)

diff --git a/Makefile b/Makefile
index a24e8f7b8..c288463c9 100644
--- a/Makefile
+++ b/Makefile
@@ -374,8 +374,8 @@ install: runtime
        $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy
        $(ENV_INSTALL) apisix/plugins/ai-proxy/*.lua 
$(ENV_INST_LUADIR)/apisix/plugins/ai-proxy
 
-       $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
-       $(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua 
$(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
+       $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-drivers
+       $(ENV_INSTALL) apisix/plugins/ai-drivers/*.lua 
$(ENV_INST_LUADIR)/apisix/plugins/ai-drivers
 
        $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
        $(ENV_INSTALL) apisix/plugins/ai-rag/embeddings/*.lua 
$(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua
index 6a05fed5d..376b5ed15 100644
--- a/apisix/cli/config.lua
+++ b/apisix/cli/config.lua
@@ -223,6 +223,7 @@ local _M = {
     "workflow",
     "api-breaker",
     "ai-proxy",
+    "ai-proxy-multi",
     "limit-conn",
     "limit-count",
     "limit-req",
diff --git a/apisix/plugins/ai-drivers/deepseek.lua 
b/apisix/plugins/ai-drivers/deepseek.lua
new file mode 100644
index 000000000..ab441c636
--- /dev/null
+++ b/apisix/plugins/ai-drivers/deepseek.lua
@@ -0,0 +1,24 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements.  See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License.  You may obtain a copy of the License at
+--
+--     http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+return require("apisix.plugins.ai-drivers.openai-compatible").new(
+    {
+        host = "api.deepseek.com",
+        path = "/chat/completions",
+        port = 443
+    }
+)
diff --git a/apisix/plugins/ai-proxy/drivers/openai.lua 
b/apisix/plugins/ai-drivers/openai-compatible.lua
similarity index 74%
rename from apisix/plugins/ai-proxy/drivers/openai.lua
rename to apisix/plugins/ai-drivers/openai-compatible.lua
index af0bc9758..fd5d2163c 100644
--- a/apisix/plugins/ai-proxy/drivers/openai.lua
+++ b/apisix/plugins/ai-drivers/openai-compatible.lua
@@ -16,27 +16,38 @@
 --
 local _M = {}
 
+local mt = {
+    __index = _M
+}
+
 local core = require("apisix.core")
 local http = require("resty.http")
 local url  = require("socket.url")
 
 local pairs = pairs
 local type  = type
+local setmetatable = setmetatable
+
+
+function _M.new(opts)
 
--- globals
-local DEFAULT_HOST = "api.openai.com"
-local DEFAULT_PORT = 443
-local DEFAULT_PATH = "/v1/chat/completions"
+    local self = {
+        host = opts.host,
+        port = opts.port,
+        path = opts.path,
+    }
+    return setmetatable(self, mt)
+end
 
 
-function _M.request(conf, request_table, ctx)
+function _M.request(self, conf, request_table, extra_opts)
     local httpc, err = http.new()
     if not httpc then
         return nil, "failed to create http client to send request to LLM 
server: " .. err
     end
     httpc:set_timeout(conf.timeout)
 
-    local endpoint = core.table.try_read_attr(conf, "override", "endpoint")
+    local endpoint = extra_opts.endpoint
     local parsed_url
     if endpoint then
         parsed_url = url.parse(endpoint)
@@ -44,10 +55,10 @@ function _M.request(conf, request_table, ctx)
 
     local ok, err = httpc:connect({
         scheme = endpoint and parsed_url.scheme or "https",
-        host = endpoint and parsed_url.host or DEFAULT_HOST,
-        port = endpoint and parsed_url.port or DEFAULT_PORT,
+        host = endpoint and parsed_url.host or self.host,
+        port = endpoint and parsed_url.port or self.port,
         ssl_verify = conf.ssl_verify,
-        ssl_server_name = endpoint and parsed_url.host or DEFAULT_HOST,
+        ssl_server_name = endpoint and parsed_url.host or self.host,
         pool_size = conf.keepalive and conf.keepalive_pool,
     })
 
@@ -55,7 +66,7 @@ function _M.request(conf, request_table, ctx)
         return nil, "failed to connect to LLM server: " .. err
     end
 
-    local query_params = conf.auth.query or {}
+    local query_params = extra_opts.query_params
 
     if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query 
> 0 then
         local args_tab = core.string.decode_args(parsed_url.query)
@@ -64,9 +75,9 @@ function _M.request(conf, request_table, ctx)
         end
     end
 
-    local path = (endpoint and parsed_url.path or DEFAULT_PATH)
+    local path = (endpoint and parsed_url.path or self.path)
 
-    local headers = (conf.auth.header or {})
+    local headers = extra_opts.headers
     headers["Content-Type"] = "application/json"
     local params = {
         method = "POST",
@@ -77,13 +88,14 @@ function _M.request(conf, request_table, ctx)
         query = query_params
     }
 
-    if conf.model.options then
-        for opt, val in pairs(conf.model.options) do
+    if extra_opts.model_options then
+        for opt, val in pairs(extra_opts.model_options) do
             request_table[opt] = val
         end
     end
     params.body = core.json.encode(request_table)
 
+    httpc:set_timeout(conf.keepalive_timeout)
     local res, err = httpc:request(params)
     if not res then
         return nil, err
diff --git a/apisix/plugins/ai-drivers/openai.lua 
b/apisix/plugins/ai-drivers/openai.lua
new file mode 100644
index 000000000..785ede193
--- /dev/null
+++ b/apisix/plugins/ai-drivers/openai.lua
@@ -0,0 +1,24 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements.  See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License.  You may obtain a copy of the License at
+--
+--     http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+return require("apisix.plugins.ai-drivers.openai-compatible").new(
+    {
+        host = "api.openai.com",
+        path = "/v1/chat/completions",
+        port = 443
+    }
+)
diff --git a/apisix/plugins/ai-proxy-multi.lua 
b/apisix/plugins/ai-proxy-multi.lua
new file mode 100644
index 000000000..48f0dea94
--- /dev/null
+++ b/apisix/plugins/ai-proxy-multi.lua
@@ -0,0 +1,236 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements.  See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License.  You may obtain a copy of the License at
+--
+--     http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+
+local core = require("apisix.core")
+local schema = require("apisix.plugins.ai-proxy.schema")
+local ai_proxy = require("apisix.plugins.ai-proxy")
+local plugin = require("apisix.plugin")
+
+local require = require
+local pcall = pcall
+local ipairs = ipairs
+local unpack = unpack
+local type = type
+
+local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
+local priority_balancer = require("apisix.balancer.priority")
+
+local pickers = {}
+local lrucache_server_picker = core.lrucache.new({
+    ttl = 300, count = 256
+})
+
+local plugin_name = "ai-proxy-multi"
+local _M = {
+    version = 0.5,
+    priority = 998,
+    name = plugin_name,
+    schema = schema.ai_proxy_multi_schema,
+}
+
+
+local function get_chash_key_schema(hash_on)
+    if hash_on == "vars" then
+        return core.schema.upstream_hash_vars_schema
+    end
+
+    if hash_on == "header" or hash_on == "cookie" then
+        return core.schema.upstream_hash_header_schema
+    end
+
+    if hash_on == "consumer" then
+        return nil, nil
+    end
+
+    if hash_on == "vars_combinations" then
+        return core.schema.upstream_hash_vars_combinations_schema
+    end
+
+    return nil, "invalid hash_on type " .. hash_on
+end
+
+
+function _M.check_schema(conf)
+    for _, provider in ipairs(conf.providers) do
+        local ai_driver = pcall(require, "apisix.plugins.ai-drivers." .. 
provider.name)
+        if not ai_driver then
+            return false, "provider: " .. provider.name .. " is not supported."
+        end
+    end
+    local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
+    local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on")
+    local hash_key = core.table.try_read_attr(conf, "balancer", "key")
+
+    if type(algo) == "string" and algo == "chash" then
+        if not hash_on then
+            return false, "must configure `hash_on` when balancer algorithm is 
chash"
+        end
+
+        if hash_on ~= "consumer" and not hash_key then
+            return false, "must configure `hash_key` when balancer `hash_on` 
is not set to cookie"
+        end
+
+        local key_schema, err = get_chash_key_schema(hash_on)
+        if err then
+            return false, "type is chash, err: " .. err
+        end
+
+        if key_schema then
+            local ok, err = core.schema.check(key_schema, hash_key)
+            if not ok then
+                return false, "invalid configuration: " .. err
+            end
+        end
+    end
+
+    return core.schema.check(schema.ai_proxy_multi_schema, conf)
+end
+
+
+local function transform_providers(new_providers, provider)
+    if not new_providers._priority_index then
+        new_providers._priority_index = {}
+    end
+
+    if not new_providers[provider.priority] then
+        new_providers[provider.priority] = {}
+        core.table.insert(new_providers._priority_index, provider.priority)
+    end
+
+    new_providers[provider.priority][provider.name] = provider.weight
+end
+
+
+local function create_server_picker(conf, ups_tab)
+    local picker = pickers[conf.balancer.algorithm] -- nil check
+    if not picker then
+        pickers[conf.balancer.algorithm] = require("apisix.balancer." .. 
conf.balancer.algorithm)
+        picker = pickers[conf.balancer.algorithm]
+    end
+    local new_providers = {}
+    for i, provider in ipairs(conf.providers) do
+        transform_providers(new_providers, provider)
+    end
+
+    if #new_providers._priority_index > 1 then
+        core.log.info("new providers: ", core.json.delay_encode(new_providers))
+        return priority_balancer.new(new_providers, ups_tab, picker)
+    end
+    core.log.info("upstream nodes: ",
+                
core.json.delay_encode(new_providers[new_providers._priority_index[1]]))
+    return picker.new(new_providers[new_providers._priority_index[1]], ups_tab)
+end
+
+
+local function get_provider_conf(providers, name)
+    for i, provider in ipairs(providers) do
+        if provider.name == name then
+            return provider
+        end
+    end
+end
+
+
+local function pick_target(ctx, conf, ups_tab)
+    ctx.ai_balancer_try_count = (ctx.ai_balancer_try_count or 0) + 1
+    if ctx.ai_balancer_try_count > 1 then
+        if ctx.server_picker and ctx.server_picker.after_balance then
+            ctx.server_picker.after_balance(ctx, true)
+        end
+    end
+
+    local server_picker = ctx.server_picker
+    if not server_picker then
+        server_picker = lrucache_server_picker(ctx.matched_route.key, 
plugin.conf_version(conf),
+                                               create_server_picker, conf, 
ups_tab)
+    end
+    if not server_picker then
+        return internal_server_error, "failed to fetch server picker"
+    end
+
+    local provider_name = server_picker.get(ctx)
+    local provider_conf = get_provider_conf(conf.providers, provider_name)
+
+    ctx.balancer_server = provider_name
+    ctx.server_picker = server_picker
+
+    return provider_name, provider_conf
+end
+
+
+local function get_load_balanced_provider(ctx, conf, ups_tab, request_table)
+    local provider_name, provider_conf
+    if #conf.providers == 1 then
+        provider_name = conf.providers[1].name
+        provider_conf = conf.providers[1]
+    else
+        provider_name, provider_conf = pick_target(ctx, conf, ups_tab)
+    end
+
+    core.log.info("picked provider: ", provider_name)
+    if provider_conf.model then
+        request_table.model = provider_conf.model
+    end
+
+    provider_conf.__name = provider_name
+    return provider_name, provider_conf
+end
+
+ai_proxy.get_model_name = function (...)
+end
+
+
+ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
+    local ups_tab = {}
+    local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
+    if algo == "chash" then
+        local hash_on = core.table.try_read_attr(conf, "balancer", "hash_on")
+        local hash_key = core.table.try_read_attr(conf, "balancer", "key")
+        ups_tab["key"] = hash_key
+        ups_tab["hash_on"] = hash_on
+    end
+
+    ::retry::
+    local provider, provider_conf = get_load_balanced_provider(ctx, conf, 
ups_tab, request_table)
+    local extra_opts = {
+        endpoint = core.table.try_read_attr(provider_conf, "override", 
"endpoint"),
+        query_params = provider_conf.auth.query or {},
+        headers = (provider_conf.auth.header or {}),
+        model_options = provider_conf.options,
+    }
+
+    local ai_driver = require("apisix.plugins.ai-drivers." .. provider)
+    local res, err, httpc = ai_driver:request(conf, request_table, extra_opts)
+    if not res then
+        if (ctx.balancer_try_count or 0) < 1 then
+            core.log.warn("failed to send request to LLM: ", err, ". 
Retrying...")
+            goto retry
+        end
+        return nil, err, nil
+    end
+
+    request_table.model = provider_conf.model
+    return res, nil, httpc
+end
+
+
+function _M.access(conf, ctx)
+    local rets = {ai_proxy.access(conf, ctx)}
+    return unpack(rets)
+end
+
+return _M
diff --git a/apisix/plugins/ai-proxy.lua b/apisix/plugins/ai-proxy.lua
index 8a0d8fa97..c27ca9a3b 100644
--- a/apisix/plugins/ai-proxy.lua
+++ b/apisix/plugins/ai-proxy.lua
@@ -34,11 +34,11 @@ local _M = {
 
 
 function _M.check_schema(conf)
-    local ai_driver = pcall(require, "apisix.plugins.ai-proxy.drivers." .. 
conf.model.provider)
+    local ai_driver = pcall(require, "apisix.plugins.ai-drivers." .. 
conf.model.provider)
     if not ai_driver then
         return false, "provider: " .. conf.model.provider .. " is not 
supported."
     end
-    return core.schema.check(schema.plugin_schema, conf)
+    return core.schema.check(schema.ai_proxy_schema, conf)
 end
 
 
@@ -54,6 +54,26 @@ local function keepalive_or_close(conf, httpc)
 end
 
 
+function _M.get_model_name(conf)
+    return conf.model.name
+end
+
+
+function _M.proxy_request_to_llm(conf, request_table, ctx)
+    local ai_driver = require("apisix.plugins.ai-drivers." .. 
conf.model.provider)
+    local extra_opts = {
+        endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
+        query_params = conf.auth.query or {},
+        headers = (conf.auth.header or {}),
+        model_options = conf.model.options
+    }
+    local res, err, httpc = ai_driver:request(conf, request_table, extra_opts)
+    if not res then
+        return nil, err, nil
+    end
+    return res, nil, httpc
+end
+
 function _M.access(conf, ctx)
     local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
     if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
@@ -70,16 +90,13 @@ function _M.access(conf, ctx)
         return bad_request, "request format doesn't match schema: " .. err
     end
 
-    if conf.model.name then
-        request_table.model = conf.model.name
-    end
+    request_table.model = _M.get_model_name(conf)
 
     if core.table.try_read_attr(conf, "model", "options", "stream") then
         request_table.stream = true
     end
 
-    local ai_driver = require("apisix.plugins.ai-proxy.drivers." .. 
conf.model.provider)
-    local res, err, httpc = ai_driver.request(conf, request_table, ctx)
+    local res, err, httpc = _M.proxy_request_to_llm(conf, request_table, ctx)
     if not res then
         core.log.error("failed to send request to LLM service: ", err)
         return internal_server_error
diff --git a/apisix/plugins/ai-proxy/schema.lua 
b/apisix/plugins/ai-proxy/schema.lua
index 382644dc2..d0ba33fdc 100644
--- a/apisix/plugins/ai-proxy/schema.lua
+++ b/apisix/plugins/ai-proxy/schema.lua
@@ -105,7 +105,48 @@ local model_schema = {
     required = {"provider", "name"}
 }
 
-_M.plugin_schema = {
+local provider_schema = {
+    type = "array",
+    minItems = 1,
+    items = {
+        type = "object",
+        properties = {
+            name = {
+                type = "string",
+                description = "Name of the AI service provider.",
+                enum = { "openai", "deepseek" }, -- add more providers later
+
+            },
+            model = {
+                type = "string",
+                description = "Model to execute.",
+            },
+            priority = {
+                type = "integer",
+                description = "Priority of the provider for load balancing",
+                default = 0,
+            },
+            weight = {
+                type = "integer",
+            },
+            auth = auth_schema,
+            options = model_options_schema,
+            override = {
+                type = "object",
+                properties = {
+                    endpoint = {
+                        type = "string",
+                        description = "To be specified to override the host of 
the AI provider",
+                    },
+                },
+            },
+        },
+        required = {"name", "model", "auth"}
+    },
+}
+
+
+_M.ai_proxy_schema = {
     type = "object",
     properties = {
         auth = auth_schema,
@@ -126,6 +167,51 @@ _M.plugin_schema = {
     required = {"model", "auth"}
 }
 
+_M.ai_proxy_multi_schema = {
+    type = "object",
+    properties = {
+        balancer = {
+            type = "object",
+            properties = {
+                algorithm = {
+                    type = "string",
+                    enum = { "chash", "roundrobin" },
+                },
+                hash_on = {
+                    type = "string",
+                    default = "vars",
+                    enum = {
+                      "vars",
+                      "header",
+                      "cookie",
+                      "consumer",
+                      "vars_combinations",
+                    },
+                },
+                key = {
+                    description = "the key of chash for dynamic load 
balancing",
+                    type = "string",
+                },
+            },
+            default = { algorithm = "roundrobin" }
+        },
+        providers = provider_schema,
+        passthrough = { type = "boolean", default = false },
+        timeout = {
+            type = "integer",
+            minimum = 1,
+            maximum = 60000,
+            default = 3000,
+            description = "timeout in milliseconds",
+        },
+        keepalive = {type = "boolean", default = true},
+        keepalive_timeout = {type = "integer", minimum = 1000, default = 
60000},
+        keepalive_pool = {type = "integer", minimum = 1, default = 30},
+        ssl_verify = {type = "boolean", default = true },
+    },
+    required = {"providers", }
+}
+
 _M.chat_request_schema = {
     type = "object",
     properties = {
diff --git a/conf/config.yaml.example b/conf/config.yaml.example
index 8052beef6..780340dcb 100644
--- a/conf/config.yaml.example
+++ b/conf/config.yaml.example
@@ -491,6 +491,7 @@ plugins:                           # plugin list (sorted by 
priority)
   - limit-req                      # priority: 1001
   #- node-status                   # priority: 1000
   - ai-proxy                       # priority: 999
+  - ai-proxy-multi                 # priority: 998
   #- brotli                        # priority: 996
   - gzip                           # priority: 995
   - server-info                    # priority: 990
diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json
index a17a6ae48..c8bf09ca7 100644
--- a/docs/en/latest/config.json
+++ b/docs/en/latest/config.json
@@ -100,6 +100,7 @@
             "plugins/degraphql",
             "plugins/body-transformer",
             "plugins/ai-proxy",
+            "plugins/ai-proxy-multi",
             "plugins/attach-consumer-label",
             "plugins/ai-rag"
           ]
diff --git a/docs/en/latest/plugins/ai-proxy-multi.md 
b/docs/en/latest/plugins/ai-proxy-multi.md
new file mode 100644
index 000000000..72d8a9cfa
--- /dev/null
+++ b/docs/en/latest/plugins/ai-proxy-multi.md
@@ -0,0 +1,195 @@
+---
+title: ai-proxy
+keywords:
+  - Apache APISIX
+  - API Gateway
+  - Plugin
+  - ai-proxy-multi
+description: This document contains information about the Apache APISIX 
ai-proxy-multi Plugin.
+---
+
+<!--
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+-->
+
+## Description
+
+The `ai-prox-multi` plugin simplifies access to LLM providers and models by 
defining a standard request format
+that allows key fields in plugin configuration to be embedded into the request.
+
+This plugin adds additional features like `load balancing` and `retries` to 
the existing `ai-proxy` plugin.
+
+Proxying requests to OpenAI is supported now. Other LLM services will be 
supported soon.
+
+## Request Format
+
+### OpenAI
+
+- Chat API
+
+| Name               | Type   | Required | Description                         
                |
+| ------------------ | ------ | -------- | 
--------------------------------------------------- |
+| `messages`         | Array  | Yes      | An array of message objects         
                |
+| `messages.role`    | String | Yes      | Role of the message (`system`, 
`user`, `assistant`) |
+| `messages.content` | String | Yes      | Content of the message              
                |
+
+## Plugin Attributes
+
+| **Name**                     | **Required** | **Type** | **Description**     
                                                                                
          | **Default** |
+| ---------------------------- | ------------ | -------- | 
-------------------------------------------------------------------------------------------------------------
 | ----------- |
+| providers                    | Yes          | array    | List of AI 
providers, each following the provider schema.                                  
                   |             |
+| provider.name                | Yes          | string   | Name of the AI 
service provider. Allowed values: `openai`, `deepseek`.                         
               |             |
+| provider.model               | Yes          | string   | Name of the AI 
model to execute. Example: `gpt-4o`.                                            
               |             |
+| provider.priority            | No           | integer  | Priority of the 
provider for load balancing.                                                    
              | 0           |
+| provider.weight              | No           | integer  | Load balancing 
weight.                                                                         
               |             |
+| balancer.algorithm           | No           | string   | Load balancing 
algorithm. Allowed values: `chash`, `roundrobin`.                               
               | roundrobin  |
+| balancer.hash_on             | No           | string   | Defines what to 
hash on for consistent hashing (`vars`, `header`, `cookie`, `consumer`, 
`vars_combinations`). | vars        |
+| balancer.key                 | No           | string   | Key for consistent 
hashing in dynamic load balancing.                                              
           |             |
+| provider.auth                | Yes          | object   | Authentication 
details, including headers and query parameters.                                
               |             |
+| provider.auth.header         | No           | object   | Authentication 
details sent via headers. Header name must match `^[a-zA-Z0-9._-]+$`.           
               |             |
+| provider.auth.query          | No           | object   | Authentication 
details sent via query parameters. Keys must match `^[a-zA-Z0-9._-]+$`.         
               |             |
+| provider.options.max_tokens  | No           | integer  | Defines the maximum 
tokens for chat or completion models.                                           
          | 256         |
+| provider.options.input_cost  | No           | number   | Cost per 1M tokens 
in the input prompt. Minimum is 0.                                              
           |             |
+| provider.options.output_cost | No           | number   | Cost per 1M tokens 
in the AI-generated output. Minimum is 0.                                       
           |             |
+| provider.options.temperature | No           | number   | Defines the model's 
temperature (0.0 - 5.0) for randomness in responses.                            
          |             |
+| provider.options.top_p       | No           | number   | Defines the top-p 
probability mass (0 - 1) for nucleus sampling.                                  
            |             |
+| provider.options.stream      | No           | boolean  | Enables streaming 
responses via SSE.                                                              
            | false       |
+| provider.override.endpoint   | No           | string   | Custom host 
override for the AI provider.                                                   
                  |             |
+| passthrough                  | No           | boolean  | If true, requests 
are forwarded without processing.                                               
            | false       |
+| timeout                      | No           | integer  | Request timeout in 
milliseconds (1-60000).                                                         
           | 3000        |
+| keepalive                    | No           | boolean  | Enables keepalive 
connections.                                                                    
            | true        |
+| keepalive_timeout            | No           | integer  | Timeout for 
keepalive connections (minimum 1000ms).                                         
                  | 60000       |
+| keepalive_pool               | No           | integer  | Maximum keepalive 
connections.                                                                    
            | 30          |
+| ssl_verify                   | No           | boolean  | Enables SSL 
certificate verification.                                                       
                  | true        |
+
+## Example usage
+
+Create a route with the `ai-proxy-multi` plugin like so:
+
+```shell
+curl "http://127.0.0.1:9180/apisix/admin/routes"; -X PUT \
+  -H "X-API-KEY: ${ADMIN_API_KEY}" \
+  -d '{
+    "id": "ai-proxy-multi-route",
+    "uri": "/anything",
+    "methods": ["POST"],
+    "plugins": {
+      "ai-proxy-multi": {
+        "providers": [
+          {
+            "name": "openai",
+            "model": "gpt-4",
+            "weight": 1,
+            "priority": 1,
+            "auth": {
+              "header": {
+                "Authorization": "Bearer '"$OPENAI_API_KEY"'"
+              }
+            },
+            "options": {
+                "max_tokens": 512,
+                "temperature": 1.0
+            }
+          },
+          {
+            "name": "deepseek",
+            "model": "deepseek-chat",
+            "weight": 1,
+            "auth": {
+              "header": {
+                "Authorization": "Bearer '"$DEEPSEEK_API_KEY"'"
+              }
+            },
+            "options": {
+                "max_tokens": 512,
+                "temperature": 1.0
+            }
+          }
+        ],
+        "passthrough": false
+      }
+    },
+    "upstream": {
+      "type": "roundrobin",
+      "nodes": {
+        "httpbin.org": 1
+      }
+    }
+  }'
+```
+
+In the above configuration, requests will be equally balanced among the 
`openai` and `deepseek` providers.
+
+### Retry and fallback:
+
+The `priority` attribute can be adjusted to implement the fallback and retry 
feature.
+
+```shell
+curl "http://127.0.0.1:9180/apisix/admin/routes"; -X PUT \
+  -H "X-API-KEY: ${ADMIN_API_KEY}" \
+  -d '{
+    "id": "ai-proxy-multi-route",
+    "uri": "/anything",
+    "methods": ["POST"],
+    "plugins": {
+      "ai-proxy-multi": {
+        "providers": [
+          {
+            "name": "openai",
+            "model": "gpt-4",
+            "weight": 1,
+            "priority": 1,
+            "auth": {
+              "header": {
+                "Authorization": "Bearer '"$OPENAI_API_KEY"'"
+              }
+            },
+            "options": {
+                "max_tokens": 512,
+                "temperature": 1.0
+            }
+          },
+          {
+            "name": "deepseek",
+            "model": "deepseek-chat",
+            "weight": 1,
+            "priority": 0,
+            "auth": {
+              "header": {
+                "Authorization": "Bearer '"$DEEPSEEK_API_KEY"'"
+              }
+            },
+            "options": {
+                "max_tokens": 512,
+                "temperature": 1.0
+            }
+          }
+        ],
+        "passthrough": false
+      }
+    },
+    "upstream": {
+      "type": "roundrobin",
+      "nodes": {
+        "httpbin.org": 1
+      }
+    }
+  }'
+```
+
+In the above configuration `priority` for the deepseek provider is set to `0`. 
Which means if `openai` provider is unavailable then `ai-proxy-multi` plugin 
will retry sending request to `deepseek` in the second attempt.
diff --git a/t/admin/plugins.t b/t/admin/plugins.t
index 6c574c2a4..7cb852cbf 100644
--- a/t/admin/plugins.t
+++ b/t/admin/plugins.t
@@ -106,6 +106,7 @@ limit-conn
 limit-count
 limit-req
 ai-proxy
+ai-proxy-multi
 gzip
 server-info
 traffic-split
diff --git a/t/plugin/ai-proxy-multi.balancer.t 
b/t/plugin/ai-proxy-multi.balancer.t
new file mode 100644
index 000000000..da26957fb
--- /dev/null
+++ b/t/plugin/ai-proxy-multi.balancer.t
@@ -0,0 +1,470 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+    my ($block) = @_;
+
+    if (!defined $block->request) {
+        $block->set_value("request", "GET /t");
+    }
+
+    my $user_yaml_config = <<_EOC_;
+plugins:
+  - ai-proxy-multi
+_EOC_
+    $block->set_value("extra_yaml_config", $user_yaml_config);
+
+    my $http_config = $block->http_config // <<_EOC_;
+        server {
+            server_name openai;
+            listen 6724;
+
+            default_type 'application/json';
+
+            location /v1/chat/completions {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("Unsupported request method: ", 
ngx.req.get_method())
+                    end
+                    ngx.req.read_body()
+                    local body, err = ngx.req.get_body_data()
+                    body, err = json.decode(body)
+
+                    local header_auth = ngx.req.get_headers()["authorization"]
+                    local query_auth = ngx.req.get_uri_args()["apikey"]
+
+                    if header_auth ~= "Bearer token" and query_auth ~= 
"apikey" then
+                        ngx.status = 401
+                        ngx.say("Unauthorized")
+                        return
+                    end
+
+                    if header_auth == "Bearer token" or query_auth == "apikey" 
then
+                        ngx.req.read_body()
+                        local body, err = ngx.req.get_body_data()
+                        body, err = json.decode(body)
+
+                        if not body.messages or #body.messages < 1 then
+                            ngx.status = 400
+                            ngx.say([[{ "error": "bad request"}]])
+                            return
+                        end
+
+                        ngx.status = 200
+                        ngx.print("openai")
+                        return
+                    end
+
+
+                    ngx.status = 503
+                    ngx.say("reached the end of the test suite")
+                }
+            }
+
+            location /chat/completions {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("Unsupported request method: ", 
ngx.req.get_method())
+                    end
+                    ngx.req.read_body()
+                    local body, err = ngx.req.get_body_data()
+                    body, err = json.decode(body)
+
+                    local header_auth = ngx.req.get_headers()["authorization"]
+                    local query_auth = ngx.req.get_uri_args()["apikey"]
+
+                    if header_auth ~= "Bearer token" and query_auth ~= 
"apikey" then
+                        ngx.status = 401
+                        ngx.say("Unauthorized")
+                        return
+                    end
+
+                    if header_auth == "Bearer token" or query_auth == "apikey" 
then
+                        ngx.req.read_body()
+                        local body, err = ngx.req.get_body_data()
+                        body, err = json.decode(body)
+
+                        if not body.messages or #body.messages < 1 then
+                            ngx.status = 400
+                            ngx.say([[{ "error": "bad request"}]])
+                            return
+                        end
+
+                        ngx.status = 200
+                        ngx.print("deepseek")
+                        return
+                    end
+
+
+                    ngx.status = 503
+                    ngx.say("reached the end of the test suite")
+                }
+            }
+        }
+_EOC_
+
+    $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: set route with roundrobin balancer, weight 4 and 1
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-4",
+                                    "weight": 4,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:6724";
+                                    }
+                                },
+                                {
+                                    "name": "deepseek",
+                                    "model": "gpt-4",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": 
"http://localhost:6724/chat/completions";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 2: test
+--- config
+    location /t {
+        content_by_lua_block {
+            local http = require "resty.http"
+            local uri = "http://127.0.0.1:"; .. ngx.var.server_port
+                        .. "/anything"
+
+            local restab = {}
+
+            local body = [[{ "messages": [ { "role": "system", "content": "You 
are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]]
+            for i = 1, 10 do
+                local httpc = http.new()
+                local res, err = httpc:request_uri(uri, {method = "POST", body 
= body})
+                if not res then
+                    ngx.say(err)
+                    return
+                end
+                table.insert(restab, res.body)
+            end
+
+            table.sort(restab)
+            ngx.log(ngx.WARN, "test picked providers: ", table.concat(restab, 
"."))
+
+        }
+    }
+--- request
+GET /t
+--- error_log
+deepseek.deepseek.openai.openai.openai.openai.openai.openai.openai.openai
+
+
+
+=== TEST 3: set route with chash balancer, weight 4 and 1
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "balancer": {
+                                "algorithm": "chash",
+                                "hash_on": "vars",
+                                "key": "query_string"
+                            },
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-4",
+                                    "weight": 4,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:6724";
+                                    }
+                                },
+                                {
+                                    "name": "deepseek",
+                                    "model": "gpt-4",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": 
"http://localhost:6724/chat/completions";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 4: test
+--- config
+    location /t {
+        content_by_lua_block {
+            local http = require "resty.http"
+            local uri = "http://127.0.0.1:"; .. ngx.var.server_port
+                        .. "/anything"
+
+            local restab = {}
+
+            local body = [[{ "messages": [ { "role": "system", "content": "You 
are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]]
+            for i = 1, 10 do
+                local httpc = http.new()
+                local query = {
+                    index = i
+                }
+                local res, err = httpc:request_uri(uri, {method = "POST", body 
= body, query = query})
+                if not res then
+                    ngx.say(err)
+                    return
+                end
+                table.insert(restab, res.body)
+            end
+
+            local count = {}
+            for _, value in ipairs(restab) do
+                count[value] = (count[value] or 0) + 1
+            end
+
+            for p, num in pairs(count) do
+                ngx.log(ngx.WARN, "distribution: ", p, ": ", num)
+            end
+
+        }
+    }
+--- request
+GET /t
+--- timeout: 10
+--- error_log
+distribution: deepseek: 2
+distribution: openai: 8
+
+
+
+=== TEST 5: retry logic with different priorities
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-4",
+                                    "weight": 1,
+                                    "priority": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:9999";
+                                    }
+                                },
+                                {
+                                    "name": "deepseek",
+                                    "model": "gpt-4",
+                                    "priority": 0,
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": 
"http://localhost:6724/chat/completions";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 6: test
+--- config
+    location /t {
+        content_by_lua_block {
+            local http = require "resty.http"
+            local uri = "http://127.0.0.1:"; .. ngx.var.server_port
+                        .. "/anything"
+
+            local restab = {}
+
+            local body = [[{ "messages": [ { "role": "system", "content": "You 
are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }]]
+                local httpc = http.new()
+                local res, err = httpc:request_uri(uri, {method = "POST", body 
= body})
+                if not res then
+                    ngx.say(err)
+                    return
+                end
+                ngx.say(res.body)
+
+        }
+    }
+--- request
+GET /t
+--- response_body
+deepseek
+--- error_log
+failed to send request to LLM: failed to connect to LLM server: connection 
refused. Retrying...
diff --git a/t/plugin/ai-proxy-multi.t b/t/plugin/ai-proxy-multi.t
new file mode 100644
index 000000000..68eed015d
--- /dev/null
+++ b/t/plugin/ai-proxy-multi.t
@@ -0,0 +1,723 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+    my ($block) = @_;
+
+    if (!defined $block->request) {
+        $block->set_value("request", "GET /t");
+    }
+
+    my $user_yaml_config = <<_EOC_;
+plugins:
+  - ai-proxy-multi
+_EOC_
+    $block->set_value("extra_yaml_config", $user_yaml_config);
+
+    my $http_config = $block->http_config // <<_EOC_;
+        server {
+            server_name openai;
+            listen 6724;
+
+            default_type 'application/json';
+
+            location /anything {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("Unsupported request method: ", 
ngx.req.get_method())
+                    end
+                    ngx.req.read_body()
+                    local body = ngx.req.get_body_data()
+
+                    if body ~= "SELECT * FROM STUDENTS" then
+                        ngx.status = 503
+                        ngx.say("passthrough doesn't work")
+                        return
+                    end
+                    ngx.say('{"foo", "bar"}')
+                }
+            }
+
+            location /v1/chat/completions {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("Unsupported request method: ", 
ngx.req.get_method())
+                    end
+                    ngx.req.read_body()
+                    local body, err = ngx.req.get_body_data()
+                    body, err = json.decode(body)
+
+                    local test_type = ngx.req.get_headers()["test-type"]
+                    if test_type == "options" then
+                        if body.foo == "bar" then
+                            ngx.status = 200
+                            ngx.say("options works")
+                        else
+                            ngx.status = 500
+                            ngx.say("model options feature doesn't work")
+                        end
+                        return
+                    end
+
+                    local header_auth = ngx.req.get_headers()["authorization"]
+                    local query_auth = ngx.req.get_uri_args()["apikey"]
+
+                    if header_auth ~= "Bearer token" and query_auth ~= 
"apikey" then
+                        ngx.status = 401
+                        ngx.say("Unauthorized")
+                        return
+                    end
+
+                    if header_auth == "Bearer token" or query_auth == "apikey" 
then
+                        ngx.req.read_body()
+                        local body, err = ngx.req.get_body_data()
+                        body, err = json.decode(body)
+
+                        if not body.messages or #body.messages < 1 then
+                            ngx.status = 400
+                            ngx.say([[{ "error": "bad request"}]])
+                            return
+                        end
+
+                        if body.messages[1].content == "write an SQL query to 
get all rows from student table" then
+                            ngx.print("SELECT * FROM STUDENTS")
+                            return
+                        end
+
+                        ngx.status = 200
+                        ngx.say([[$resp]])
+                        return
+                    end
+
+
+                    ngx.status = 503
+                    ngx.say("reached the end of the test suite")
+                }
+            }
+
+            location /random {
+                content_by_lua_block {
+                    ngx.say("path override works")
+                }
+            }
+        }
+_EOC_
+
+    $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: minimal viable configuration
+--- config
+    location /t {
+        content_by_lua_block {
+            local plugin = require("apisix.plugins.ai-proxy-multi")
+            local ok, err = plugin.check_schema({
+                providers = {
+                    {
+                        name = "openai",
+                        model = "gpt-4",
+                        weight = 1,
+                        auth = {
+                            header = {
+                                some_header = "some_value"
+                            }
+                        }
+                    }
+                }
+            })
+
+            if not ok then
+                ngx.say(err)
+            else
+                ngx.say("passed")
+            end
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 2: unsupported provider
+--- config
+    location /t {
+        content_by_lua_block {
+            local plugin = require("apisix.plugins.ai-proxy-multi")
+            local ok, err = plugin.check_schema({
+                providers = {
+                    {
+                        name = "some-unique",
+                        model = "gpt-4",
+                        weight = 1,
+                        auth = {
+                            header = {
+                                some_header = "some_value"
+                            }
+                        }
+                    }
+                }
+            })
+
+            if not ok then
+                ngx.say(err)
+            else
+                ngx.say("passed")
+            end
+        }
+    }
+--- response_body eval
+qr/.*provider: some-unique is not supported.*/
+
+
+
+=== TEST 3: set route with wrong auth header
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-4",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer 
wrongtoken"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:6724";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 4: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 401
+--- response_body
+Unauthorized
+
+
+
+=== TEST 5: set route with right auth header
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-4",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:6724";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 6: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body eval
+qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/
+
+
+
+=== TEST 7: send request with empty body
+--- request
+POST /anything
+--- more_headers
+Authorization: Bearer token
+--- error_code: 400
+--- response_body_chomp
+failed to get request body: request body is empty
+
+
+
+=== TEST 8: send request with wrong method (GET) should work
+--- request
+GET /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body eval
+qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/
+
+
+
+=== TEST 9: wrong JSON in request body should give error
+--- request
+GET /anything
+{}"messages": [ { "role": "system", "cont
+--- error_code: 400
+--- response_body
+{"message":"could not get parse JSON request body: Expected the end but found 
T_STRING at character 3"}
+
+
+
+=== TEST 10: content-type should be JSON
+--- request
+POST /anything
+prompt%3Dwhat%2520is%25201%2520%252B%25201
+--- more_headers
+Content-Type: application/x-www-form-urlencoded
+--- error_code: 400
+--- response_body chomp
+unsupported content-type: application/x-www-form-urlencoded
+
+
+
+=== TEST 11: request schema validity check
+--- request
+POST /anything
+{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 400
+--- response_body chomp
+request format doesn't match schema: property "messages" is required
+
+
+
+=== TEST 12: model options being merged to request body
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "some-model",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "foo": "bar",
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:6724";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                 }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+                ngx.say(body)
+                return
+            end
+
+            local code, body, actual_body = t("/anything",
+                ngx.HTTP_POST,
+                [[{
+                    "messages": [
+                        { "role": "system", "content": "You are a 
mathematician" },
+                        { "role": "user", "content": "What is 1+1?" }
+                    ]
+                }]],
+                nil,
+                {
+                    ["test-type"] = "options",
+                    ["Content-Type"] = "application/json",
+                }
+            )
+
+            ngx.status = code
+            ngx.say(actual_body)
+
+        }
+    }
+--- error_code: 200
+--- response_body_chomp
+options_works
+
+
+
+=== TEST 13: override path
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "some-model",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "foo": "bar",
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": 
"http://localhost:6724/random";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                 }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+                ngx.say(body)
+                return
+            end
+
+            local code, body, actual_body = t("/anything",
+                ngx.HTTP_POST,
+                [[{
+                    "messages": [
+                        { "role": "system", "content": "You are a 
mathematician" },
+                        { "role": "user", "content": "What is 1+1?" }
+                    ]
+                }]],
+                nil,
+                {
+                    ["test-type"] = "path",
+                    ["Content-Type"] = "application/json",
+                }
+            )
+
+            ngx.status = code
+            ngx.say(actual_body)
+
+        }
+    }
+--- response_body_chomp
+path override works
+
+
+
+=== TEST 14: set route with right auth header
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-35-turbo-instruct",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:6724";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false,
+                            "passthrough": true
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "127.0.0.1:6724": 1
+                        }
+                    }
+                 }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 15: send request with wrong method should work
+--- request
+POST /anything
+{ "messages": [ { "role": "user", "content": "write an SQL query to get all 
rows from student table" } ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body
+{"foo", "bar"}
+
+
+
+=== TEST 16: set route with stream = true (SSE)
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-35-turbo-instruct",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "Bearer token"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0,
+                                        "stream": true
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:7737";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                 }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 17: test is SSE works as expected
+--- config
+    location /t {
+        content_by_lua_block {
+            local http = require("resty.http")
+            local httpc = http.new()
+            local core = require("apisix.core")
+
+            local ok, err = httpc:connect({
+                scheme = "http",
+                host = "localhost",
+                port = ngx.var.server_port,
+            })
+
+            if not ok then
+                ngx.status = 500
+                ngx.say(err)
+                return
+            end
+
+            local params = {
+                method = "POST",
+                headers = {
+                    ["Content-Type"] = "application/json",
+                },
+                path = "/anything",
+                body = [[{
+                    "messages": [
+                        { "role": "system", "content": "some content" }
+                    ]
+                }]],
+            }
+
+            local res, err = httpc:request(params)
+            if not res then
+                ngx.status = 500
+                ngx.say(err)
+                return
+            end
+
+            local final_res = {}
+            while true do
+                local chunk, err = res.body_reader() -- will read chunk by 
chunk
+                if err then
+                    core.log.error("failed to read response chunk: ", err)
+                    break
+                end
+                if not chunk then
+                    break
+                end
+                core.table.insert_tail(final_res, chunk)
+            end
+
+            ngx.print(#final_res .. final_res[6])
+        }
+    }
+--- response_body_like eval
+qr/6data: \[DONE\]\n\n/
diff --git a/t/plugin/ai-proxy-multi2.t b/t/plugin/ai-proxy-multi2.t
new file mode 100644
index 000000000..af5c4e880
--- /dev/null
+++ b/t/plugin/ai-proxy-multi2.t
@@ -0,0 +1,361 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+    my ($block) = @_;
+
+    if (!defined $block->request) {
+        $block->set_value("request", "GET /t");
+    }
+
+    my $user_yaml_config = <<_EOC_;
+plugins:
+  - ai-proxy-multi
+_EOC_
+    $block->set_value("extra_yaml_config", $user_yaml_config);
+
+    my $http_config = $block->http_config // <<_EOC_;
+        server {
+            server_name openai;
+            listen 6724;
+
+            default_type 'application/json';
+
+            location /v1/chat/completions {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("Unsupported request method: ", 
ngx.req.get_method())
+                    end
+                    ngx.req.read_body()
+                    local body, err = ngx.req.get_body_data()
+                    body, err = json.decode(body)
+
+                    local query_auth = ngx.req.get_uri_args()["api_key"]
+
+                    if query_auth ~= "apikey" then
+                        ngx.status = 401
+                        ngx.say("Unauthorized")
+                        return
+                    end
+
+
+                    ngx.status = 200
+                    ngx.say("passed")
+                }
+            }
+
+
+            location /test/params/in/overridden/endpoint {
+                content_by_lua_block {
+                    local json = require("cjson.safe")
+                    local core = require("apisix.core")
+
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("Unsupported request method: ", 
ngx.req.get_method())
+                    end
+
+                    local query_auth = ngx.req.get_uri_args()["api_key"]
+                    ngx.log(ngx.INFO, "found query params: ", 
core.json.stably_encode(ngx.req.get_uri_args()))
+
+                    if query_auth ~= "apikey" then
+                        ngx.status = 401
+                        ngx.say("Unauthorized")
+                        return
+                    end
+
+                    ngx.status = 200
+                    ngx.say("passed")
+                }
+            }
+        }
+_EOC_
+
+    $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: set route with wrong query param
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-35-turbo-instruct",
+                                    "weight": 1,
+                                    "auth": {
+                                        "query": {
+                                            "api_key": "wrong_key"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:6724";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                 }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 2: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 401
+--- response_body
+Unauthorized
+
+
+
+=== TEST 3: set route with right query param
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-35-turbo-instruct",
+                                    "weight": 1,
+                                    "auth": {
+                                        "query": {
+                                            "api_key": "apikey"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                        "endpoint": "http://localhost:6724";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                 }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 4: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 200
+--- response_body
+passed
+
+
+
+=== TEST 5: set route without overriding the endpoint_url
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-35-turbo-instruct",
+                                    "weight": 1,
+                                    "auth": {
+                                        "header": {
+                                            "Authorization": "some-key"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "httpbin.org": 1
+                        }
+                    }
+                 }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 6: send request
+--- custom_trusted_cert: /etc/ssl/cert.pem
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 401
+
+
+
+=== TEST 7: query params in override.endpoint should be sent to LLM
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/anything",
+                    "plugins": {
+                        "ai-proxy-multi": {
+                            "providers": [
+                                {
+                                    "name": "openai",
+                                    "model": "gpt-35-turbo-instruct",
+                                    "weight": 1,
+                                    "auth": {
+                                        "query": {
+                                            "api_key": "apikey"
+                                        }
+                                    },
+                                    "options": {
+                                        "max_tokens": 512,
+                                        "temperature": 1.0
+                                    },
+                                    "override": {
+                                       "endpoint": 
"http://localhost:6724/test/params/in/overridden/endpoint?some_query=yes";
+                                    }
+                                }
+                            ],
+                            "ssl_verify": false
+                        }
+                    },
+                    "upstream": {
+                        "type": "roundrobin",
+                        "nodes": {
+                            "canbeanything.com": 1
+                        }
+                    }
+                 }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+            end
+            ngx.say(body)
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 8: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { 
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 200
+--- error_log
+found query params: {"api_key":"apikey","some_query":"yes"}
+--- response_body
+passed

Reply via email to