This is an automated email from the ASF dual-hosted git repository.
shreemaanabhishek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git
The following commit(s) were added to refs/heads/master by this push:
new d46737fe7 feat: ai-proxy plugin (#11499)
d46737fe7 is described below
commit d46737fe70b6ce332146a9eb322e76997c8fa8ba
Author: Shreemaan Abhishek <[email protected]>
AuthorDate: Tue Sep 17 10:08:58 2024 +0545
feat: ai-proxy plugin (#11499)
---
Makefile | 6 +
apisix/cli/config.lua | 1 +
apisix/core/request.lua | 16 +
apisix/plugins/ai-proxy.lua | 138 ++++++
apisix/plugins/ai-proxy/drivers/openai.lua | 85 ++++
apisix/plugins/ai-proxy/schema.lua | 154 +++++++
ci/common.sh | 21 +
ci/linux_openresty_common_runner.sh | 2 +
ci/redhat-ci.sh | 2 +
conf/config.yaml.example | 1 +
docs/en/latest/config.json | 3 +-
docs/en/latest/plugins/ai-proxy.md | 144 ++++++
t/admin/plugins.t | 1 +
t/assets/ai-proxy-response.json | 15 +
t/plugin/ai-proxy.t | 693 +++++++++++++++++++++++++++++
t/plugin/ai-proxy2.t | 200 +++++++++
t/sse_server_example/go.mod | 3 +
t/sse_server_example/main.go | 58 +++
18 files changed, 1542 insertions(+), 1 deletion(-)
diff --git a/Makefile b/Makefile
index 21a238963..545a21e4f 100644
--- a/Makefile
+++ b/Makefile
@@ -374,6 +374,12 @@ install: runtime
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/utils
$(ENV_INSTALL) apisix/utils/*.lua $(ENV_INST_LUADIR)/apisix/utils/
+ $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy
+ $(ENV_INSTALL) apisix/plugins/ai-proxy/*.lua
$(ENV_INST_LUADIR)/apisix/plugins/ai-proxy
+
+ $(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
+ $(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua
$(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
+
$(ENV_INSTALL) bin/apisix $(ENV_INST_BINDIR)/apisix
diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua
index 6ab10c925..f5c5d8dca 100644
--- a/apisix/cli/config.lua
+++ b/apisix/cli/config.lua
@@ -219,6 +219,7 @@ local _M = {
"proxy-rewrite",
"workflow",
"api-breaker",
+ "ai-proxy",
"limit-conn",
"limit-count",
"limit-req",
diff --git a/apisix/core/request.lua b/apisix/core/request.lua
index c5278b6b8..fef4bf17e 100644
--- a/apisix/core/request.lua
+++ b/apisix/core/request.lua
@@ -21,6 +21,7 @@
local lfs = require("lfs")
local log = require("apisix.core.log")
+local json = require("apisix.core.json")
local io = require("apisix.core.io")
local req_add_header
if ngx.config.subsystem == "http" then
@@ -334,6 +335,21 @@ function _M.get_body(max_size, ctx)
end
+function _M.get_json_request_body_table()
+ local body, err = _M.get_body()
+ if not body then
+ return nil, { message = "could not get body: " .. (err or "request
body is empty") }
+ end
+
+ local body_tab, err = json.decode(body)
+ if not body_tab then
+ return nil, { message = "could not get parse JSON request body: " ..
err }
+ end
+
+ return body_tab
+end
+
+
function _M.get_scheme(ctx)
if not ctx then
ctx = ngx.ctx.api_ctx
diff --git a/apisix/plugins/ai-proxy.lua b/apisix/plugins/ai-proxy.lua
new file mode 100644
index 000000000..8a0d8fa97
--- /dev/null
+++ b/apisix/plugins/ai-proxy.lua
@@ -0,0 +1,138 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+local core = require("apisix.core")
+local schema = require("apisix.plugins.ai-proxy.schema")
+local require = require
+local pcall = pcall
+local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
+local bad_request = ngx.HTTP_BAD_REQUEST
+local ngx_req = ngx.req
+local ngx_print = ngx.print
+local ngx_flush = ngx.flush
+
+local plugin_name = "ai-proxy"
+local _M = {
+ version = 0.5,
+ priority = 999,
+ name = plugin_name,
+ schema = schema,
+}
+
+
+function _M.check_schema(conf)
+ local ai_driver = pcall(require, "apisix.plugins.ai-proxy.drivers." ..
conf.model.provider)
+ if not ai_driver then
+ return false, "provider: " .. conf.model.provider .. " is not
supported."
+ end
+ return core.schema.check(schema.plugin_schema, conf)
+end
+
+
+local CONTENT_TYPE_JSON = "application/json"
+
+
+local function keepalive_or_close(conf, httpc)
+ if conf.set_keepalive then
+ httpc:set_keepalive(10000, 100)
+ return
+ end
+ httpc:close()
+end
+
+
+function _M.access(conf, ctx)
+ local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
+ if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
+ return bad_request, "unsupported content-type: " .. ct
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return bad_request, err
+ end
+
+ local ok, err = core.schema.check(schema.chat_request_schema,
request_table)
+ if not ok then
+ return bad_request, "request format doesn't match schema: " .. err
+ end
+
+ if conf.model.name then
+ request_table.model = conf.model.name
+ end
+
+ if core.table.try_read_attr(conf, "model", "options", "stream") then
+ request_table.stream = true
+ end
+
+ local ai_driver = require("apisix.plugins.ai-proxy.drivers." ..
conf.model.provider)
+ local res, err, httpc = ai_driver.request(conf, request_table, ctx)
+ if not res then
+ core.log.error("failed to send request to LLM service: ", err)
+ return internal_server_error
+ end
+
+ local body_reader = res.body_reader
+ if not body_reader then
+ core.log.error("LLM sent no response body")
+ return internal_server_error
+ end
+
+ if conf.passthrough then
+ ngx_req.init_body()
+ while true do
+ local chunk, err = body_reader() -- will read chunk by chunk
+ if err then
+ core.log.error("failed to read response chunk: ", err)
+ break
+ end
+ if not chunk then
+ break
+ end
+ ngx_req.append_body(chunk)
+ end
+ ngx_req.finish_body()
+ keepalive_or_close(conf, httpc)
+ return
+ end
+
+ if request_table.stream then
+ while true do
+ local chunk, err = body_reader() -- will read chunk by chunk
+ if err then
+ core.log.error("failed to read response chunk: ", err)
+ break
+ end
+ if not chunk then
+ break
+ end
+ ngx_print(chunk)
+ ngx_flush(true)
+ end
+ keepalive_or_close(conf, httpc)
+ return
+ else
+ local res_body, err = res:read_body()
+ if not res_body then
+ core.log.error("failed to read response body: ", err)
+ return internal_server_error
+ end
+ keepalive_or_close(conf, httpc)
+ return res.status, res_body
+ end
+end
+
+return _M
diff --git a/apisix/plugins/ai-proxy/drivers/openai.lua
b/apisix/plugins/ai-proxy/drivers/openai.lua
new file mode 100644
index 000000000..c8f7f4b62
--- /dev/null
+++ b/apisix/plugins/ai-proxy/drivers/openai.lua
@@ -0,0 +1,85 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+local _M = {}
+
+local core = require("apisix.core")
+local http = require("resty.http")
+local url = require("socket.url")
+
+local pairs = pairs
+
+-- globals
+local DEFAULT_HOST = "api.openai.com"
+local DEFAULT_PORT = 443
+local DEFAULT_PATH = "/v1/chat/completions"
+
+
+function _M.request(conf, request_table, ctx)
+ local httpc, err = http.new()
+ if not httpc then
+ return nil, "failed to create http client to send request to LLM
server: " .. err
+ end
+ httpc:set_timeout(conf.timeout)
+
+ local endpoint = core.table.try_read_attr(conf, "override", "endpoint")
+ local parsed_url
+ if endpoint then
+ parsed_url = url.parse(endpoint)
+ end
+
+ local ok, err = httpc:connect({
+ scheme = parsed_url.scheme or "https",
+ host = parsed_url.host or DEFAULT_HOST,
+ port = parsed_url.port or DEFAULT_PORT,
+ ssl_verify = conf.ssl_verify,
+ ssl_server_name = parsed_url.host or DEFAULT_HOST,
+ pool_size = conf.keepalive and conf.keepalive_pool,
+ })
+
+ if not ok then
+ return nil, "failed to connect to LLM server: " .. err
+ end
+
+ local path = (parsed_url.path or DEFAULT_PATH)
+
+ local headers = (conf.auth.header or {})
+ headers["Content-Type"] = "application/json"
+ local params = {
+ method = "POST",
+ headers = headers,
+ keepalive = conf.keepalive,
+ ssl_verify = conf.ssl_verify,
+ path = path,
+ query = conf.auth.query
+ }
+
+ if conf.model.options then
+ for opt, val in pairs(conf.model.options) do
+ request_table[opt] = val
+ end
+ end
+ params.body = core.json.encode(request_table)
+
+ local res, err = httpc:request(params)
+ if not res then
+ return nil, err
+ end
+
+ return res, nil, httpc
+end
+
+return _M
diff --git a/apisix/plugins/ai-proxy/schema.lua
b/apisix/plugins/ai-proxy/schema.lua
new file mode 100644
index 000000000..382644dc2
--- /dev/null
+++ b/apisix/plugins/ai-proxy/schema.lua
@@ -0,0 +1,154 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one or more
+-- contributor license agreements. See the NOTICE file distributed with
+-- this work for additional information regarding copyright ownership.
+-- The ASF licenses this file to You under the Apache License, Version 2.0
+-- (the "License"); you may not use this file except in compliance with
+-- the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+--
+local _M = {}
+
+local auth_item_schema = {
+ type = "object",
+ patternProperties = {
+ ["^[a-zA-Z0-9._-]+$"] = {
+ type = "string"
+ }
+ }
+}
+
+local auth_schema = {
+ type = "object",
+ patternProperties = {
+ header = auth_item_schema,
+ query = auth_item_schema,
+ },
+ additionalProperties = false,
+}
+
+local model_options_schema = {
+ description = "Key/value settings for the model",
+ type = "object",
+ properties = {
+ max_tokens = {
+ type = "integer",
+ description = "Defines the max_tokens, if using chat or completion
models.",
+ default = 256
+
+ },
+ input_cost = {
+ type = "number",
+ description = "Defines the cost per 1M tokens in your prompt.",
+ minimum = 0
+
+ },
+ output_cost = {
+ type = "number",
+ description = "Defines the cost per 1M tokens in the output of the
AI.",
+ minimum = 0
+
+ },
+ temperature = {
+ type = "number",
+ description = "Defines the matching temperature, if using chat or
completion models.",
+ minimum = 0.0,
+ maximum = 5.0,
+
+ },
+ top_p = {
+ type = "number",
+ description = "Defines the top-p probability mass, if supported.",
+ minimum = 0,
+ maximum = 1,
+
+ },
+ stream = {
+ description = "Stream response by SSE",
+ type = "boolean",
+ default = false,
+ }
+ }
+}
+
+local model_schema = {
+ type = "object",
+ properties = {
+ provider = {
+ type = "string",
+ description = "Name of the AI service provider.",
+ oneOf = { "openai" }, -- add more providers later
+
+ },
+ name = {
+ type = "string",
+ description = "Model name to execute.",
+ },
+ options = model_options_schema,
+ override = {
+ type = "object",
+ properties = {
+ endpoint = {
+ type = "string",
+ description = "To be specified to override the host of the
AI provider",
+ },
+ }
+ }
+ },
+ required = {"provider", "name"}
+}
+
+_M.plugin_schema = {
+ type = "object",
+ properties = {
+ auth = auth_schema,
+ model = model_schema,
+ passthrough = { type = "boolean", default = false },
+ timeout = {
+ type = "integer",
+ minimum = 1,
+ maximum = 60000,
+ default = 3000,
+ description = "timeout in milliseconds",
+ },
+ keepalive = {type = "boolean", default = true},
+ keepalive_timeout = {type = "integer", minimum = 1000, default =
60000},
+ keepalive_pool = {type = "integer", minimum = 1, default = 30},
+ ssl_verify = {type = "boolean", default = true },
+ },
+ required = {"model", "auth"}
+}
+
+_M.chat_request_schema = {
+ type = "object",
+ properties = {
+ messages = {
+ type = "array",
+ minItems = 1,
+ items = {
+ properties = {
+ role = {
+ type = "string",
+ enum = {"system", "user", "assistant"}
+ },
+ content = {
+ type = "string",
+ minLength = "1",
+ },
+ },
+ additionalProperties = false,
+ required = {"role", "content"},
+ },
+ }
+ },
+ required = {"messages"}
+}
+
+return _M
diff --git a/ci/common.sh b/ci/common.sh
index 146b7aa50..ae5d12b2b 100644
--- a/ci/common.sh
+++ b/ci/common.sh
@@ -203,3 +203,24 @@ function start_grpc_server_example() {
ss -lntp | grep 10051 | grep grpc_server && break
done
}
+
+
+function start_sse_server_example() {
+ # build sse_server_example
+ pushd t/sse_server_example
+ go build
+ ./sse_server_example 7737 2>&1 &
+
+ for (( i = 0; i <= 10; i++ )); do
+ sleep 0.5
+ SSE_PROC=`ps -ef | grep sse_server_example | grep -v grep || echo
"none"`
+ if [[ $SSE_PROC == "none" || "$i" -eq 10 ]]; then
+ echo "failed to start sse_server_example"
+ ss -antp | grep 7737 || echo "no proc listen port 7737"
+ exit 1
+ else
+ break
+ fi
+ done
+ popd
+}
diff --git a/ci/linux_openresty_common_runner.sh
b/ci/linux_openresty_common_runner.sh
index ea2e8b41c..1b73ceec9 100755
--- a/ci/linux_openresty_common_runner.sh
+++ b/ci/linux_openresty_common_runner.sh
@@ -77,6 +77,8 @@ script() {
start_grpc_server_example
+ start_sse_server_example
+
# APISIX_ENABLE_LUACOV=1 PERL5LIB=.:$PERL5LIB prove -Itest-nginx/lib -r t
FLUSH_ETCD=1 TEST_EVENTS_MODULE=$TEST_EVENTS_MODULE prove --timer
-Itest-nginx/lib -I./ -r $TEST_FILE_SUB_DIR | tee /tmp/test.result
rerun_flaky_tests /tmp/test.result
diff --git a/ci/redhat-ci.sh b/ci/redhat-ci.sh
index 3cad10b59..da9839d4e 100755
--- a/ci/redhat-ci.sh
+++ b/ci/redhat-ci.sh
@@ -77,6 +77,8 @@ install_dependencies() {
yum install -y iproute procps
start_grpc_server_example
+ start_sse_server_example
+
# installing grpcurl
install_grpcurl
diff --git a/conf/config.yaml.example b/conf/config.yaml.example
index da125f77d..bd741b2f7 100644
--- a/conf/config.yaml.example
+++ b/conf/config.yaml.example
@@ -486,6 +486,7 @@ plugins: # plugin list (sorted by
priority)
- limit-count # priority: 1002
- limit-req # priority: 1001
#- node-status # priority: 1000
+ - ai-proxy # priority: 999
#- brotli # priority: 996
- gzip # priority: 995
- server-info # priority: 990
diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json
index 2195688a3..ad9c1e051 100644
--- a/docs/en/latest/config.json
+++ b/docs/en/latest/config.json
@@ -96,7 +96,8 @@
"plugins/fault-injection",
"plugins/mocking",
"plugins/degraphql",
- "plugins/body-transformer"
+ "plugins/body-transformer",
+ "plugins/ai-proxy"
]
},
{
diff --git a/docs/en/latest/plugins/ai-proxy.md
b/docs/en/latest/plugins/ai-proxy.md
new file mode 100644
index 000000000..a6a4e3542
--- /dev/null
+++ b/docs/en/latest/plugins/ai-proxy.md
@@ -0,0 +1,144 @@
+---
+title: ai-proxy
+keywords:
+ - Apache APISIX
+ - API Gateway
+ - Plugin
+ - ai-proxy
+description: This document contains information about the Apache APISIX
ai-proxy Plugin.
+---
+
+<!--
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+-->
+
+## Description
+
+The `ai-proxy` plugin simplifies access to LLM providers and models by
defining a standard request format
+that allows key fields in plugin configuration to be embedded into the request.
+
+Proxying requests to OpenAI is supported now. Other LLM services will be
supported soon.
+
+## Request Format
+
+### OpenAI
+
+- Chat API
+
+| Name | Type | Required | Description
|
+| ------------------ | ------ | -------- |
--------------------------------------------------- |
+| `messages` | Array | Yes | An array of message objects
|
+| `messages.role` | String | Yes | Role of the message (`system`,
`user`, `assistant`) |
+| `messages.content` | String | Yes | Content of the message
|
+
+## Plugin Attributes
+
+| **Field** | **Required** | **Type** | **Description**
|
+| ------------------------- | ------------ | -------- |
------------------------------------------------------------------------------------
|
+| auth | Yes | Object | Authentication
configuration |
+| auth.header | No | Object | Authentication
headers. Key must match pattern `^[a-zA-Z0-9._-]+$`. |
+| auth.query | No | Object | Authentication query
parameters. Key must match pattern `^[a-zA-Z0-9._-]+$`. |
+| model.provider | Yes | String | Name of the AI service
provider (`openai`). |
+| model.name | Yes | String | Model name to execute.
|
+| model.options | No | Object | Key/value settings for
the model |
+| model.options.max_tokens | No | Integer | Defines the max tokens
if using chat or completion models. Default: 256 |
+| model.options.input_cost | No | Number | Cost per 1M tokens in
your prompt. Minimum: 0 |
+| model.options.output_cost | No | Number | Cost per 1M tokens in
the output of the AI. Minimum: 0 |
+| model.options.temperature | No | Number | Matching temperature
for models. Range: 0.0 - 5.0 |
+| model.options.top_p | No | Number | Top-p probability
mass. Range: 0 - 1 |
+| model.options.stream | No | Boolean | Stream response by
SSE. Default: false |
+| model.override.endpoint | No | String | Override the endpoint
of the AI provider |
+| passthrough | No | Boolean | If enabled, the
response from LLM will be sent to the upstream. Default: false |
+| timeout | No | Integer | Timeout in
milliseconds for requests to LLM. Range: 1 - 60000. Default: 3000 |
+| keepalive | No | Boolean | Enable keepalive for
requests to LLM. Default: true |
+| keepalive_timeout | No | Integer | Keepalive timeout in
milliseconds for requests to LLM. Minimum: 1000. Default: 60000 |
+| keepalive_pool | No | Integer | Keepalive pool size
for requests to LLM. Minimum: 1. Default: 30 |
+| ssl_verify | No | Boolean | SSL verification for
requests to LLM. Default: true |
+
+## Example usage
+
+Create a route with the `ai-proxy` plugin like so:
+
+```shell
+curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \
+ -H "X-API-KEY: ${ADMIN_API_KEY}" \
+ -d '{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "header": {
+ "Authorization": "Bearer <some-token>"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "gpt-4",
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ }
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "somerandom.com:443": 1
+ },
+ "scheme": "https",
+ "pass_host": "node"
+ }
+ }'
+```
+
+Since `passthrough` is not enabled upstream node can be any arbitrary value
because it won't be contacted.
+
+Now send a request:
+
+```shell
+curl http://127.0.0.1:9080/anything -i -XPOST -H 'Content-Type:
application/json' -d '{
+ "messages": [
+ { "role": "system", "content": "You are a mathematician" },
+ { "role": "user", "a": 1, "content": "What is 1+1?" }
+ ]
+ }'
+```
+
+You will receive a response like this:
+
+```json
+{
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "content": "The sum of \\(1 + 1\\) is \\(2\\).",
+ "role": "assistant"
+ }
+ }
+ ],
+ "created": 1723777034,
+ "id": "chatcmpl-9whRKFodKl5sGhOgHIjWltdeB8sr7",
+ "model": "gpt-4o-2024-05-13",
+ "object": "chat.completion",
+ "system_fingerprint": "fp_abc28019ad",
+ "usage": { "completion_tokens": 15, "prompt_tokens": 23, "total_tokens": 38 }
+}
+```
diff --git a/t/admin/plugins.t b/t/admin/plugins.t
index ef43ea9f3..bf3d485e8 100644
--- a/t/admin/plugins.t
+++ b/t/admin/plugins.t
@@ -102,6 +102,7 @@ api-breaker
limit-conn
limit-count
limit-req
+ai-proxy
gzip
server-info
traffic-split
diff --git a/t/assets/ai-proxy-response.json b/t/assets/ai-proxy-response.json
new file mode 100644
index 000000000..94665e5ea
--- /dev/null
+++ b/t/assets/ai-proxy-response.json
@@ -0,0 +1,15 @@
+{
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": { "content": "1 + 1 = 2.", "role": "assistant" }
+ }
+ ],
+ "created": 1723780938,
+ "id": "chatcmpl-9wiSIg5LYrrpxwsr2PubSQnbtod1P",
+ "model": "gpt-4o-2024-05-13",
+ "object": "chat.completion",
+ "system_fingerprint": "fp_abc28019ad",
+ "usage": { "completion_tokens": 8, "prompt_tokens": 23, "total_tokens": 31 }
+}
diff --git a/t/plugin/ai-proxy.t b/t/plugin/ai-proxy.t
new file mode 100644
index 000000000..445e406f6
--- /dev/null
+++ b/t/plugin/ai-proxy.t
@@ -0,0 +1,693 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+ my ($block) = @_;
+
+ if (!defined $block->request) {
+ $block->set_value("request", "GET /t");
+ }
+
+ my $http_config = $block->http_config // <<_EOC_;
+ server {
+ server_name openai;
+ listen 6724;
+
+ default_type 'application/json';
+
+ location /anything {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+ ngx.req.read_body()
+ local body = ngx.req.get_body_data()
+
+ if body ~= "SELECT * FROM STUDENTS" then
+ ngx.status = 503
+ ngx.say("passthrough doesn't work")
+ return
+ end
+ ngx.say('{"foo", "bar"}')
+ }
+ }
+
+ location /v1/chat/completions {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ local test_type = ngx.req.get_headers()["test-type"]
+ if test_type == "options" then
+ if body.foo == "bar" then
+ ngx.status = 200
+ ngx.say("options works")
+ else
+ ngx.status = 500
+ ngx.say("model options feature doesn't work")
+ end
+ return
+ end
+
+ local header_auth = ngx.req.get_headers()["authorization"]
+ local query_auth = ngx.req.get_uri_args()["apikey"]
+
+ if header_auth ~= "Bearer token" and query_auth ~=
"apikey" then
+ ngx.status = 401
+ ngx.say("Unauthorized")
+ return
+ end
+
+ if header_auth == "Bearer token" or query_auth == "apikey"
then
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ if not body.messages or #body.messages < 1 then
+ ngx.status = 400
+ ngx.say([[{ "error": "bad request"}]])
+ return
+ end
+
+ if body.messages[1].content == "write an SQL query to
get all rows from student table" then
+ ngx.print("SELECT * FROM STUDENTS")
+ return
+ end
+
+ ngx.status = 200
+ ngx.say([[$resp]])
+ return
+ end
+
+
+ ngx.status = 503
+ ngx.say("reached the end of the test suite")
+ }
+ }
+
+ location /random {
+ content_by_lua_block {
+ ngx.say("path override works")
+ }
+ }
+ }
+_EOC_
+
+ $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: minimal viable configuration
+--- config
+ location /t {
+ content_by_lua_block {
+ local plugin = require("apisix.plugins.ai-proxy")
+ local ok, err = plugin.check_schema({
+ model = {
+ provider = "openai",
+ name = "gpt-4",
+ },
+ auth = {
+ header = {
+ some_header = "some_value"
+ }
+ }
+ })
+
+ if not ok then
+ ngx.say(err)
+ else
+ ngx.say("passed")
+ end
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 2: unsupported provider
+--- config
+ location /t {
+ content_by_lua_block {
+ local plugin = require("apisix.plugins.ai-proxy")
+ local ok, err = plugin.check_schema({
+ model = {
+ provider = "some-unique",
+ name = "gpt-4",
+ },
+ auth = {
+ header = {
+ some_header = "some_value"
+ }
+ }
+ })
+
+ if not ok then
+ ngx.say(err)
+ else
+ ngx.say("passed")
+ end
+ }
+ }
+--- response_body eval
+qr/.*provider: some-unique is not supported.*/
+
+
+
+=== TEST 3: set route with wrong auth header
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "header": {
+ "Authorization": "Bearer wrongtoken"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "gpt-35-turbo-instruct",
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ },
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 4: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 401
+--- response_body
+Unauthorized
+
+
+
+=== TEST 5: set route with right auth header
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "gpt-35-turbo-instruct",
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ },
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 6: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body eval
+qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/
+
+
+
+=== TEST 7: send request with empty body
+--- request
+POST /anything
+--- more_headers
+Authorization: Bearer token
+--- error_code: 400
+--- response_body_chomp
+failed to get request body: request body is empty
+
+
+
+=== TEST 8: send request with wrong method (GET) should work
+--- request
+GET /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body eval
+qr/\{ "content": "1 \+ 1 = 2\.", "role": "assistant" \}/
+
+
+
+=== TEST 9: wrong JSON in request body should give error
+--- request
+GET /anything
+{}"messages": [ { "role": "system", "cont
+--- error_code: 400
+--- response_body
+{"message":"could not get parse JSON request body: Expected the end but found
T_STRING at character 3"}
+
+
+
+=== TEST 10: content-type should be JSON
+--- request
+POST /anything
+prompt%3Dwhat%2520is%25201%2520%252B%25201
+--- more_headers
+Content-Type: application/x-www-form-urlencoded
+--- error_code: 400
+--- response_body chomp
+unsupported content-type: application/x-www-form-urlencoded
+
+
+
+=== TEST 11: request schema validity check
+--- request
+POST /anything
+{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 400
+--- response_body chomp
+request format doesn't match schema: property "messages" is required
+
+
+
+=== TEST 12: model options being merged to request body
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "some-model",
+ "options": {
+ "foo": "bar",
+ "temperature": 1.0
+ }
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ },
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+
+ local code, body, actual_body = t("/anything",
+ ngx.HTTP_POST,
+ [[{
+ "messages": [
+ { "role": "system", "content": "You are a
mathematician" },
+ { "role": "user", "content": "What is 1+1?" }
+ ]
+ }]],
+ nil,
+ {
+ ["test-type"] = "options",
+ ["Content-Type"] = "application/json",
+ }
+ )
+
+ ngx.status = code
+ ngx.say(actual_body)
+
+ }
+ }
+--- error_code: 200
+--- response_body_chomp
+options_works
+
+
+
+=== TEST 13: override path
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "some-model",
+ "options": {
+ "foo": "bar",
+ "temperature": 1.0
+ }
+ },
+ "override": {
+ "endpoint": "http://localhost:6724/random"
+ },
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+
+ local code, body, actual_body = t("/anything",
+ ngx.HTTP_POST,
+ [[{
+ "messages": [
+ { "role": "system", "content": "You are a
mathematician" },
+ { "role": "user", "content": "What is 1+1?" }
+ ]
+ }]],
+ nil,
+ {
+ ["test-type"] = "path",
+ ["Content-Type"] = "application/json",
+ }
+ )
+
+ ngx.status = code
+ ngx.say(actual_body)
+
+ }
+ }
+--- response_body_chomp
+path override works
+
+
+
+=== TEST 14: set route with right auth header
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "gpt-35-turbo-instruct",
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ },
+ "ssl_verify": false,
+ "passthrough": true
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "127.0.0.1:6724": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 15: send request with wrong method should work
+--- request
+POST /anything
+{ "messages": [ { "role": "user", "content": "write an SQL query to get all
rows from student table" } ] }
+--- more_headers
+Authorization: Bearer token
+--- error_code: 200
+--- response_body
+{"foo", "bar"}
+
+
+
+=== TEST 16: set route with stream = true (SSE)
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "gpt-35-turbo-instruct",
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0,
+ "stream": true
+ }
+ },
+ "override": {
+ "endpoint": "http://localhost:7737"
+ },
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 17: test is SSE works as expected
+--- config
+ location /t {
+ content_by_lua_block {
+ local http = require("resty.http")
+ local httpc = http.new()
+ local core = require("apisix.core")
+
+ local ok, err = httpc:connect({
+ scheme = "http",
+ host = "localhost",
+ port = ngx.var.server_port,
+ })
+
+ if not ok then
+ ngx.status = 500
+ ngx.say(err)
+ return
+ end
+
+ local params = {
+ method = "POST",
+ headers = {
+ ["Content-Type"] = "application/json",
+ },
+ path = "/anything",
+ body = [[{
+ "messages": [
+ { "role": "system", "content": "some content" }
+ ]
+ }]],
+ }
+
+ local res, err = httpc:request(params)
+ if not res then
+ ngx.status = 500
+ ngx.say(err)
+ return
+ end
+
+ local final_res = {}
+ while true do
+ local chunk, err = res.body_reader() -- will read chunk by
chunk
+ if err then
+ core.log.error("failed to read response chunk: ", err)
+ break
+ end
+ if not chunk then
+ break
+ end
+ core.table.insert_tail(final_res, chunk)
+ end
+
+ ngx.print(#final_res .. final_res[6])
+ }
+ }
+--- response_body_like eval
+qr/6data: \[DONE\]\n\n/
diff --git a/t/plugin/ai-proxy2.t b/t/plugin/ai-proxy2.t
new file mode 100644
index 000000000..6e398e566
--- /dev/null
+++ b/t/plugin/ai-proxy2.t
@@ -0,0 +1,200 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+use t::APISIX 'no_plan';
+
+log_level("info");
+repeat_each(1);
+no_long_string();
+no_root_location();
+
+
+my $resp_file = 't/assets/ai-proxy-response.json';
+open(my $fh, '<', $resp_file) or die "Could not open file '$resp_file' $!";
+my $resp = do { local $/; <$fh> };
+close($fh);
+
+print "Hello, World!\n";
+print $resp;
+
+
+add_block_preprocessor(sub {
+ my ($block) = @_;
+
+ if (!defined $block->request) {
+ $block->set_value("request", "GET /t");
+ }
+
+ my $http_config = $block->http_config // <<_EOC_;
+ server {
+ server_name openai;
+ listen 6724;
+
+ default_type 'application/json';
+
+ location /v1/chat/completions {
+ content_by_lua_block {
+ local json = require("cjson.safe")
+
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("Unsupported request method: ",
ngx.req.get_method())
+ end
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ body, err = json.decode(body)
+
+ local query_auth = ngx.req.get_uri_args()["api_key"]
+
+ if query_auth ~= "apikey" then
+ ngx.status = 401
+ ngx.say("Unauthorized")
+ return
+ end
+
+
+ ngx.status = 200
+ ngx.say("passed")
+ }
+ }
+ }
+_EOC_
+
+ $block->set_value("http_config", $http_config);
+});
+
+run_tests();
+
+__DATA__
+
+=== TEST 1: set route with wrong query param
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "query": {
+ "api_key": "wrong_key"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "gpt-35-turbo-instruct",
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ },
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 2: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 401
+--- response_body
+Unauthorized
+
+
+
+=== TEST 3: set route with right query param
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/anything",
+ "plugins": {
+ "ai-proxy": {
+ "auth": {
+ "query": {
+ "api_key": "apikey"
+ }
+ },
+ "model": {
+ "provider": "openai",
+ "name": "gpt-35-turbo-instruct",
+ "options": {
+ "max_tokens": 512,
+ "temperature": 1.0
+ }
+ },
+ "override": {
+ "endpoint": "http://localhost:6724"
+ },
+ "ssl_verify": false
+ }
+ },
+ "upstream": {
+ "type": "roundrobin",
+ "nodes": {
+ "canbeanything.com": 1
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ end
+ ngx.say(body)
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 4: send request
+--- request
+POST /anything
+{ "messages": [ { "role": "system", "content": "You are a mathematician" }, {
"role": "user", "content": "What is 1+1?"} ] }
+--- error_code: 200
+--- response_body
+passed
diff --git a/t/sse_server_example/go.mod b/t/sse_server_example/go.mod
new file mode 100644
index 000000000..9cc909d03
--- /dev/null
+++ b/t/sse_server_example/go.mod
@@ -0,0 +1,3 @@
+module foo.bar/apache/sse_server_example
+
+go 1.17
diff --git a/t/sse_server_example/main.go b/t/sse_server_example/main.go
new file mode 100644
index 000000000..ab976c860
--- /dev/null
+++ b/t/sse_server_example/main.go
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package main
+
+import (
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+ "time"
+)
+
+func sseHandler(w http.ResponseWriter, r *http.Request) {
+ // Set the headers for SSE
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+
+ f, ok := w.(http.Flusher);
+ if !ok {
+ fmt.Fprintf(w, "[ERROR]")
+ return
+ }
+ // A simple loop that sends a message every 500ms
+ for i := 0; i < 5; i++ {
+ // Create a message to send to the client
+ fmt.Fprintf(w, "data: %s\n\n", time.Now().Format(time.RFC3339))
+
+ // Flush the data immediately to the client
+ f.Flush()
+ time.Sleep(500 * time.Millisecond)
+ }
+ fmt.Fprintf(w, "data: %s\n\n", "[DONE]")
+}
+
+func main() {
+ // Create a simple route
+ http.HandleFunc("/v1/chat/completions", sseHandler)
+ port := os.Args[1]
+ // Start the server
+ log.Println("Starting server on :", port)
+ log.Fatal(http.ListenAndServe(":" + port, nil))
+}