This is an automated email from the ASF dual-hosted git repository.

shreemaanabhishek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/apisix.git


The following commit(s) were added to refs/heads/master by this push:
     new 0efdb8e5a feat(ai-proxy): support embeddings API (#12062)
0efdb8e5a is described below

commit 0efdb8e5af47b11c4d4947f3b2430274660e26de
Author: Shreemaan Abhishek <[email protected]>
AuthorDate: Wed Mar 19 11:42:04 2025 +0545

    feat(ai-proxy): support embeddings API (#12062)
---
 apisix/plugins/ai-drivers/openai-base.lua |   6 --
 apisix/plugins/ai-proxy/schema.lua        |  37 +-------
 docs/en/latest/plugins/ai-proxy-multi.md  |   6 --
 docs/en/latest/plugins/ai-proxy.md        |   6 --
 t/plugin/ai-proxy-multi.t                 |  20 +----
 t/plugin/ai-proxy.t                       | 135 ++++++++++++++++++++++++++----
 6 files changed, 124 insertions(+), 86 deletions(-)

diff --git a/apisix/plugins/ai-drivers/openai-base.lua 
b/apisix/plugins/ai-drivers/openai-base.lua
index a34261202..a4f061fe4 100644
--- a/apisix/plugins/ai-drivers/openai-base.lua
+++ b/apisix/plugins/ai-drivers/openai-base.lua
@@ -25,7 +25,6 @@ local CONTENT_TYPE_JSON = "application/json"
 local core = require("apisix.core")
 local http = require("resty.http")
 local url  = require("socket.url")
-local schema = require("apisix.plugins.ai-drivers.schema")
 local ngx_re = require("ngx.re")
 
 local ngx_print = ngx.print
@@ -59,11 +58,6 @@ function _M.validate_request(ctx)
             return nil, err
         end
 
-        local ok, err = core.schema.check(schema.chat_request_schema, 
request_table)
-        if not ok then
-            return nil, "request format doesn't match schema: " .. err
-        end
-
         return request_table, nil
 end
 
diff --git a/apisix/plugins/ai-proxy/schema.lua 
b/apisix/plugins/ai-proxy/schema.lua
index 1b9d07b1c..1bd44da04 100644
--- a/apisix/plugins/ai-proxy/schema.lua
+++ b/apisix/plugins/ai-proxy/schema.lua
@@ -42,43 +42,8 @@ local model_options_schema = {
             type = "string",
             description = "Model to execute.",
         },
-        max_tokens = {
-            type = "integer",
-            description = "Defines the max_tokens, if using chat or completion 
models.",
-            default = 256
-
-        },
-        input_cost = {
-            type = "number",
-            description = "Defines the cost per 1M tokens in your prompt.",
-            minimum = 0
-
-        },
-        output_cost = {
-            type = "number",
-            description = "Defines the cost per 1M tokens in the output of the 
AI.",
-            minimum = 0
-
-        },
-        temperature = {
-            type = "number",
-            description = "Defines the matching temperature, if using chat or 
completion models.",
-            minimum = 0.0,
-            maximum = 5.0,
-
-        },
-        top_p = {
-            type = "number",
-            description = "Defines the top-p probability mass, if supported.",
-            minimum = 0,
-            maximum = 1,
-
-        },
-        stream = {
-            description = "Stream response by SSE",
-            type = "boolean",
-        }
     },
+    additionalProperties = true,
 }
 
 local ai_instance_schema = {
diff --git a/docs/en/latest/plugins/ai-proxy-multi.md 
b/docs/en/latest/plugins/ai-proxy-multi.md
index f977f85db..a23eccb55 100644
--- a/docs/en/latest/plugins/ai-proxy-multi.md
+++ b/docs/en/latest/plugins/ai-proxy-multi.md
@@ -63,12 +63,6 @@ Proxying requests to OpenAI is supported now. Other LLM 
services will be support
 | provider.auth                | Yes          | object   | Authentication 
details, including headers and query parameters.                                
               |             |
 | provider.auth.header         | No           | object   | Authentication 
details sent via headers. Header name must match `^[a-zA-Z0-9._-]+$`.           
               |             |
 | provider.auth.query          | No           | object   | Authentication 
details sent via query parameters. Keys must match `^[a-zA-Z0-9._-]+$`.         
               |             |
-| provider.options.max_tokens  | No           | integer  | Defines the maximum 
tokens for chat or completion models.                                           
          | 256         |
-| provider.options.input_cost  | No           | number   | Cost per 1M tokens 
in the input prompt. Minimum is 0.                                              
           |             |
-| provider.options.output_cost | No           | number   | Cost per 1M tokens 
in the AI-generated output. Minimum is 0.                                       
           |             |
-| provider.options.temperature | No           | number   | Defines the model's 
temperature (0.0 - 5.0) for randomness in responses.                            
          |             |
-| provider.options.top_p       | No           | number   | Defines the top-p 
probability mass (0 - 1) for nucleus sampling.                                  
            |             |
-| provider.options.stream      | No           | boolean  | Enables streaming 
responses via SSE.                                                              
            |             |
 | provider.override.endpoint   | No           | string   | Custom host 
override for the AI provider.                                                   
                  |             |
 | timeout                      | No           | integer  | Request timeout in 
milliseconds (1-60000).                                                         
           | 30000        |
 | keepalive                    | No           | boolean  | Enables keepalive 
connections.                                                                    
            | true        |
diff --git a/docs/en/latest/plugins/ai-proxy.md 
b/docs/en/latest/plugins/ai-proxy.md
index 66740cf7e..fe6bb2f6b 100644
--- a/docs/en/latest/plugins/ai-proxy.md
+++ b/docs/en/latest/plugins/ai-proxy.md
@@ -56,12 +56,6 @@ Proxying requests to OpenAI is supported now. Other LLM 
services will be support
 | model.provider            | Yes          | String   | Name of the AI service 
provider (`openai`).                                          |
 | model.name                | Yes          | String   | Model name to execute. 
                                                              |
 | model.options             | No           | Object   | Key/value settings for 
the model                                                     |
-| model.options.max_tokens  | No           | Integer  | Defines the max tokens 
if using chat or completion models. Default: 256              |
-| model.options.input_cost  | No           | Number   | Cost per 1M tokens in 
your prompt. Minimum: 0                                        |
-| model.options.output_cost | No           | Number   | Cost per 1M tokens in 
the output of the AI. Minimum: 0                               |
-| model.options.temperature | No           | Number   | Matching temperature 
for models. Range: 0.0 - 5.0                                    |
-| model.options.top_p       | No           | Number   | Top-p probability 
mass. Range: 0 - 1                                                 |
-| model.options.stream      | No           | Boolean  | Stream response by 
SSE.                                                              |
 | override.endpoint         | No           | String   | Override the endpoint 
of the AI provider                                             |
 | timeout                   | No           | Integer  | Timeout in 
milliseconds for requests to LLM. Range: 1 - 60000. Default: 30000         |
 | keepalive                 | No           | Boolean  | Enable keepalive for 
requests to LLM. Default: true                                  |
diff --git a/t/plugin/ai-proxy-multi.t b/t/plugin/ai-proxy-multi.t
index a70392e60..83f3444cf 100644
--- a/t/plugin/ai-proxy-multi.t
+++ b/t/plugin/ai-proxy-multi.t
@@ -360,19 +360,7 @@ unsupported content-type: 
application/x-www-form-urlencoded, only application/js
 
 
 
-=== TEST 11: request schema validity check
---- request
-POST /anything
-{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
---- more_headers
-Authorization: Bearer token
---- error_code: 400
---- response_body chomp
-request format doesn't match schema: property "messages" is required
-
-
-
-=== TEST 12: model options being merged to request body
+=== TEST 11: model options being merged to request body
 --- config
     location /t {
         content_by_lua_block {
@@ -441,7 +429,7 @@ options_works
 
 
 
-=== TEST 13: override path
+=== TEST 12: override path
 --- config
     location /t {
         content_by_lua_block {
@@ -509,7 +497,7 @@ path override works
 
 
 
-=== TEST 14: set route with stream = true (SSE)
+=== TEST 13: set route with stream = true (SSE)
 --- config
     location /t {
         content_by_lua_block {
@@ -558,7 +546,7 @@ passed
 
 
 
-=== TEST 15: test is SSE works as expected
+=== TEST 14: test is SSE works as expected
 --- config
     location /t {
         content_by_lua_block {
diff --git a/t/plugin/ai-proxy.t b/t/plugin/ai-proxy.t
index c5696a2fb..c99a6c11e 100644
--- a/t/plugin/ai-proxy.t
+++ b/t/plugin/ai-proxy.t
@@ -106,6 +106,66 @@ add_block_preprocessor(sub {
                 }
             }
 
+            location /v1/embeddings {
+                content_by_lua_block {
+                    if ngx.req.get_method() ~= "POST" then
+                        ngx.status = 400
+                        ngx.say("unsupported request method: ", 
ngx.req.get_method())
+                    end
+
+                    local header_auth = ngx.req.get_headers()["authorization"]
+                    if header_auth ~= "Bearer token" then
+                        ngx.status = 401
+                        ngx.say("unauthorized")
+                        return
+                    end
+
+                    ngx.req.read_body()
+                    local body, err = ngx.req.get_body_data()
+                    local json = require("cjson.safe")
+                    body, err = json.decode(body)
+                    if err then
+                        ngx.status = 400
+                        ngx.say("failed to get request body: ", err)
+                    end
+
+                    if body.model ~= "text-embedding-ada-002" then
+                        ngx.status = 400
+                        ngx.say("unsupported model: ", body.model)
+                        return
+                    end
+
+                    if body.encoding_format ~= "float" then
+                        ngx.status = 400
+                        ngx.say("unsupported encoding format: ", 
body.encoding_format)
+                        return
+                    end
+
+                    ngx.status = 200
+                    ngx.say([[
+                        {
+                          "object": "list",
+                          "data": [
+                            {
+                              "object": "embedding",
+                              "embedding": [
+                                0.0023064255,
+                                -0.009327292,
+                                -0.0028842222
+                              ],
+                              "index": 0
+                            }
+                          ],
+                          "model": "text-embedding-ada-002",
+                          "usage": {
+                            "prompt_tokens": 8,
+                            "total_tokens": 8
+                          }
+                        }
+                    ]])
+                }
+            }
+
             location /random {
                 content_by_lua_block {
                     ngx.say("path override works")
@@ -330,19 +390,7 @@ unsupported content-type: 
application/x-www-form-urlencoded, only application/js
 
 
 
-=== TEST 11: request schema validity check
---- request
-POST /anything
-{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
---- more_headers
-Authorization: Bearer token
---- error_code: 400
---- response_body chomp
-request format doesn't match schema: property "messages" is required
-
-
-
-=== TEST 12: model options being merged to request body
+=== TEST 11: model options being merged to request body
 --- config
     location /t {
         content_by_lua_block {
@@ -405,7 +453,7 @@ options_works
 
 
 
-=== TEST 13: override path
+=== TEST 12: override path
 --- config
     location /t {
         content_by_lua_block {
@@ -467,7 +515,7 @@ path override works
 
 
 
-=== TEST 14: set route with stream = true (SSE)
+=== TEST 13: set route with stream = true (SSE)
 --- config
     location /t {
         content_by_lua_block {
@@ -510,7 +558,7 @@ passed
 
 
 
-=== TEST 15: test is SSE works as expected
+=== TEST 14: test is SSE works as expected
 --- config
     location /t {
         content_by_lua_block {
@@ -568,3 +616,58 @@ passed
     }
 --- response_body_like eval
 qr/6data: \[DONE\]\n\n/
+
+
+
+=== TEST 15: proxy embedding endpoint
+--- config
+    location /t {
+        content_by_lua_block {
+            local t = require("lib.test_admin").test
+            local code, body = t('/apisix/admin/routes/1',
+                 ngx.HTTP_PUT,
+                 [[{
+                    "uri": "/embeddings",
+                    "plugins": {
+                        "ai-proxy": {
+                            "provider": "openai",
+                            "auth": {
+                                "header": {
+                                    "Authorization": "Bearer token"
+                                }
+                            },
+                            "options": {
+                                "model": "text-embedding-ada-002",
+                                "encoding_format": "float"
+                            },
+                            "override": {
+                                "endpoint": 
"http://localhost:6724/v1/embeddings";
+                            }
+                        }
+                    }
+                }]]
+            )
+
+            if code >= 300 then
+                ngx.status = code
+                ngx.say(body)
+                return
+            end
+
+            ngx.say("passed")
+        }
+    }
+--- response_body
+passed
+
+
+
+=== TEST 16: send request to embedding api
+--- request
+POST /embeddings
+{
+    "input": "The food was delicious and the waiter..."
+}
+--- error_code: 200
+--- response_body_like eval
+qr/.*text-embedding-ada-002*/

Reply via email to