This is an automated email from the ASF dual-hosted git repository.
shreemaanabhishek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git
The following commit(s) were added to refs/heads/master by this push:
new 0efdb8e5a feat(ai-proxy): support embeddings API (#12062)
0efdb8e5a is described below
commit 0efdb8e5af47b11c4d4947f3b2430274660e26de
Author: Shreemaan Abhishek <[email protected]>
AuthorDate: Wed Mar 19 11:42:04 2025 +0545
feat(ai-proxy): support embeddings API (#12062)
---
apisix/plugins/ai-drivers/openai-base.lua | 6 --
apisix/plugins/ai-proxy/schema.lua | 37 +-------
docs/en/latest/plugins/ai-proxy-multi.md | 6 --
docs/en/latest/plugins/ai-proxy.md | 6 --
t/plugin/ai-proxy-multi.t | 20 +----
t/plugin/ai-proxy.t | 135 ++++++++++++++++++++++++++----
6 files changed, 124 insertions(+), 86 deletions(-)
diff --git a/apisix/plugins/ai-drivers/openai-base.lua
b/apisix/plugins/ai-drivers/openai-base.lua
index a34261202..a4f061fe4 100644
--- a/apisix/plugins/ai-drivers/openai-base.lua
+++ b/apisix/plugins/ai-drivers/openai-base.lua
@@ -25,7 +25,6 @@ local CONTENT_TYPE_JSON = "application/json"
local core = require("apisix.core")
local http = require("resty.http")
local url = require("socket.url")
-local schema = require("apisix.plugins.ai-drivers.schema")
local ngx_re = require("ngx.re")
local ngx_print = ngx.print
@@ -59,11 +58,6 @@ function _M.validate_request(ctx)
return nil, err
end
- local ok, err = core.schema.check(schema.chat_request_schema,
request_table)
- if not ok then
- return nil, "request format doesn't match schema: " .. err
- end
-
return request_table, nil
end
diff --git a/apisix/plugins/ai-proxy/schema.lua
b/apisix/plugins/ai-proxy/schema.lua
index 1b9d07b1c..1bd44da04 100644
--- a/apisix/plugins/ai-proxy/schema.lua
+++ b/apisix/plugins/ai-proxy/schema.lua
@@ -42,43 +42,8 @@ local model_options_schema = {
type = "string",
description = "Model to execute.",
},
- max_tokens = {
- type = "integer",
- description = "Defines the max_tokens, if using chat or completion
models.",
- default = 256
-
- },
- input_cost = {
- type = "number",
- description = "Defines the cost per 1M tokens in your prompt.",
- minimum = 0
-
- },
- output_cost = {
- type = "number",
- description = "Defines the cost per 1M tokens in the output of the
AI.",
- minimum = 0
-
- },
- temperature = {
- type = "number",
- description = "Defines the matching temperature, if using chat or
completion models.",
- minimum = 0.0,
- maximum = 5.0,
-
- },
- top_p = {
- type = "number",
- description = "Defines the top-p probability mass, if supported.",
- minimum = 0,
- maximum = 1,
-
- },
- stream = {
- description = "Stream response by SSE",
- type = "boolean",
- }
},
+ additionalProperties = true,
}
local ai_instance_schema = {
diff --git a/docs/en/latest/plugins/ai-proxy-multi.md
b/docs/en/latest/plugins/ai-proxy-multi.md
index f977f85db..a23eccb55 100644
--- a/docs/en/latest/plugins/ai-proxy-multi.md
+++ b/docs/en/latest/plugins/ai-proxy-multi.md
@@ -63,12 +63,6 @@ Proxying requests to OpenAI is supported now. Other LLM
services will be support
| provider.auth | Yes | object | Authentication
details, including headers and query parameters.
| |
| provider.auth.header | No | object | Authentication
details sent via headers. Header name must match `^[a-zA-Z0-9._-]+$`.
| |
| provider.auth.query | No | object | Authentication
details sent via query parameters. Keys must match `^[a-zA-Z0-9._-]+$`.
| |
-| provider.options.max_tokens | No | integer | Defines the maximum
tokens for chat or completion models.
| 256 |
-| provider.options.input_cost | No | number | Cost per 1M tokens
in the input prompt. Minimum is 0.
| |
-| provider.options.output_cost | No | number | Cost per 1M tokens
in the AI-generated output. Minimum is 0.
| |
-| provider.options.temperature | No | number | Defines the model's
temperature (0.0 - 5.0) for randomness in responses.
| |
-| provider.options.top_p | No | number | Defines the top-p
probability mass (0 - 1) for nucleus sampling.
| |
-| provider.options.stream | No | boolean | Enables streaming
responses via SSE.
| |
| provider.override.endpoint | No | string | Custom host
override for the AI provider.
| |
| timeout | No | integer | Request timeout in
milliseconds (1-60000).
| 30000 |
| keepalive | No | boolean | Enables keepalive
connections.
| true |
diff --git a/docs/en/latest/plugins/ai-proxy.md
b/docs/en/latest/plugins/ai-proxy.md
index 66740cf7e..fe6bb2f6b 100644
--- a/docs/en/latest/plugins/ai-proxy.md
+++ b/docs/en/latest/plugins/ai-proxy.md
@@ -56,12 +56,6 @@ Proxying requests to OpenAI is supported now. Other LLM
services will be support
| model.provider | Yes | String | Name of the AI service
provider (`openai`). |
| model.name | Yes | String | Model name to execute.
|
| model.options | No | Object | Key/value settings for
the model |
-| model.options.max_tokens | No | Integer | Defines the max tokens
if using chat or completion models. Default: 256 |
-| model.options.input_cost | No | Number | Cost per 1M tokens in
your prompt. Minimum: 0 |
-| model.options.output_cost | No | Number | Cost per 1M tokens in
the output of the AI. Minimum: 0 |
-| model.options.temperature | No | Number | Matching temperature
for models. Range: 0.0 - 5.0 |
-| model.options.top_p | No | Number | Top-p probability
mass. Range: 0 - 1 |
-| model.options.stream | No | Boolean | Stream response by
SSE. |
| override.endpoint | No | String | Override the endpoint
of the AI provider |
| timeout | No | Integer | Timeout in
milliseconds for requests to LLM. Range: 1 - 60000. Default: 30000 |
| keepalive | No | Boolean | Enable keepalive for
requests to LLM. Default: true |
diff --git a/t/plugin/ai-proxy-multi.t b/t/plugin/ai-proxy-multi.t
index a70392e60..83f3444cf 100644
--- a/t/plugin/ai-proxy-multi.t
+++ b/t/plugin/ai-proxy-multi.t
@@ -360,19 +360,7 @@ unsupported content-type:
application/x-www-form-urlencoded, only application/js
-=== TEST 11: request schema validity check
---- request
-POST /anything
-{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
---- more_headers
-Authorization: Bearer token
---- error_code: 400
---- response_body chomp
-request format doesn't match schema: property "messages" is required
-
-
-
-=== TEST 12: model options being merged to request body
+=== TEST 11: model options being merged to request body
--- config
location /t {
content_by_lua_block {
@@ -441,7 +429,7 @@ options_works
-=== TEST 13: override path
+=== TEST 12: override path
--- config
location /t {
content_by_lua_block {
@@ -509,7 +497,7 @@ path override works
-=== TEST 14: set route with stream = true (SSE)
+=== TEST 13: set route with stream = true (SSE)
--- config
location /t {
content_by_lua_block {
@@ -558,7 +546,7 @@ passed
-=== TEST 15: test is SSE works as expected
+=== TEST 14: test is SSE works as expected
--- config
location /t {
content_by_lua_block {
diff --git a/t/plugin/ai-proxy.t b/t/plugin/ai-proxy.t
index c5696a2fb..c99a6c11e 100644
--- a/t/plugin/ai-proxy.t
+++ b/t/plugin/ai-proxy.t
@@ -106,6 +106,66 @@ add_block_preprocessor(sub {
}
}
+ location /v1/embeddings {
+ content_by_lua_block {
+ if ngx.req.get_method() ~= "POST" then
+ ngx.status = 400
+ ngx.say("unsupported request method: ",
ngx.req.get_method())
+ end
+
+ local header_auth = ngx.req.get_headers()["authorization"]
+ if header_auth ~= "Bearer token" then
+ ngx.status = 401
+ ngx.say("unauthorized")
+ return
+ end
+
+ ngx.req.read_body()
+ local body, err = ngx.req.get_body_data()
+ local json = require("cjson.safe")
+ body, err = json.decode(body)
+ if err then
+ ngx.status = 400
+ ngx.say("failed to get request body: ", err)
+ end
+
+ if body.model ~= "text-embedding-ada-002" then
+ ngx.status = 400
+ ngx.say("unsupported model: ", body.model)
+ return
+ end
+
+ if body.encoding_format ~= "float" then
+ ngx.status = 400
+ ngx.say("unsupported encoding format: ",
body.encoding_format)
+ return
+ end
+
+ ngx.status = 200
+ ngx.say([[
+ {
+ "object": "list",
+ "data": [
+ {
+ "object": "embedding",
+ "embedding": [
+ 0.0023064255,
+ -0.009327292,
+ -0.0028842222
+ ],
+ "index": 0
+ }
+ ],
+ "model": "text-embedding-ada-002",
+ "usage": {
+ "prompt_tokens": 8,
+ "total_tokens": 8
+ }
+ }
+ ]])
+ }
+ }
+
location /random {
content_by_lua_block {
ngx.say("path override works")
@@ -330,19 +390,7 @@ unsupported content-type:
application/x-www-form-urlencoded, only application/js
-=== TEST 11: request schema validity check
---- request
-POST /anything
-{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
---- more_headers
-Authorization: Bearer token
---- error_code: 400
---- response_body chomp
-request format doesn't match schema: property "messages" is required
-
-
-
-=== TEST 12: model options being merged to request body
+=== TEST 11: model options being merged to request body
--- config
location /t {
content_by_lua_block {
@@ -405,7 +453,7 @@ options_works
-=== TEST 13: override path
+=== TEST 12: override path
--- config
location /t {
content_by_lua_block {
@@ -467,7 +515,7 @@ path override works
-=== TEST 14: set route with stream = true (SSE)
+=== TEST 13: set route with stream = true (SSE)
--- config
location /t {
content_by_lua_block {
@@ -510,7 +558,7 @@ passed
-=== TEST 15: test is SSE works as expected
+=== TEST 14: test is SSE works as expected
--- config
location /t {
content_by_lua_block {
@@ -568,3 +616,58 @@ passed
}
--- response_body_like eval
qr/6data: \[DONE\]\n\n/
+
+
+
+=== TEST 15: proxy embedding endpoint
+--- config
+ location /t {
+ content_by_lua_block {
+ local t = require("lib.test_admin").test
+ local code, body = t('/apisix/admin/routes/1',
+ ngx.HTTP_PUT,
+ [[{
+ "uri": "/embeddings",
+ "plugins": {
+ "ai-proxy": {
+ "provider": "openai",
+ "auth": {
+ "header": {
+ "Authorization": "Bearer token"
+ }
+ },
+ "options": {
+ "model": "text-embedding-ada-002",
+ "encoding_format": "float"
+ },
+ "override": {
+ "endpoint":
"http://localhost:6724/v1/embeddings"
+ }
+ }
+ }
+ }]]
+ )
+
+ if code >= 300 then
+ ngx.status = code
+ ngx.say(body)
+ return
+ end
+
+ ngx.say("passed")
+ }
+ }
+--- response_body
+passed
+
+
+
+=== TEST 16: send request to embedding api
+--- request
+POST /embeddings
+{
+ "input": "The food was delicious and the waiter..."
+}
+--- error_code: 200
+--- response_body_like eval
+qr/.*text-embedding-ada-002*/