Copilot commented on code in PR #12924:
URL: https://github.com/apache/apisix/pull/12924#discussion_r2706594441
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
Review Comment:
There is an extra space before the closing parenthesis in the require
statement. This should be 'require("resty.http")' without the extra space.
```suggestion
local http = require("resty.http")
```
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
Review Comment:
There is an extra space before the closing parenthesis. This is inconsistent
with the rest of the codebase.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
+
+ if not res then
+ core.log.error("failed to send request to LLM server: ", err)
+ return self:handle_error(err)
+ end
+
+ -- 4. 处理响应
+ local is_stream = request_table.stream
+ local content_type = res.headers["Content-Type"]
+
+ if is_stream and core.string.find(content_type, "text/event-stream") then
+ -- 流式响应处理
+ return self:handle_stream_response(ctx, res, conf)
+ else
+ -- 非流式响应处理
+ return self:handle_non_stream_response(ctx, res, conf)
+ end
Review Comment:
The request method doesn't handle retry scenarios for 429 (rate limit) or
5xx errors. The original openai-base.lua implementation returns the status code
for these errors to allow the ai-proxy base.lua to retry with another instance.
Without this, the retry mechanism in ai-proxy-multi won't work properly.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
+
+ if not res then
+ core.log.error("failed to send request to LLM server: ", err)
+ return self:handle_error(err)
+ end
+
+ -- 4. 处理响应
+ local is_stream = request_table.stream
+ local content_type = res.headers["Content-Type"]
+
+ if is_stream and core.string.find(content_type, "text/event-stream") then
+ -- 流式响应处理
+ return self:handle_stream_response(ctx, res, conf)
+ else
+ -- 非流式响应处理
+ return self:handle_non_stream_response(ctx, res, conf)
+ end
+end
+
+-- 处理非流式响应
+function _M.handle_non_stream_response(self, ctx, res, conf)
+ local raw_res_body = res:read_body()
+ if not raw_res_body then
+ core.log.warn("failed to read response body: ", res.err)
+ return self:handle_error(res.err)
+ end
+
+ -- 协议转换(如果驱动提供了 transform_response)
+ if self.transform_response then
+ raw_res_body = self.transform_response(raw_res_body)
+ end
+
+ -- 设置响应头和状态码
+ core.response.set_header(ctx, "Content-Type", "application/json")
+ core.response.set_status(ctx, res.status)
+ core.response.set_body(ctx, raw_res_body)
+ core.response.send_response(ctx)
+end
+
+-- 处理流式响应
+function _M.handle_stream_response(self, ctx, res, conf)
+ core.response.set_header(ctx, "Content-Type", "text/event-stream")
+ core.response.set_status(ctx, res.status)
+ core.response.send_http_header(ctx )
+
+ local body_reader = res.body_reader
+ local chunk
+ while true do
+ chunk, err = body_reader()
+ if not chunk then
+ break
+ end
+
+ -- 协议转换(如果驱动提供了 process_sse_chunk)
+ if self.process_sse_chunk then
+ chunk = self.process_sse_chunk(chunk)
+ end
+
+ core.response.write(ctx, chunk)
+ end
+
+ if err then
+ core.log.error("failed to read stream body: ", err)
+ end
+
+ core.response.close(ctx)
+end
Review Comment:
The ai-driver-base.lua doesn't set any of the context variables that are
used for metrics and logging, such as ctx.llm_request_start_time,
ctx.var.llm_time_to_first_token, ctx.var.apisix_upstream_response_time,
ctx.ai_token_usage, etc. These are essential for the logging and prometheus
metrics functionality. The original implementation properly tracks all these
metrics.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
Review Comment:
There is an extra space before the closing parenthesis. This is inconsistent
with the rest of the codebase.
```suggestion
local http = require("resty.http")
local url = require("socket.url")
-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
local sse = require("apisix.plugins.ai-drivers.sse")
```
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
Review Comment:
The validate_request function doesn't handle the case where Content-Type
header is nil. The code will fail when trying to call has_prefix on a nil
value. This should check if ct is nil first and provide a default or return an
error.
```suggestion
if not ct or not core.string.has_prefix(ct, "application/json") then
return nil, "unsupported content-type: " .. tostring(ct) .. ", only
application/json is supported"
```
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
+
+ if not res then
+ core.log.error("failed to send request to LLM server: ", err)
+ return self:handle_error(err)
+ end
+
+ -- 4. 处理响应
+ local is_stream = request_table.stream
+ local content_type = res.headers["Content-Type"]
+
+ if is_stream and core.string.find(content_type, "text/event-stream") then
+ -- 流式响应处理
+ return self:handle_stream_response(ctx, res, conf)
+ else
+ -- 非流式响应处理
+ return self:handle_non_stream_response(ctx, res, conf)
+ end
+end
+
+-- 处理非流式响应
+function _M.handle_non_stream_response(self, ctx, res, conf)
+ local raw_res_body = res:read_body()
+ if not raw_res_body then
+ core.log.warn("failed to read response body: ", res.err)
+ return self:handle_error(res.err)
+ end
+
+ -- 协议转换(如果驱动提供了 transform_response)
+ if self.transform_response then
+ raw_res_body = self.transform_response(raw_res_body)
+ end
+
+ -- 设置响应头和状态码
+ core.response.set_header(ctx, "Content-Type", "application/json")
+ core.response.set_status(ctx, res.status)
+ core.response.set_body(ctx, raw_res_body)
+ core.response.send_response(ctx)
+end
+
+-- 处理流式响应
+function _M.handle_stream_response(self, ctx, res, conf)
+ core.response.set_header(ctx, "Content-Type", "text/event-stream")
+ core.response.set_status(ctx, res.status)
+ core.response.send_http_header(ctx )
+
+ local body_reader = res.body_reader
+ local chunk
+ while true do
+ chunk, err = body_reader()
+ if not chunk then
+ break
+ end
+
+ -- 协议转换(如果驱动提供了 process_sse_chunk)
+ if self.process_sse_chunk then
+ chunk = self.process_sse_chunk(chunk)
+ end
+
+ core.response.write(ctx, chunk)
+ end
+
+ if err then
+ core.log.error("failed to read stream body: ", err)
+ end
+
+ core.response.close(ctx)
+end
Review Comment:
The handle_stream_response function uses non-existent APISIX API methods.
Methods like core.response.send_http_header(), core.response.write(), and
core.response.close() don't exist in the APISIX core.response module. The
original implementation uses plugin.lua_response_filter and reads the
body_reader correctly. This needs to follow existing APISIX patterns.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
+
+ if not res then
+ core.log.error("failed to send request to LLM server: ", err)
+ return self:handle_error(err)
+ end
+
+ -- 4. 处理响应
+ local is_stream = request_table.stream
+ local content_type = res.headers["Content-Type"]
+
+ if is_stream and core.string.find(content_type, "text/event-stream") then
+ -- 流式响应处理
+ return self:handle_stream_response(ctx, res, conf)
+ else
+ -- 非流式响应处理
+ return self:handle_non_stream_response(ctx, res, conf)
+ end
Review Comment:
The ai-driver-base.lua doesn't implement connection keepalive which is
important for performance. The original openai-base.lua implementation sets
keepalive based on conf.keepalive settings. Without this, each request will
create a new connection, significantly impacting performance.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
Review Comment:
The ai-driver-base.lua request method doesn't handle extra_opts.endpoint
which allows overriding the API endpoint, and doesn't handle
extra_opts.query_params. The original openai-base.lua implementation supports
these via parsed_url and query parameter merging. This breaks functionality for
users who need custom endpoints or query parameters.
##########
apisix/plugins/ai-drivers/anthropic.lua:
##########
@@ -1,24 +1,97 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-
-return require("apisix.plugins.ai-drivers.openai-base").new(
- {
- host = "api.anthropic.com",
- path = "/v1/chat/completions",
- port = 443
+local core = require("apisix.core")
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+
+-- 将 OpenAI 兼容请求转换为 Anthropic 原生请求
+function _M.transform_request(request_table)
+ local anthropic_request = {
+ model = request_table.model,
+ max_tokens = request_table.max_tokens or 1024,
+ stream = request_table.stream,
}
-)
+
+ local messages = request_table.messages
+ local system_prompt = nil
+ local new_messages = {}
+
+ for _, msg in ipairs(messages) do
+ if msg.role == "system" then
+ -- Anthropic Messages API 支持 system 字段
+ system_prompt = msg.content
+ elseif msg.role == "user" or msg.role == "assistant" then
+ -- 角色映射:OpenAI 的 user/assistant 对应 Anthropic 的 user/assistant
+ table.insert(new_messages, {
+ role = msg.role,
+ content = msg.content
+ })
+ end
+ end
+
+ if system_prompt then
+ anthropic_request.system = system_prompt
+ end
+
+ anthropic_request.messages = new_messages
+
+ -- 【添加日志打印】在返回前打印转换后的请求体,方便我们验证逻辑
+ local core = require("apisix.core")
+ core.log.warn("--- 转换后的 Anthropic 请求体开始 ---")
+ core.log.warn(core.json.encode(anthropic_request))
+ core.log.warn("--- 转换后的 Anthropic 请求体结束 ---")
+
+ return anthropic_request
+end
+
+-- 处理流式响应的 SSE Chunk 转换
+function _M.process_sse_chunk(chunk)
+ local events = sse.decode(chunk)
+ local out = {}
+
+ for _, e in ipairs(events) do
+ if e.type == "message" then
+ local d = core.json.decode(e.data)
+ if d.type == "content_block_delta" then
+ -- 转换为 OpenAI 兼容的流式格式
+ table.insert(out, "data: " .. core.json.encode({
+ choices = {
+ {
+ delta = {
+ content = d.delta.text
+ }
+ }
+ }
+ }) .. "\n")
+ elseif d.type == "message_stop" then
+ table.insert(out, "data: [DONE]\n")
+ end
+ end
+ end
+
+ return table.concat(out)
+end
+
+-- 将 Anthropic 原生响应转换为 OpenAI 兼容响应
+function _M.transform_response(body)
+ local d = core.json.decode(body)
+ return core.json.encode({
+ choices = {
+ {
+ message = {
+ content = d.content[1].text
+ }
+ }
+ }
+ })
+end
+
+-- 导出驱动实例
+return driver_base.new({
+ host = "api.anthropic.com",
+ port = 443,
+ path = "/v1/messages",
+ transform_request = _M.transform_request,
+ transform_response = _M.transform_response,
+ process_sse_chunk = _M.process_sse_chunk
+})
Review Comment:
The Anthropic driver doesn't implement parse_token_usage function or handle
token usage from Anthropic responses. Anthropic's response includes usage data
in a different format than OpenAI (with input_tokens and output_tokens).
Without proper token usage parsing, metrics and logging will not include token
counts for Anthropic requests.
##########
apisix/plugins/ai-drivers/openai-base.lua:
##########
@@ -1,293 +1,74 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-local _M = {}
-
-local mt = {
- __index = _M
-}
-
-local CONTENT_TYPE_JSON = "application/json"
+-- apisix/plugins/ai-drivers/openai-base.lua (重构后)
local core = require("apisix.core")
-local plugin = require("apisix.plugin")
-local http = require("resty.http")
-local url = require("socket.url")
-local sse = require("apisix.plugins.ai-drivers.sse")
-local ngx = ngx
-local ngx_now = ngx.now
-
-local table = table
-local pairs = pairs
-local type = type
-local math = math
-local ipairs = ipairs
-local setmetatable = setmetatable
-
-local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
-local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+local _M = {}
+-- OpenAI 驱动的构造函数
function _M.new(opts)
-
- local self = {
+ -- 继承通用基类,并传入 OpenAI 的 API 信息和自定义处理函数
+ local self = driver_base.new({
host = opts.host,
port = opts.port,
path = opts.path,
- remove_model = opts.options and opts.options.remove_model
- }
- return setmetatable(self, mt)
+ scheme = opts.scheme or "https",
+ -- OpenAI 特有的处理函数
+ process_sse_chunk = _M.process_sse_chunk,
+ parse_token_usage = _M.parse_token_usage,
+ -- transform_request 和 transform_response 在 OpenAI 兼容层中通常不需要
+ } )
Review Comment:
There is an extra space before the closing parenthesis. This is inconsistent
with the rest of the codebase.
```suggestion
})
```
##########
apisix/plugins/ai-drivers/anthropic.lua:
##########
@@ -1,24 +1,97 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-
-return require("apisix.plugins.ai-drivers.openai-base").new(
- {
- host = "api.anthropic.com",
- path = "/v1/chat/completions",
- port = 443
+local core = require("apisix.core")
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+
+-- 将 OpenAI 兼容请求转换为 Anthropic 原生请求
+function _M.transform_request(request_table)
+ local anthropic_request = {
+ model = request_table.model,
+ max_tokens = request_table.max_tokens or 1024,
+ stream = request_table.stream,
}
-)
+
+ local messages = request_table.messages
+ local system_prompt = nil
+ local new_messages = {}
+
+ for _, msg in ipairs(messages) do
+ if msg.role == "system" then
+ -- Anthropic Messages API 支持 system 字段
+ system_prompt = msg.content
+ elseif msg.role == "user" or msg.role == "assistant" then
+ -- 角色映射:OpenAI 的 user/assistant 对应 Anthropic 的 user/assistant
+ table.insert(new_messages, {
+ role = msg.role,
+ content = msg.content
+ })
+ end
+ end
+
+ if system_prompt then
+ anthropic_request.system = system_prompt
+ end
+
+ anthropic_request.messages = new_messages
+
+ -- 【添加日志打印】在返回前打印转换后的请求体,方便我们验证逻辑
+ local core = require("apisix.core")
+ core.log.warn("--- 转换后的 Anthropic 请求体开始 ---")
+ core.log.warn(core.json.encode(anthropic_request))
+ core.log.warn("--- 转换后的 Anthropic 请求体结束 ---")
+
+ return anthropic_request
+end
+
+-- 处理流式响应的 SSE Chunk 转换
+function _M.process_sse_chunk(chunk)
+ local events = sse.decode(chunk)
+ local out = {}
+
+ for _, e in ipairs(events) do
+ if e.type == "message" then
+ local d = core.json.decode(e.data)
+ if d.type == "content_block_delta" then
+ -- 转换为 OpenAI 兼容的流式格式
+ table.insert(out, "data: " .. core.json.encode({
+ choices = {
+ {
+ delta = {
+ content = d.delta.text
+ }
+ }
+ }
+ }) .. "\n")
+ elseif d.type == "message_stop" then
+ table.insert(out, "data: [DONE]\n")
+ end
+ end
+ end
+
+ return table.concat(out)
+end
+
+-- 将 Anthropic 原生响应转换为 OpenAI 兼容响应
+function _M.transform_response(body)
+ local d = core.json.decode(body)
+ return core.json.encode({
+ choices = {
+ {
+ message = {
+ content = d.content[1].text
+ }
+ }
+ }
+ })
+end
Review Comment:
The transform_response function assumes the response body structure without
error handling. If d.content is nil or d.content[1] is nil or d.content[1].text
is nil, this will cause a runtime error. Proper validation and error handling
should be added to handle malformed responses.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
Review Comment:
There is an extra space before the closing parenthesis. This is inconsistent
with the rest of the codebase.
##########
apisix/plugins/ai-drivers/openai-base.lua:
##########
@@ -1,293 +1,74 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-local _M = {}
-
-local mt = {
- __index = _M
-}
-
-local CONTENT_TYPE_JSON = "application/json"
+-- apisix/plugins/ai-drivers/openai-base.lua (重构后)
local core = require("apisix.core")
-local plugin = require("apisix.plugin")
-local http = require("resty.http")
-local url = require("socket.url")
-local sse = require("apisix.plugins.ai-drivers.sse")
-local ngx = ngx
-local ngx_now = ngx.now
-
-local table = table
-local pairs = pairs
-local type = type
-local math = math
-local ipairs = ipairs
-local setmetatable = setmetatable
-
-local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
-local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+local _M = {}
+-- OpenAI 驱动的构造函数
function _M.new(opts)
-
- local self = {
+ -- 继承通用基类,并传入 OpenAI 的 API 信息和自定义处理函数
+ local self = driver_base.new({
host = opts.host,
port = opts.port,
path = opts.path,
- remove_model = opts.options and opts.options.remove_model
- }
- return setmetatable(self, mt)
+ scheme = opts.scheme or "https",
+ -- OpenAI 特有的处理函数
+ process_sse_chunk = _M.process_sse_chunk,
+ parse_token_usage = _M.parse_token_usage,
+ -- transform_request 和 transform_response 在 OpenAI 兼容层中通常不需要
+ } )
+
+ return self
end
-
-function _M.validate_request(ctx)
- local ct = core.request.header(ctx, "Content-Type") or
CONTENT_TYPE_JSON
- if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
- return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
- end
-
- local request_table, err = core.request.get_json_request_body_table()
- if not request_table then
- return nil, err
- end
-
- return request_table, nil
-end
-
-
-local function handle_error(err)
- if core.string.find(err, "timeout") then
- return HTTP_GATEWAY_TIMEOUT
- end
- return HTTP_INTERNAL_SERVER_ERROR
-end
-
-
-local function read_response(ctx, res)
- local body_reader = res.body_reader
- if not body_reader then
- core.log.warn("AI service sent no response body")
- return HTTP_INTERNAL_SERVER_ERROR
- end
-
- local content_type = res.headers["Content-Type"]
- core.response.set_header("Content-Type", content_type)
-
- if content_type and core.string.find(content_type, "text/event-stream")
then
- local contents = {}
- while true do
- local chunk, err = body_reader() -- will read chunk by chunk
- ctx.var.apisix_upstream_response_time = math.floor((ngx_now() -
- ctx.llm_request_start_time) *
1000)
- if err then
- core.log.warn("failed to read response chunk: ", err)
- return handle_error(err)
- end
- if not chunk then
- return
+-- 将 OpenAI 原生流式响应块转换为 APISIX 兼容格式(主要用于 token 统计和错误处理)
+function _M.process_sse_chunk(chunk)
+ local events = sse.decode(chunk)
+ local contents = {}
+
+ for _, event in ipairs(events) do
+ if event.type == "message" then
+ local data, err = core.json.decode(event.data)
+ if not data then
+ core.log.warn("failed to decode SSE data: ", err)
+ goto continue
end
- if ctx.var.llm_time_to_first_token == "0" then
- ctx.var.llm_time_to_first_token = math.floor(
- (ngx_now() -
ctx.llm_request_start_time) * 1000)
+ -- 提取 token usage (仅在非流式或流式结束时出现)
+ if data.usage and type(data.usage) == "table" then
+ -- 实际 APISIX 实现中,这部分逻辑可能在 ai-proxy 插件的 response 阶段
+ -- 这里仅作示意,实际应依赖 APISIX 内部机制
end
- local events = sse.decode(chunk)
- ctx.llm_response_contents_in_chunk = {}
- for _, event in ipairs(events) do
- if event.type == "message" then
- local data, err = core.json.decode(event.data)
- if not data then
- core.log.warn("failed to decode SSE data: ", err)
- goto CONTINUE
- end
-
- if data and type(data.choices) == "table" and
#data.choices > 0 then
- for _, choice in ipairs(data.choices) do
- if type(choice) == "table"
- and type(choice.delta) == "table"
- and type(choice.delta.content) == "string"
then
- core.table.insert(contents,
choice.delta.content)
-
core.table.insert(ctx.llm_response_contents_in_chunk,
- choice.delta.content)
- end
- end
- end
-
-
- -- usage field is null for non-last events, null is parsed
as userdata type
- if data and type(data.usage) == "table" then
- core.log.info("got token usage from ai service: ",
- core.json.delay_encode(data.usage))
- ctx.llm_raw_usage = data.usage
- ctx.ai_token_usage = {
- prompt_tokens = data.usage.prompt_tokens or 0,
- completion_tokens = data.usage.completion_tokens
or 0,
- total_tokens = data.usage.total_tokens or 0,
- }
- ctx.var.llm_prompt_tokens =
ctx.ai_token_usage.prompt_tokens
- ctx.var.llm_completion_tokens =
ctx.ai_token_usage.completion_tokens
- ctx.var.llm_response_text = table.concat(contents, "")
+ -- 提取内容
+ if data.choices and type(data.choices) == "table" and
#data.choices > 0 then
+ for _, choice in ipairs(data.choices) do
+ if type(choice) == "table" and type(choice.delta) ==
"table" and type(choice.delta.content) == "string" then
+ table.insert(contents, choice.delta.content)
end
- elseif event.type == "done" then
- ctx.var.llm_request_done = true
end
-
- ::CONTINUE::
end
-
- plugin.lua_response_filter(ctx, res.headers, chunk)
end
+ ::continue::
end
- local raw_res_body, err = res:read_body()
- if not raw_res_body then
- core.log.warn("failed to read response body: ", err)
- return handle_error(err)
- end
- ngx.status = res.status
- ctx.var.llm_time_to_first_token = math.floor((ngx_now() -
ctx.llm_request_start_time) * 1000)
- ctx.var.apisix_upstream_response_time = ctx.var.llm_time_to_first_token
- local res_body, err = core.json.decode(raw_res_body)
- if err then
- core.log.warn("invalid response body from ai service: ", raw_res_body,
" err: ", err,
- ", it will cause token usage not available")
- else
- core.log.info("got token usage from ai service: ",
core.json.delay_encode(res_body.usage))
- ctx.ai_token_usage = {}
- if type(res_body.usage) == "table" then
- ctx.llm_raw_usage = res_body.usage
- ctx.ai_token_usage.prompt_tokens = res_body.usage.prompt_tokens or 0
- ctx.ai_token_usage.completion_tokens =
res_body.usage.completion_tokens or 0
- ctx.ai_token_usage.total_tokens = res_body.usage.total_tokens or 0
- end
- ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens or 0
- ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens
or 0
- if type(res_body.choices) == "table" and #res_body.choices > 0 then
- local contents = {}
- for _, choice in ipairs(res_body.choices) do
- if type(choice) == "table"
- and type(choice.message) == "table"
- and type(choice.message.content) == "string" then
- core.table.insert(contents, choice.message.content)
- end
- end
- local content_to_check = table.concat(contents, " ")
- ctx.var.llm_response_text = content_to_check
- end
- end
- plugin.lua_response_filter(ctx, res.headers, raw_res_body)
+ -- 返回原始 chunk,因为 OpenAI 兼容层不需要对 chunk 本身进行格式转换
+ return chunk
end
-
-function _M.request(self, ctx, conf, request_table, extra_opts)
- local httpc, err = http.new()
- if not httpc then
- core.log.error("failed to create http client to send request to LLM
server: ", err)
- return HTTP_INTERNAL_SERVER_ERROR
+-- 解析 OpenAI 的 token usage
+function _M.parse_token_usage(openai_usage)
+ if not openai_usage then
+ return nil
end
- httpc:set_timeout(conf.timeout)
-
- local endpoint = extra_opts and extra_opts.endpoint
- local parsed_url
- if endpoint then
- parsed_url = url.parse(endpoint)
- end
-
- local scheme = parsed_url and parsed_url.scheme or "https"
- local host = parsed_url and parsed_url.host or self.host
- local port = parsed_url and parsed_url.port
- if not port then
- if scheme == "https" then
- port = 443
- else
- port = 80
- end
- end
- local ok, err = httpc:connect({
- scheme = scheme,
- host = host,
- port = port,
- ssl_verify = conf.ssl_verify,
- ssl_server_name = parsed_url and parsed_url.host or self.host,
- })
-
- if not ok then
- core.log.warn("failed to connect to LLM server: ", err)
- return handle_error(err)
- end
-
- local query_params = extra_opts.query_params
-
- if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query
> 0 then
- local args_tab = core.string.decode_args(parsed_url.query)
- if type(args_tab) == "table" then
- core.table.merge(query_params, args_tab)
- end
- end
-
- local path = (parsed_url and parsed_url.path or self.path)
- local headers = extra_opts.headers
- headers["Content-Type"] = "application/json"
- local params = {
- method = "POST",
- headers = headers,
- ssl_verify = conf.ssl_verify,
- path = path,
- query = query_params
+ return {
+ prompt_tokens = openai_usage.prompt_tokens or 0,
+ completion_tokens = openai_usage.completion_tokens or 0,
+ total_tokens = openai_usage.total_tokens or 0
}
-
- if extra_opts.model_options then
- for opt, val in pairs(extra_opts.model_options) do
- request_table[opt] = val
- end
- end
- if self.remove_model then
- request_table.model = nil
- end
- local req_json, err = core.json.encode(request_table)
- if not req_json then
- return nil, err
- end
-
- params.body = req_json
-
- local res, err = httpc:request(params)
- if not res then
- core.log.warn("failed to send request to LLM server: ", err)
- return handle_error(err)
- end
-
- -- handling this error separately is needed for retries
- if res.status == 429 or (res.status >= 500 and res.status < 600 )then
- return res.status
- end
-
- local code, body = read_response(ctx, res)
-
- if conf.keepalive then
- local ok, err = httpc:set_keepalive(conf.keepalive_timeout,
conf.keepalive_pool)
- if not ok then
- core.log.warn("failed to keepalive connection: ", err)
- end
- end
-
- return code, body
end
-
-return _M
+return _M.new({})
Review Comment:
The refactored openai-base.lua returns _M.new({}) at the end, which means
every OpenAI-compatible driver that requires this module will get the same
singleton instance with empty opts. This breaks the pattern where each driver
(openai.lua, deepseek.lua, etc.) calls .new() with their specific
host/port/path configuration. This line should return just _M, not _M.new({}).
```suggestion
return _M
```
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
Review Comment:
The handle_error function calls core.response.exit() which returns control
immediately. This means the return statement after calling this function is
unreachable. The function should return a status code value instead of calling
core.response.exit() to maintain consistency with the original implementation.
##########
apisix/plugins/ai-drivers/anthropic.lua:
##########
@@ -1,24 +1,97 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-
-return require("apisix.plugins.ai-drivers.openai-base").new(
- {
- host = "api.anthropic.com",
- path = "/v1/chat/completions",
- port = 443
+local core = require("apisix.core")
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+
+-- 将 OpenAI 兼容请求转换为 Anthropic 原生请求
+function _M.transform_request(request_table)
+ local anthropic_request = {
+ model = request_table.model,
+ max_tokens = request_table.max_tokens or 1024,
+ stream = request_table.stream,
}
-)
+
+ local messages = request_table.messages
+ local system_prompt = nil
+ local new_messages = {}
+
+ for _, msg in ipairs(messages) do
+ if msg.role == "system" then
+ -- Anthropic Messages API 支持 system 字段
+ system_prompt = msg.content
+ elseif msg.role == "user" or msg.role == "assistant" then
+ -- 角色映射:OpenAI 的 user/assistant 对应 Anthropic 的 user/assistant
+ table.insert(new_messages, {
+ role = msg.role,
+ content = msg.content
+ })
+ end
+ end
+
+ if system_prompt then
+ anthropic_request.system = system_prompt
+ end
+
+ anthropic_request.messages = new_messages
+
+ -- 【添加日志打印】在返回前打印转换后的请求体,方便我们验证逻辑
+ local core = require("apisix.core")
Review Comment:
The variable 'core' is already required at line 1, but is required again at
line 39. This redundant require statement should be removed.
```suggestion
```
##########
apisix/plugins/ai-drivers/openai-base.lua:
##########
@@ -1,293 +1,74 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-local _M = {}
-
-local mt = {
- __index = _M
-}
-
-local CONTENT_TYPE_JSON = "application/json"
+-- apisix/plugins/ai-drivers/openai-base.lua (重构后)
local core = require("apisix.core")
-local plugin = require("apisix.plugin")
-local http = require("resty.http")
-local url = require("socket.url")
-local sse = require("apisix.plugins.ai-drivers.sse")
-local ngx = ngx
-local ngx_now = ngx.now
-
-local table = table
-local pairs = pairs
-local type = type
-local math = math
-local ipairs = ipairs
-local setmetatable = setmetatable
-
-local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
-local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+local _M = {}
+-- OpenAI 驱动的构造函数
function _M.new(opts)
-
- local self = {
+ -- 继承通用基类,并传入 OpenAI 的 API 信息和自定义处理函数
+ local self = driver_base.new({
host = opts.host,
port = opts.port,
path = opts.path,
- remove_model = opts.options and opts.options.remove_model
- }
- return setmetatable(self, mt)
+ scheme = opts.scheme or "https",
+ -- OpenAI 特有的处理函数
+ process_sse_chunk = _M.process_sse_chunk,
+ parse_token_usage = _M.parse_token_usage,
+ -- transform_request 和 transform_response 在 OpenAI 兼容层中通常不需要
+ } )
+
+ return self
end
Review Comment:
The refactored openai-base.lua is missing the validate_request method which
is called by ai-proxy/base.lua at line 56. This will cause a runtime error when
any OpenAI-compatible driver is used. The validate_request method needs to be
added back to this module.
##########
apisix/plugins/ai-drivers/openai-base.lua:
##########
@@ -1,293 +1,74 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-local _M = {}
-
-local mt = {
- __index = _M
-}
-
-local CONTENT_TYPE_JSON = "application/json"
+-- apisix/plugins/ai-drivers/openai-base.lua (重构后)
local core = require("apisix.core")
-local plugin = require("apisix.plugin")
-local http = require("resty.http")
-local url = require("socket.url")
-local sse = require("apisix.plugins.ai-drivers.sse")
-local ngx = ngx
-local ngx_now = ngx.now
-
-local table = table
-local pairs = pairs
-local type = type
-local math = math
-local ipairs = ipairs
-local setmetatable = setmetatable
-
-local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
-local HTTP_GATEWAY_TIMEOUT = ngx.HTTP_GATEWAY_TIMEOUT
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+local _M = {}
+-- OpenAI 驱动的构造函数
function _M.new(opts)
-
- local self = {
+ -- 继承通用基类,并传入 OpenAI 的 API 信息和自定义处理函数
+ local self = driver_base.new({
host = opts.host,
port = opts.port,
path = opts.path,
- remove_model = opts.options and opts.options.remove_model
- }
- return setmetatable(self, mt)
+ scheme = opts.scheme or "https",
+ -- OpenAI 特有的处理函数
+ process_sse_chunk = _M.process_sse_chunk,
+ parse_token_usage = _M.parse_token_usage,
+ -- transform_request 和 transform_response 在 OpenAI 兼容层中通常不需要
+ } )
+
+ return self
end
Review Comment:
The refactored openai-base.lua is missing the request method which is called
by ai-proxy/base.lua at line 87. This will cause a runtime error when any
OpenAI-compatible driver is used. The request method needs to be added back to
this module or properly inherited from the base driver.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
+
+ if not res then
+ core.log.error("failed to send request to LLM server: ", err)
+ return self:handle_error(err)
+ end
+
+ -- 4. 处理响应
+ local is_stream = request_table.stream
+ local content_type = res.headers["Content-Type"]
+
+ if is_stream and core.string.find(content_type, "text/event-stream") then
+ -- 流式响应处理
+ return self:handle_stream_response(ctx, res, conf)
+ else
+ -- 非流式响应处理
+ return self:handle_non_stream_response(ctx, res, conf)
+ end
+end
+
+-- 处理非流式响应
+function _M.handle_non_stream_response(self, ctx, res, conf)
+ local raw_res_body = res:read_body()
+ if not raw_res_body then
+ core.log.warn("failed to read response body: ", res.err)
+ return self:handle_error(res.err)
+ end
+
+ -- 协议转换(如果驱动提供了 transform_response)
+ if self.transform_response then
+ raw_res_body = self.transform_response(raw_res_body)
+ end
+
+ -- 设置响应头和状态码
+ core.response.set_header(ctx, "Content-Type", "application/json")
+ core.response.set_status(ctx, res.status)
+ core.response.set_body(ctx, raw_res_body)
+ core.response.send_response(ctx)
+end
Review Comment:
The handle_non_stream_response function uses non-existent APISIX API
methods. Methods like core.response.set_status(), core.response.set_body(), and
core.response.send_response() don't exist in the APISIX core.response module.
The original implementation directly sets ngx.status and uses
plugin.lua_response_filter. This needs to follow the existing APISIX patterns.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
Review Comment:
The timeout is only used in the request() call but not in the HTTP client
setup. The original implementation calls httpc:set_timeout(conf.timeout) before
making the connection. The resty.http library's request() method timeout
parameter may not work as expected without proper client initialization.
```suggestion
httpc:set_timeout(conf.timeout or 60000)
local res, err = httpc:request({
method = "POST",
url = upstream_url,
headers = headers,
body = core.json.encode(request_table ),
ssl_verify = false, -- 生产环境应为 true
```
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
Review Comment:
The ai-driver-base.lua request method doesn't include authentication headers
that are passed via extra_opts.headers from the ai-proxy base. The headers at
line 64-68 are hardcoded and don't merge in the authentication headers from
conf or extra_opts, which means API calls will fail with authentication errors.
The implementation needs to properly merge headers from extra_opts similar to
the original openai-base.lua implementation.
```suggestion
-- 从 conf.headers 和 extra_opts.headers 合并请求头,以保留上游传入的认证信息
local headers = {}
if conf and conf.headers then
for k, v in pairs(conf.headers) do
headers[k] = v
end
end
if extra_opts and extra_opts.headers then
for k, v in pairs(extra_opts.headers) do
headers[k] = v
end
end
-- 默认 Host 与 Content-Type,如未显式提供则补充
if headers["Host"] == nil then
headers["Host"] = self.host
end
if headers["Content-Type"] == nil then
headers["Content-Type"] = "application/json"
end
-- 认证头由具体驱动在 transform_request 中添加或在 conf/extra_opts 中获取
```
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
+
+ if not res then
+ core.log.error("failed to send request to LLM server: ", err)
+ return self:handle_error(err)
+ end
+
+ -- 4. 处理响应
+ local is_stream = request_table.stream
+ local content_type = res.headers["Content-Type"]
+
+ if is_stream and core.string.find(content_type, "text/event-stream") then
+ -- 流式响应处理
+ return self:handle_stream_response(ctx, res, conf)
+ else
+ -- 非流式响应处理
+ return self:handle_non_stream_response(ctx, res, conf)
+ end
+end
+
+-- 处理非流式响应
+function _M.handle_non_stream_response(self, ctx, res, conf)
+ local raw_res_body = res:read_body()
+ if not raw_res_body then
+ core.log.warn("failed to read response body: ", res.err)
+ return self:handle_error(res.err)
+ end
+
+ -- 协议转换(如果驱动提供了 transform_response)
+ if self.transform_response then
+ raw_res_body = self.transform_response(raw_res_body)
+ end
+
+ -- 设置响应头和状态码
+ core.response.set_header(ctx, "Content-Type", "application/json")
+ core.response.set_status(ctx, res.status)
+ core.response.set_body(ctx, raw_res_body)
+ core.response.send_response(ctx)
+end
+
+-- 处理流式响应
+function _M.handle_stream_response(self, ctx, res, conf)
+ core.response.set_header(ctx, "Content-Type", "text/event-stream")
+ core.response.set_status(ctx, res.status)
+ core.response.send_http_header(ctx )
Review Comment:
There is an extra space before the closing parenthesis. This is inconsistent
with the rest of the codebase.
```suggestion
core.response.send_http_header(ctx)
```
##########
apisix/plugins/ai-drivers/anthropic.lua:
##########
@@ -1,24 +1,97 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-
-return require("apisix.plugins.ai-drivers.openai-base").new(
- {
- host = "api.anthropic.com",
- path = "/v1/chat/completions",
- port = 443
+local core = require("apisix.core")
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+
+-- 将 OpenAI 兼容请求转换为 Anthropic 原生请求
+function _M.transform_request(request_table)
+ local anthropic_request = {
+ model = request_table.model,
+ max_tokens = request_table.max_tokens or 1024,
+ stream = request_table.stream,
}
-)
+
+ local messages = request_table.messages
+ local system_prompt = nil
+ local new_messages = {}
+
+ for _, msg in ipairs(messages) do
+ if msg.role == "system" then
+ -- Anthropic Messages API 支持 system 字段
+ system_prompt = msg.content
+ elseif msg.role == "user" or msg.role == "assistant" then
+ -- 角色映射:OpenAI 的 user/assistant 对应 Anthropic 的 user/assistant
+ table.insert(new_messages, {
+ role = msg.role,
+ content = msg.content
+ })
+ end
+ end
+
+ if system_prompt then
+ anthropic_request.system = system_prompt
+ end
+
+ anthropic_request.messages = new_messages
+
+ -- 【添加日志打印】在返回前打印转换后的请求体,方便我们验证逻辑
+ local core = require("apisix.core")
+ core.log.warn("--- 转换后的 Anthropic 请求体开始 ---")
+ core.log.warn(core.json.encode(anthropic_request))
+ core.log.warn("--- 转换后的 Anthropic 请求体结束 ---")
+
+ return anthropic_request
+end
Review Comment:
Anthropic API requires specific headers including 'x-api-key' for
authentication and 'anthropic-version' header. The transform_request function
doesn't handle these required headers. While authentication might be added via
extra_opts.headers, the anthropic-version header (required by Anthropic API) is
not being set. This will cause API calls to fail.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
Review Comment:
The request method is using the wrong HTTP client API. The resty.http
library's request() method expects separate parameters for scheme, host, port,
etc., not a single URL parameter. This code will fail at runtime. The
implementation should follow the pattern used in the original openai-base.lua
with connect() and request() methods.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
+ timeout = conf.timeout or 60000,
+ })
Review Comment:
The ai-driver-base.lua request method doesn't handle
extra_opts.model_options which allows setting provider-specific options on the
request. The original openai-base.lua merges model_options into the
request_table. Without this, configuration options like temperature, top_p,
etc. cannot be set via the plugin configuration.
##########
apisix/plugins/ai-drivers/ai-driver-base.lua:
##########
@@ -0,0 +1,148 @@
+-- apisix/plugins/ai-drivers/ai-driver-base.lua
+
+local core = require("apisix.core")
+local plugin = require("apisix.plugin")
+local http = require("resty.http" )
+local url = require("socket.url")
+-- 假设 sse 模块存在于 apisix.plugins.ai-drivers.sse
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+local mt = {
+ __index = _M
+}
+
+-- 构造函数,用于初始化驱动的通用配置
+function _M.new(opts)
+ local self = {
+ host = opts.host,
+ port = opts.port,
+ path = opts.path,
+ scheme = opts.scheme or "https",
+ -- 抽象方法占位符 ,由具体驱动实现
+ transform_request = opts.transform_request,
+ transform_response = opts.transform_response,
+ process_sse_chunk = opts.process_sse_chunk,
+ parse_token_usage = opts.parse_token_usage,
+ }
+
+ return setmetatable(self, mt)
+end
+
+-- 通用请求验证:检查 Content-Type 并解析 JSON
+function _M.validate_request(self, ctx)
+ local ct = core.request.header(ctx, "Content-Type")
+ if not core.string.has_prefix(ct, "application/json") then
+ return nil, "unsupported content-type: " .. ct .. ", only
application/json is supported"
+ end
+
+ local request_table, err = core.request.get_json_request_body_table()
+ if not request_table then
+ return nil, err
+ end
+
+ return request_table, nil
+end
+
+-- 通用错误处理
+function _M.handle_error(self, err)
+ if core.string.find(err, "timeout") then
+ return core.response.exit(504) -- HTTP_GATEWAY_TIMEOUT
+ end
+ return core.response.exit(500) -- HTTP_INTERNAL_SERVER_ERROR
+end
+
+-- 核心请求方法
+function _M.request(self, ctx, conf, request_table, extra_opts)
+ -- 1. 协议转换(如果驱动提供了 transform_request)
+ if self.transform_request then
+ request_table = self.transform_request(request_table)
+ end
+
+ -- 2. 构造上游请求
+ local upstream_url = self.scheme .. "://" .. self.host .. ":" .. self.port
.. self.path
+ local headers = {
+ ["Host"] = self.host,
+ ["Content-Type"] = "application/json",
+ -- 认证头由具体驱动在 transform_request 中添加或在 conf 中获取
+ }
+
+ -- 3. 发送请求
+ local httpc = http.new( )
+ local res, err = httpc:request({
+ method = "POST",
+ url = upstream_url,
+ headers = headers,
+ body = core.json.encode(request_table ),
+ ssl_verify = false, -- 生产环境应为 true
Review Comment:
The ssl_verify is hardcoded to false in the request method. This is a
security risk as it disables SSL certificate verification. The original
implementation uses conf.ssl_verify to allow proper SSL verification. This
should respect the ssl_verify setting from conf.
```suggestion
ssl_verify = conf.ssl_verify ~= nil and conf.ssl_verify or true,
```
##########
apisix/plugins/ai-drivers/anthropic.lua:
##########
@@ -1,24 +1,97 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-
-return require("apisix.plugins.ai-drivers.openai-base").new(
- {
- host = "api.anthropic.com",
- path = "/v1/chat/completions",
- port = 443
+local core = require("apisix.core")
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+
+-- 将 OpenAI 兼容请求转换为 Anthropic 原生请求
+function _M.transform_request(request_table)
+ local anthropic_request = {
+ model = request_table.model,
+ max_tokens = request_table.max_tokens or 1024,
+ stream = request_table.stream,
}
-)
+
+ local messages = request_table.messages
+ local system_prompt = nil
+ local new_messages = {}
+
+ for _, msg in ipairs(messages) do
+ if msg.role == "system" then
+ -- Anthropic Messages API 支持 system 字段
+ system_prompt = msg.content
+ elseif msg.role == "user" or msg.role == "assistant" then
+ -- 角色映射:OpenAI 的 user/assistant 对应 Anthropic 的 user/assistant
+ table.insert(new_messages, {
+ role = msg.role,
+ content = msg.content
+ })
+ end
+ end
+
+ if system_prompt then
+ anthropic_request.system = system_prompt
+ end
+
+ anthropic_request.messages = new_messages
+
+ -- 【添加日志打印】在返回前打印转换后的请求体,方便我们验证逻辑
+ local core = require("apisix.core")
+ core.log.warn("--- 转换后的 Anthropic 请求体开始 ---")
+ core.log.warn(core.json.encode(anthropic_request))
+ core.log.warn("--- 转换后的 Anthropic 请求体结束 ---")
+
+ return anthropic_request
+end
+
+-- 处理流式响应的 SSE Chunk 转换
+function _M.process_sse_chunk(chunk)
+ local events = sse.decode(chunk)
+ local out = {}
+
+ for _, e in ipairs(events) do
+ if e.type == "message" then
+ local d = core.json.decode(e.data)
+ if d.type == "content_block_delta" then
+ -- 转换为 OpenAI 兼容的流式格式
+ table.insert(out, "data: " .. core.json.encode({
+ choices = {
+ {
+ delta = {
+ content = d.delta.text
+ }
+ }
+ }
+ }) .. "\n")
+ elseif d.type == "message_stop" then
+ table.insert(out, "data: [DONE]\n")
+ end
+ end
+ end
+
+ return table.concat(out)
+end
+
+-- 将 Anthropic 原生响应转换为 OpenAI 兼容响应
+function _M.transform_response(body)
+ local d = core.json.decode(body)
+ return core.json.encode({
+ choices = {
+ {
+ message = {
+ content = d.content[1].text
+ }
+ }
+ }
+ })
+end
+
+-- 导出驱动实例
+return driver_base.new({
+ host = "api.anthropic.com",
+ port = 443,
+ path = "/v1/messages",
+ transform_request = _M.transform_request,
+ transform_response = _M.transform_response,
+ process_sse_chunk = _M.process_sse_chunk
+})
Review Comment:
The Anthropic driver is missing the validate_request method which is called
by ai-proxy/base.lua. This will cause a runtime error when the Anthropic
provider is used. The validate_request method needs to be implemented or the
driver needs to properly inherit it from the base driver.
##########
apisix/plugins/ai-drivers/anthropic.lua:
##########
@@ -1,24 +1,97 @@
---
--- Licensed to the Apache Software Foundation (ASF) under one or more
--- contributor license agreements. See the NOTICE file distributed with
--- this work for additional information regarding copyright ownership.
--- The ASF licenses this file to You under the Apache License, Version 2.0
--- (the "License"); you may not use this file except in compliance with
--- the License. You may obtain a copy of the License at
---
--- http://www.apache.org/licenses/LICENSE-2.0
---
--- Unless required by applicable law or agreed to in writing, software
--- distributed under the License is distributed on an "AS IS" BASIS,
--- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
--- See the License for the specific language governing permissions and
--- limitations under the License.
---
-
-return require("apisix.plugins.ai-drivers.openai-base").new(
- {
- host = "api.anthropic.com",
- path = "/v1/chat/completions",
- port = 443
+local core = require("apisix.core")
+local driver_base = require("apisix.plugins.ai-drivers.ai-driver-base")
+local sse = require("apisix.plugins.ai-drivers.sse")
+
+local _M = {}
+
+-- 将 OpenAI 兼容请求转换为 Anthropic 原生请求
+function _M.transform_request(request_table)
+ local anthropic_request = {
+ model = request_table.model,
+ max_tokens = request_table.max_tokens or 1024,
+ stream = request_table.stream,
}
-)
+
+ local messages = request_table.messages
+ local system_prompt = nil
+ local new_messages = {}
+
+ for _, msg in ipairs(messages) do
+ if msg.role == "system" then
+ -- Anthropic Messages API 支持 system 字段
+ system_prompt = msg.content
+ elseif msg.role == "user" or msg.role == "assistant" then
+ -- 角色映射:OpenAI 的 user/assistant 对应 Anthropic 的 user/assistant
+ table.insert(new_messages, {
+ role = msg.role,
+ content = msg.content
+ })
+ end
+ end
+
+ if system_prompt then
+ anthropic_request.system = system_prompt
+ end
+
+ anthropic_request.messages = new_messages
+
+ -- 【添加日志打印】在返回前打印转换后的请求体,方便我们验证逻辑
+ local core = require("apisix.core")
+ core.log.warn("--- 转换后的 Anthropic 请求体开始 ---")
+ core.log.warn(core.json.encode(anthropic_request))
+ core.log.warn("--- 转换后的 Anthropic 请求体结束 ---")
+
+ return anthropic_request
+end
+
+-- 处理流式响应的 SSE Chunk 转换
+function _M.process_sse_chunk(chunk)
+ local events = sse.decode(chunk)
+ local out = {}
+
+ for _, e in ipairs(events) do
+ if e.type == "message" then
+ local d = core.json.decode(e.data)
+ if d.type == "content_block_delta" then
+ -- 转换为 OpenAI 兼容的流式格式
+ table.insert(out, "data: " .. core.json.encode({
+ choices = {
+ {
+ delta = {
+ content = d.delta.text
+ }
+ }
+ }
+ }) .. "\n")
+ elseif d.type == "message_stop" then
+ table.insert(out, "data: [DONE]\n")
+ end
+ end
Review Comment:
The process_sse_chunk function doesn't handle JSON decode errors properly.
If core.json.decode(e.data) fails at line 54, the variable 'd' will be nil but
the code continues to access d.type without checking, which will cause a
runtime error.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]