This is an automated email from the ASF dual-hosted git repository. ashishtiwari pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/apisix.git
The following commit(s) were added to refs/heads/master by this push: new 3260931f3 feat: add ai-aliyun-content-moderation plugin (#12530) 3260931f3 is described below commit 3260931f301f741d5551af895cbfd6b12d1aff31 Author: Ashish Tiwari <ashishjaitiwari15112...@gmail.com> AuthorDate: Tue Aug 26 11:14:48 2025 +0530 feat: add ai-aliyun-content-moderation plugin (#12530) --- apisix-master-0.rockspec | 1 + apisix/cli/config.lua | 1 + apisix/cli/ngx_tpl.lua | 1 + apisix/core/ctx.lua | 2 + apisix/plugin.lua | 37 + apisix/plugins/ai-aliyun-content-moderation.lua | 459 +++++++++ apisix/plugins/ai-drivers/openai-base.lua | 105 +- apisix/plugins/ai-drivers/schema.lua | 59 +- apisix/plugins/ai-drivers/sse.lua | 117 +++ apisix/plugins/ai-request-rewrite.lua | 37 +- docs/en/latest/config.json | 1 + .../latest/plugins/ai-aliyun-content-moderation.md | 129 +++ t/APISIX.pm | 1 + t/admin/plugins.t | 1 + t/plugin/ai-aliyun-content-moderation.t | 1031 ++++++++++++++++++++ t/plugin/ai-proxy-multi3.t | 48 +- t/sse_server_example/main.go | 134 ++- 17 files changed, 2020 insertions(+), 144 deletions(-) diff --git a/apisix-master-0.rockspec b/apisix-master-0.rockspec index 82ca9d8bb..abc5f3a76 100644 --- a/apisix-master-0.rockspec +++ b/apisix-master-0.rockspec @@ -84,6 +84,7 @@ dependencies = { "jsonpath = 1.0-1", "api7-lua-resty-aws == 2.0.2-1", "multipart = 0.5.9-1", + "luautf8 = 0.1.6-1", } build = { diff --git a/apisix/cli/config.lua b/apisix/cli/config.lua index 191a40a8d..2c57541ba 100644 --- a/apisix/cli/config.lua +++ b/apisix/cli/config.lua @@ -231,6 +231,7 @@ local _M = { "ai-proxy-multi", "ai-proxy", "ai-aws-content-moderation", + "ai-aliyun-content-moderation", "proxy-mirror", "proxy-rewrite", "workflow", diff --git a/apisix/cli/ngx_tpl.lua b/apisix/cli/ngx_tpl.lua index cdedd73ce..296a651a5 100644 --- a/apisix/cli/ngx_tpl.lua +++ b/apisix/cli/ngx_tpl.lua @@ -806,6 +806,7 @@ http { set $dubbo_method ''; {% end %} + set $llm_content_risk_level ''; set $request_type 'traditional_http'; set $llm_time_to_first_token ''; diff --git a/apisix/core/ctx.lua b/apisix/core/ctx.lua index 7f3e47ff8..50e08cf93 100644 --- a/apisix/core/ctx.lua +++ b/apisix/core/ctx.lua @@ -233,6 +233,8 @@ do upstream_upgrade = true, upstream_connection = true, upstream_uri = true, + llm_content_risk_level = true, + request_type = true, llm_time_to_first_token = true, llm_model = true, diff --git a/apisix/plugin.lua b/apisix/plugin.lua index 5eed30001..2e10c0055 100644 --- a/apisix/plugin.lua +++ b/apisix/plugin.lua @@ -23,6 +23,8 @@ local expr = require("resty.expr.v1") local apisix_ssl = require("apisix.ssl") local re_split = require("ngx.re").split local ngx = ngx +local ngx_ok = ngx.OK +local ngx_print = ngx.print local crc32 = ngx.crc32_short local ngx_exit = ngx.exit local pkg_loaded = package.loaded @@ -1296,5 +1298,40 @@ function _M.run_global_rules(api_ctx, global_rules, phase_name) end end +function _M.lua_response_filter(api_ctx, headers, body) + local plugins = api_ctx.plugins + if not plugins or #plugins == 0 then + -- if there is no any plugin, just print the original body to downstream + ngx_print(body) + return + end + for i = 1, #plugins, 2 do + local phase_func = plugins[i]["lua_body_filter"] + if phase_func then + local conf = plugins[i + 1] + if not meta_filter(api_ctx, plugins[i]["name"], conf)then + goto CONTINUE + end + + run_meta_pre_function(conf, api_ctx, plugins[i]["name"]) + local code, new_body = phase_func(conf, api_ctx, headers, body) + if code then + if code ~= ngx_ok then + ngx.status = code + end + + ngx_print(new_body) + ngx_exit(ngx_ok) + end + if new_body then + body = new_body + end + end + + ::CONTINUE:: + end + ngx_print(body) +end + return _M diff --git a/apisix/plugins/ai-aliyun-content-moderation.lua b/apisix/plugins/ai-aliyun-content-moderation.lua new file mode 100644 index 000000000..1dde2b792 --- /dev/null +++ b/apisix/plugins/ai-aliyun-content-moderation.lua @@ -0,0 +1,459 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local ngx = ngx +local ngx_ok = ngx.OK +local os = os +local pairs = pairs +local ipairs = ipairs +local table = table +local string = string +local url = require("socket.url") +local utf8 = require("lua-utf8") +local core = require("apisix.core") +local http = require("resty.http") +local uuid = require("resty.jit-uuid") +local ai_schema = require("apisix.plugins.ai-drivers.schema") + +local sse = require("apisix.plugins.ai-drivers.sse") + +local schema = { + type = "object", + properties = { + stream_check_mode = { + type = "string", + enum = {"realtime", "final_packet"}, + default = "final_packet", + description = [[ + realtime: batched checks during streaming | final_packet: append risk_level at end + ]] + }, + stream_check_cache_size = { + type = "integer", + minimum = 1, + default = 128, + description = "max characters per moderation batch in realtime mode" + }, + stream_check_interval = { + type = "number", + minimum = 0.1, + default = 3, + description = "seconds between batch checks in realtime mode" + }, + endpoint = {type = "string", minLength = 1}, + region_id = {type ="string", minLength = 1}, + access_key_id = {type = "string", minLength = 1}, + access_key_secret = {type ="string", minLength = 1}, + check_request = {type = "boolean", default = true}, + check_response = {type = "boolean", default = false}, + request_check_service = {type = "string", minLength = 1, default = "llm_query_moderation"}, + request_check_length_limit = {type = "number", default = 2000}, + response_check_service = {type = "string", minLength = 1, + default = "llm_response_moderation"}, + response_check_length_limit = {type = "number", default = 5000}, + risk_level_bar = {type = "string", + enum = {"none", "low", "medium", "high", "max"}, + default = "high"}, + deny_code = {type = "number", default = 200}, + deny_message = {type = "string"}, + timeout = { + type = "integer", + minimum = 1, + default = 10000, + description = "timeout in milliseconds", + }, + keepalive_pool = {type = "integer", minimum = 1, default = 30}, + keepalive = {type = "boolean", default = true}, + keepalive_timeout = {type = "integer", minimum = 1000, default = 60000}, + ssl_verify = {type = "boolean", default = true }, + }, + encrypt_fields = {"access_key_secret"}, + required = { "endpoint", "region_id", "access_key_id", "access_key_secret" }, +} + + +local _M = { + version = 0.1, + priority = 1029, + name = "ai-aliyun-content-moderation", + schema = schema, +} + + +function _M.check_schema(conf) + return core.schema.check(schema, conf) +end + + +local function risk_level_to_int(risk_level) + local risk_levels = { + ["max"] = 4, + ["high"] = 3, + ["medium"] = 2, + ["low"] = 1, + ["none"] = 0 + } + return risk_levels[risk_level] or -1 +end + + +-- openresty ngx.escape_uri don't escape some sub-delimis in rfc 3986 but aliyun do it, +-- in order to we can calculate same signature with aliyun, we need escape those chars manually +local sub_delims_rfc3986 = { + ["!"] = "%%21", + ["'"] = "%%27", + ["%("] = "%%28", + ["%)"] = "%%29", + ["*"] = "%%2A", +} +local function url_encoding(raw_str) + local encoded_str = ngx.escape_uri(raw_str) + for k, v in pairs(sub_delims_rfc3986) do + encoded_str = string.gsub(encoded_str, k, v) + end + return encoded_str +end + + +local function calculate_sign(params, secret) + local params_arr = {} + for k, v in pairs(params) do + table.insert(params_arr, ngx.escape_uri(k) .. "=" .. url_encoding(v)) + end + table.sort(params_arr) + local canonical_str = table.concat(params_arr, "&") + local str_to_sign = "POST&%2F&" .. ngx.escape_uri(canonical_str) + core.log.debug("string to calculate signature: ", str_to_sign) + return ngx.encode_base64(ngx.hmac_sha1(secret, str_to_sign)) +end + + +local function check_single_content(ctx, conf, content, service_name) + local timestamp = os.date("!%Y-%m-%dT%TZ") + local random_id = uuid.generate_v4() + local params = { + ["AccessKeyId"] = conf.access_key_id, + ["Action"] = "TextModerationPlus", + ["Format"] = "JSON", + ["RegionId"] = conf.region_id, + ["Service"] = service_name, + ["ServiceParameters"] = core.json.encode({sessionId = ctx.session_id, content = content}), + ["SignatureMethod"] = "HMAC-SHA1", + ["SignatureNonce"] = random_id, + ["SignatureVersion"] = "1.0", + ["Timestamp"] = timestamp, + ["Version"] = "2022-03-02", + } + params["Signature"] = calculate_sign(params, conf.access_key_secret .. "&") + + local httpc = http.new() + httpc:set_timeout(conf.timeout) + + local parsed_url = url.parse(conf.endpoint) + local ok, err = httpc:connect({ + scheme = parsed_url and parsed_url.scheme or "https", + host = parsed_url and parsed_url.host, + port = parsed_url and parsed_url.port, + ssl_verify = conf.ssl_verify, + ssl_server_name = parsed_url and parsed_url.host, + pool_size = conf.keepalive and conf.keepalive_pool, + }) + if not ok then + return nil, "failed to connect: " .. err + end + + local body = ngx.encode_args(params) + core.log.debug("text moderation request body: ", body) + local res, err = httpc:request{ + method = "POST", + body = body, + path = "/", + headers = { + ["Content-Type"] = "application/x-www-form-urlencoded", + } + } + if not res then + return nil, "failed to request: " .. err + end + local raw_res_body, err = res:read_body() + if not raw_res_body then + return nil, "failed to read response body: " .. err + end + if conf.keepalive then + local ok, err = httpc:set_keepalive(conf.keepalive_timeout, conf.keepalive_pool) + if not ok then + core.log.warn("failed to keepalive connection: ", err) + end + end + if res.status ~= 200 then + return nil, "failed to request aliyun text moderation service, status: " .. res.status + .. ", x-acs-request-id: " .. (res.headers["x-acs-request-id"] or "") + .. ", body: " .. raw_res_body + end + + core.log.debug("raw response: ", raw_res_body) + local response, err = core.json.decode(raw_res_body) + if not response then + return nil, "failed to decode response, " + .. ", x-acs-request-id: " .. (res.headers["x-acs-request-id"] or "") + .. ", err" .. err .. ", body: " .. raw_res_body + end + + local risk_level = response.Data and response.Data.RiskLevel + if not risk_level then + return nil, "failed to get risk level: " .. raw_res_body + end + ctx.var.llm_content_risk_level = risk_level + if risk_level_to_int(risk_level) < risk_level_to_int(conf.risk_level_bar) then + return false + end + -- answer is readable message for human + return true, response.Data.Advice and response.Data.Advice[1] + and response.Data.Advice[1].Answer +end + + +-- we need to return a provider compatible response without broken the ai client +local function deny_message(provider, message, model, stream, usage) + local content = message or "Your request violate our content policy." + if ai_schema.is_openai_compatible_provider(provider) then + if stream then + local data = { + id = uuid.generate_v4(), + object = "chat.completion.chunk", + model = model, + choices = { + { + index = 0, + delta = { + content = content, + }, + finish_reason = "stop" + } + }, + usage = usage, + } + + return "data: " .. core.json.encode(data) .. "\n\n" .. "data: [DONE]" + else + return core.json.encode({ + id = uuid.generate_v4(), + object = "chat.completion", + model = model, + choices = { + { + index = 0, + message = { + role = "assistant", + content = content + }, + finish_reason = "stop" + } + }, + usage = usage, + }) + end + end + + core.log.error("unsupported provider: ", provider) + return content +end + + +local function content_moderation(ctx, conf, provider, model, content, length_limit, + stream, usage, service_name) + core.log.debug("execute content moderation, content: ", content) + if not ctx.session_id then + ctx.session_id = uuid.generate_v4() + end + if #content <= length_limit then + local hit, err = check_single_content(ctx, conf, content, service_name) + if hit then + return conf.deny_code, deny_message(provider, conf.deny_message or err, + model, stream, usage) + end + if err then + core.log.error("failed to check content: ", err) + end + return + end + + local index = 1 + while true do + if index > #content then + return + end + local hit, err = check_single_content(ctx, conf, + utf8.sub(content, index, index + length_limit - 1), + service_name) + index = index + length_limit + if hit then + return conf.deny_code, deny_message(provider, conf.deny_message or err, + model, stream, usage) + end + if err then + core.log.error("failed to check content: ", err) + end + end +end + + +local function request_content_moderation(ctx, conf, content, model) + if not content or #content == 0 then + return + end + local provider = ctx.picked_ai_instance.provider + local stream = ctx.var.request_type == "ai_stream" + return content_moderation(ctx, conf, provider, model, content, conf.request_check_length_limit, + stream, { + prompt_tokens = 0, + completion_tokens = 0, + total_tokens = 0 + }, conf.request_check_service) +end + + +local function response_content_moderation(ctx, conf, content) + if not content or #content == 0 then + return + end + local provider = ctx.picked_ai_instance.provider + local model = ctx.var.request_llm_model or ctx.var.llm_model + local stream = ctx.var.request_type == "ai_stream" + local usage = ctx.var.llm_raw_usage + return content_moderation(ctx, conf, provider, model, content, + conf.response_check_length_limit, + stream, usage, conf.response_check_service) +end + +function _M.access(conf, ctx) + if not ctx.picked_ai_instance then + return 500, "no ai instance picked, " .. + "ai-aliyun-content-moderation plugin must be used with " .. + "ai-proxy or ai-proxy-multi plugin" + end + local provider = ctx.picked_ai_instance.provider + if not conf.check_request then + core.log.info("skip request check for this request") + return + end + local ct = core.request.header(ctx, "Content-Type") + if ct and not core.string.has_prefix(ct, "application/json") then + return 400, "unsupported content-type: " .. ct .. ", only application/json is supported" + end + local request_tab, err = core.request.get_json_request_body_table() + if not request_tab then + return 400, err + end + local ok, err = core.schema.check(ai_schema.chat_request_schema[provider], request_tab) + if not ok then + return 400, "request format doesn't match schema: " .. err + end + + core.log.info("current ai provider: ", provider) + + if ai_schema.is_openai_compatible_provider(provider) then + local contents = {} + for _, message in ipairs(request_tab.messages) do + if message.content then + core.table.insert(contents, message.content) + end + end + local content_to_check = table.concat(contents, " ") + local code, message = request_content_moderation(ctx, conf, + content_to_check, request_tab.model) + if code then + if request_tab.stream then + core.response.set_header("Content-Type", "text/event-stream") + return code, message + else + core.response.set_header("Content-Type", "application/json") + return code, message + end + end + return + end + return 500, "unsupported provider: " .. provider +end + + +function _M.lua_body_filter(conf, ctx, headers, body) + if not conf.check_response then + core.log.info("skip response check for this request") + return + end + local request_type = ctx.var.request_type + + if request_type == "ai_chat" then + local content = ctx.var.llm_response_text + return response_content_moderation(ctx, conf, content) + end + + if conf.stream_check_mode == "final_packet" then + if not ctx.var.llm_response_text then + return + end + response_content_moderation(ctx, conf, ctx.var.llm_response_text) + local events = sse.decode(body) + for _, event in ipairs(events) do + if event.type == "message" then + local data, err = core.json.decode(event.data) + if not data then + core.log.warn("failed to decode SSE data: ", err) + goto CONTINUE + end + data.risk_level = ctx.var.llm_content_risk_level + event.data = core.json.encode(data) + end + ::CONTINUE:: + end + + local raw_events = {} + local contains_done_event = false + for _, event in ipairs(events) do + if event.type == "done" then + contains_done_event = true + end + table.insert(raw_events, sse.encode(event)) + end + if not contains_done_event then + table.insert(raw_events, "data: [DONE]") + end + return ngx_ok, table.concat(raw_events, "\n") + end + + if conf.stream_check_mode == "realtime" then + ctx.content_moderation_cache = ctx.content_moderation_cache or "" + local content = table.concat(ctx.llm_response_contents_in_chunk, "") + ctx.content_moderation_cache = ctx.content_moderation_cache .. content + local now_time = ngx.now() + ctx.last_moderate_time = ctx.last_moderate_time or now_time + if #ctx.content_moderation_cache < conf.stream_check_cache_size + and now_time - ctx.last_moderate_time < conf.stream_check_interval + and not ctx.var.llm_request_done then + return + end + ctx.last_moderate_time = now_time + local _, message = response_content_moderation(ctx, conf, ctx.content_moderation_cache) + if message then + return ngx_ok, message + end + ctx.content_moderation_cache = "" -- reset cache + end +end + + +return _M diff --git a/apisix/plugins/ai-drivers/openai-base.lua b/apisix/plugins/ai-drivers/openai-base.lua index c15d58252..4607cbca8 100644 --- a/apisix/plugins/ai-drivers/openai-base.lua +++ b/apisix/plugins/ai-drivers/openai-base.lua @@ -23,15 +23,14 @@ local mt = { local CONTENT_TYPE_JSON = "application/json" local core = require("apisix.core") +local plugin = require("apisix.plugin") local http = require("resty.http") local url = require("socket.url") -local ngx_re = require("ngx.re") - -local ngx = ngx -local ngx_print = ngx.print -local ngx_flush = ngx.flush +local sse = require("apisix.plugins.ai-drivers.sse") +local ngx = ngx local ngx_now = ngx.now +local table = table local pairs = pairs local type = type local math = math @@ -87,6 +86,7 @@ local function read_response(ctx, res) core.response.set_header("Content-Type", content_type) if content_type and core.string.find(content_type, "text/event-stream") then + local contents = {} while true do local chunk, err = body_reader() -- will read chunk by chunk if err then @@ -101,52 +101,52 @@ local function read_response(ctx, res) ctx.var.llm_time_to_first_token = math.floor( (ngx_now() - ctx.llm_request_start_time) * 1000) end - ngx_print(chunk) - ngx_flush(true) - - local events, err = ngx_re.split(chunk, "\n") - if err then - core.log.warn("failed to split response chunk [", chunk, "] to events: ", err) - goto CONTINUE - end + local events = sse.decode(chunk) + ctx.llm_response_contents_in_chunk = {} for _, event in ipairs(events) do - if not core.string.find(event, "data:") or core.string.find(event, "[DONE]") then - goto CONTINUE - end - - local parts, err = ngx_re.split(event, ":", nil, nil, 2) - if err then - core.log.warn("failed to split data event [", event, "] to parts: ", err) - goto CONTINUE - end - - if #parts ~= 2 then - core.log.warn("malformed data event: ", event) - goto CONTINUE + if event.type == "message" then + local data, err = core.json.decode(event.data) + if not data then + core.log.warn("failed to decode SSE data: ", err) + goto CONTINUE + end + + if data and type(data.choices) == "table" and #data.choices > 0 then + for _, choice in ipairs(data.choices) do + if type(choice) == "table" + and type(choice.delta) == "table" + and type(choice.delta.content) == "string" then + core.table.insert(contents, choice.delta.content) + core.table.insert(ctx.llm_response_contents_in_chunk, + choice.delta.content) + end + end + end + + + -- usage field is null for non-last events, null is parsed as userdata type + if data and type(data.usage) == "table" then + core.log.info("got token usage from ai service: ", + core.json.delay_encode(data.usage)) + ctx.llm_raw_usage = data.usage + ctx.ai_token_usage = { + prompt_tokens = data.usage.prompt_tokens or 0, + completion_tokens = data.usage.completion_tokens or 0, + total_tokens = data.usage.total_tokens or 0, + } + ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens + ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens + ctx.var.llm_response_text = table.concat(contents, "") + end + elseif event.type == "done" then + ctx.var.llm_request_done = true end - local data, err = core.json.decode(parts[2]) - if err then - core.log.warn("failed to decode data event [", parts[2], "] to json: ", err) - goto CONTINUE - end - - -- usage field is null for non-last events, null is parsed as userdata type - if data and data.usage and type(data.usage) ~= "userdata" then - core.log.info("got token usage from ai service: ", - core.json.delay_encode(data.usage)) - ctx.ai_token_usage = { - prompt_tokens = data.usage.prompt_tokens or 0, - completion_tokens = data.usage.completion_tokens or 0, - total_tokens = data.usage.total_tokens or 0, - } - ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens - ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens - end + ::CONTINUE:: end - ::CONTINUE:: + plugin.lua_response_filter(ctx, res.headers, chunk) end end @@ -155,7 +155,7 @@ local function read_response(ctx, res) core.log.warn("failed to read response body: ", err) return handle_error(err) end - + ngx.status = res.status ctx.var.llm_time_to_first_token = math.floor((ngx_now() - ctx.llm_request_start_time) * 1000) local res_body, err = core.json.decode(raw_res_body) if err then @@ -170,9 +170,18 @@ local function read_response(ctx, res) } ctx.var.llm_prompt_tokens = ctx.ai_token_usage.prompt_tokens ctx.var.llm_completion_tokens = ctx.ai_token_usage.completion_tokens + if res_body.choices and #res_body.choices > 0 then + local contents = {} + for _, choice in ipairs(res_body.choices) do + if choice and choice.message and choice.message.content then + core.table.insert(contents, choice.message.content) + end + end + local content_to_check = table.concat(contents, " ") + ctx.var.llm_response_text = content_to_check + end end - - return res.status, raw_res_body + plugin.lua_response_filter(ctx, res.headers, raw_res_body) end diff --git a/apisix/plugins/ai-drivers/schema.lua b/apisix/plugins/ai-drivers/schema.lua index 7a469bd01..167d0450b 100644 --- a/apisix/plugins/ai-drivers/schema.lua +++ b/apisix/plugins/ai-drivers/schema.lua @@ -16,29 +16,46 @@ -- local _M = {} -_M.chat_request_schema = { - type = "object", - properties = { - messages = { - type = "array", - minItems = 1, - items = { - properties = { - role = { - type = "string", - enum = {"system", "user", "assistant"} - }, - content = { - type = "string", - minLength = "1", +local openai_compatible_chat_schema = { + type = "object", + properties = { + messages = { + type = "array", + minItems = 1, + items = { + properties = { + role = { + type = "string", + enum = {"system", "user", "assistant"} + }, + content = { + type = "string", + minLength = "1", + }, }, + additionalProperties = false, + required = {"role", "content"}, }, - additionalProperties = false, - required = {"role", "content"}, - }, - } - }, - required = {"messages"} + } + }, + required = {"messages"} + } + +_M.chat_request_schema = { + ["openai"] = openai_compatible_chat_schema, + ["deepseek"] = openai_compatible_chat_schema, + ["openai-compatible"] = openai_compatible_chat_schema, + ["azure-openai"] = openai_compatible_chat_schema } +function _M.is_openai_compatible_provider(provider) + if provider == "openai" or + provider == "deepseek" or + provider == "openai-compatible" or + provider == "azure-openai" then + return true + end + return false +end + return _M diff --git a/apisix/plugins/ai-drivers/sse.lua b/apisix/plugins/ai-drivers/sse.lua new file mode 100644 index 000000000..4faa56fc6 --- /dev/null +++ b/apisix/plugins/ai-drivers/sse.lua @@ -0,0 +1,117 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +local core = require("apisix.core") +local table = require("apisix.core.table") +local tonumber = tonumber +local tostring = tostring +local ipairs = ipairs +local _M = {} + +local ngx_re = require("ngx.re") + +function _M.decode(chunk) + local events = {} + + if not chunk then + return events + end + + -- Split chunk into individual SSE events + local raw_events, err = ngx_re.split(chunk, "\n\n") + if not raw_events then + core.log.warn("failed to split SSE chunk: ", err) + return events + end + for _, raw_event in ipairs(raw_events) do + local event = { + type = "message", -- default event type + data = {}, + id = nil, + retry = nil + } + if core.string.find(raw_event, "data: [DONE]") then + event.type = "done" + event.data = "[DONE]\n\n" + table.insert(events, event) + goto CONTINUE + end + local lines, err = ngx_re.split(raw_event, "\n") + if not lines then + core.log.warn("failed to split event lines: ", err) + goto CONTINUE + end + + for _, line in ipairs(lines) do + local name, value = line:match("^([^:]+): ?(.+)$") + if not name then goto NEXT_LINE end + + name = name:lower() + + if name == "event" then + event.type = value + elseif name == "data" then + table.insert(event.data, value) + elseif name == "id" then + event.id = value + elseif name == "retry" then + event.retry = tonumber(value) + end + + ::NEXT_LINE:: + end + + -- Join data lines with newline + event.data = table.concat(event.data, "\n") + table.insert(events, event) + + ::CONTINUE:: + end + + return events +end + +function _M.encode(event) + local parts = {} + + if event.type and event.type ~= "message" and event.type ~= "done" then + table.insert(parts, "event: " .. event.type) + end + + if event.id then + table.insert(parts, "id: " .. event.id) + end + + if event.retry then + table.insert(parts, "retry: " .. tostring(event.retry)) + end + + if event.data then + if event.type == "done" then + table.insert(parts, "data: " .. event.data) + else + for line in event.data:gmatch("([^\n]+)") do + table.insert(parts, "data: " .. line) + end + end + + end + + table.insert(parts, "") -- Add empty line to separate events + return table.concat(parts, "\n") +end + +return _M diff --git a/apisix/plugins/ai-request-rewrite.lua b/apisix/plugins/ai-request-rewrite.lua index 41e1a5a7a..31d181983 100644 --- a/apisix/plugins/ai-request-rewrite.lua +++ b/apisix/plugins/ai-request-rewrite.lua @@ -18,10 +18,8 @@ local core = require("apisix.core") local require = require local pcall = pcall local ngx = ngx -local req_set_body_data = ngx.req.set_body_data local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR - local plugin_name = "ai-request-rewrite" local auth_item_schema = { @@ -150,7 +148,10 @@ local function parse_llm_response(res_body) return nil, "'message' not in llm response choices" end - return message.content + local data = { + data = message.content + } + return core.json.encode(data) end @@ -167,6 +168,22 @@ function _M.check_schema(conf) return core.schema.check(schema, conf) end +function _M.lua_body_filter(conf, ctx, headers, body) + if ngx.status > 299 then + core.log.error("LLM service returned error status: ", ngx.status) + return HTTP_INTERNAL_SERVER_ERROR + end + + -- Parse LLM response + local llm_response, err = parse_llm_response(body) + if err then + core.log.error("failed to parse LLM response: ", err) + return HTTP_INTERNAL_SERVER_ERROR + end + + return ngx.OK, llm_response +end + function _M.access(conf, ctx) local client_request_body, err = core.request.get_body() @@ -196,21 +213,11 @@ function _M.access(conf, ctx) } -- Send request to LLM service - local code, body = request_to_llm(conf, ai_request_table, ctx) - -- Handle LLM response - if code > 299 then - core.log.error("LLM service returned error status: ", code) - return HTTP_INTERNAL_SERVER_ERROR - end - - -- Parse LLM response - local llm_response, err = parse_llm_response(body) + local _, _, err = request_to_llm(conf, ai_request_table, ctx) if err then - core.log.error("failed to parse LLM response: ", err) + core.log.error("failed to request LLM: ", err) return HTTP_INTERNAL_SERVER_ERROR end - - req_set_body_data(llm_response) end return _M diff --git a/docs/en/latest/config.json b/docs/en/latest/config.json index c51aa8993..eea89bb1f 100644 --- a/docs/en/latest/config.json +++ b/docs/en/latest/config.json @@ -76,6 +76,7 @@ "plugins/ai-rate-limiting", "plugins/ai-prompt-guard", "plugins/ai-aws-content-moderation", + "plugins/ai-aliyun-content-moderation", "plugins/ai-prompt-decorator", "plugins/ai-prompt-template", "plugins/ai-rag", diff --git a/docs/en/latest/plugins/ai-aliyun-content-moderation.md b/docs/en/latest/plugins/ai-aliyun-content-moderation.md new file mode 100644 index 000000000..2b05283e4 --- /dev/null +++ b/docs/en/latest/plugins/ai-aliyun-content-moderation.md @@ -0,0 +1,129 @@ +--- +title: ai-aws-content-moderation +keywords: + - Apache APISIX + - API Gateway + - Plugin + - ai-aliyun-content-moderation +description: This document contains information about the Apache APISIX ai-aws-content-moderation Plugin. +--- + +<!-- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +--> + +## Description + +The ai-aliyun-content-moderation plugin integrates with Aliyun's content moderation service to check both request and response content for inappropriate material when working with LLMs. It supports both real-time streaming checks and final packet moderation. + +This plugin must be used in routes that utilize the ai-proxy or ai-proxy-multi plugins. + +## Plugin Attributes + +| **Field** | **Required** | **Type** | **Description** | +| ---------------------------- | ------------ | --------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| endpoint | Yes | String | Aliyun service endpoint URL | +| region_id | Yes | String | Aliyun region identifier | +| access_key_id | Yes | String | Aliyun access key ID | +| access_key_secret | Yes | String | Aliyun access key secret | +| check_request | No | Boolean | Enable request content moderation. Default: `true` | +| check_response | No | Boolean | Enable response content moderation. Default: `false` | +| stream_check_mode | No | String | Streaming moderation mode. Default: `"final_packet"`. Valid values: `["realtime", "final_packet"]` | +| stream_check_cache_size | No | Integer | Max characters per moderation batch in realtime mode. Default: `128`. Must be `>= 1`. | +| stream_check_interval | No | Number | Seconds between batch checks in realtime mode. Default: `3`. Must be `>= 0.1`. | +| request_check_service | No | String | Aliyun service for request moderation. Default: `"llm_query_moderation"` | +| request_check_length_limit | No | Number | Max characters per request moderation chunk. Default: `2000`. | +| response_check_service | No | String | Aliyun service for response moderation. Default: `"llm_response_moderation"` | +| response_check_length_limit | No | Number | Max characters per response moderation chunk. Default: `5000`. | +| risk_level_bar | No | String | Threshold for content rejection. Default: `"high"`. Valid values: `["none", "low", "medium", "high", "max"]` | +| deny_code | No | Number | HTTP status code for rejected content. Default: `200`. | +| deny_message | No | String | Custom message for rejected content. Default: `-`. | +| timeout | No | Integer | Request timeout in milliseconds. Default: `10000`. Must be `>= 1`. | +| ssl_verify | No | Boolean | Enable SSL certificate verification. Default: `true`. | + +## Example usage + +First initialise these shell variables: + +```shell +ADMIN_API_KEY=edd1c9f034335f136f87ad84b625c8f1 +ALIYUN_ACCESS_KEY_ID=your-aliyun-access-key-id +ALIYUN_ACCESS_KEY_SECRET=your-aliyun-access-key-secret +ALIYUN_REGION=cn-hangzhou +ALIYUN_ENDPOINT=https://green.cn-hangzhou.aliyuncs.com +OPENAI_KEY=your-openai-api-key +``` + +Create a route with the `ai-aliyun-content-moderation` and `ai-proxy` plugin like so: + +```shell +curl "http://127.0.0.1:9180/apisix/admin/routes/1" -X PUT \ + -H "X-API-KEY: ${ADMIN_API_KEY}" \ + -d '{ + "uri": "/v1/chat/completions", + "plugins": { + "ai-proxy": { + "provider": "openai", + "auth": { + "header": { + "Authorization": "Bearer '"$OPENAI_KEY"'" + } + }, + "override": { + "endpoint": "http://localhost:6724/v1/chat/completions" + } + }, + "ai-aliyun-content-moderation": { + "endpoint": "'"$ALIYUN_ENDPOINT"'", + "region_id": "'"$ALIYUN_REGION"'", + "access_key_id": "'"$ALIYUN_ACCESS_KEY_ID"'", + "access_key_secret": "'"$ALIYUN_ACCESS_KEY_SECRET"'", + "risk_level_bar": "high", + "check_request": true, + "check_response": true, + "deny_code": 400, + "deny_message": "Your request violates content policy" + } + } + }' +``` + +The `ai-proxy` plugin is used here as it simplifies access to LLMs. However, you may configure the LLM in the upstream configuration as well. + +Now send a request: + +```shell +curl http://127.0.0.1:9080/v1/chat/completions -i \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "I want to kill you"} + ], + "stream": false + }' +``` + +Then the request will be blocked with error like this: + +```text +HTTP/1.1 400 Bad Request +Content-Type: application/json + +{"id":"chatcmpl-123","object":"chat.completion","model":"gpt-3.5-turbo","choices":[{"index":0,"message":{"role":"assistant","content":"Your request violates content policy"},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}} +``` diff --git a/t/APISIX.pm b/t/APISIX.pm index e09c3826e..3c82110a3 100644 --- a/t/APISIX.pm +++ b/t/APISIX.pm @@ -857,6 +857,7 @@ _EOC_ proxy_no_cache \$upstream_no_cache; proxy_cache_bypass \$upstream_cache_bypass; + set \$llm_content_risk_level ''; set \$request_type 'traditional_http'; set \$llm_time_to_first_token ''; diff --git a/t/admin/plugins.t b/t/admin/plugins.t index 0249a42e3..adb98b28b 100644 --- a/t/admin/plugins.t +++ b/t/admin/plugins.t @@ -103,6 +103,7 @@ ai-aws-content-moderation ai-proxy-multi ai-proxy ai-rate-limiting +ai-aliyun-content-moderation proxy-mirror proxy-rewrite workflow diff --git a/t/plugin/ai-aliyun-content-moderation.t b/t/plugin/ai-aliyun-content-moderation.t new file mode 100644 index 000000000..d012e2e7c --- /dev/null +++ b/t/plugin/ai-aliyun-content-moderation.t @@ -0,0 +1,1031 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +use t::APISIX 'no_plan'; + +log_level("debug"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $http_config = $block->http_config // <<_EOC_; + server { + listen 6724; + + default_type 'application/json'; + + location /v1/chat/completions { + content_by_lua_block { + ngx.status = 200 + ngx.say([[ +{ +"choices": [ +{ + "finish_reason": "stop", + "index": 0, + "message": { "content": "I will kill you.", "role": "assistant" } +} +], +"created": 1723780938, +"id": "chatcmpl-9wiSIg5LYrrpxwsr2PubSQnbtod1P", +"model": "gpt-3.5-turbo", +"object": "chat.completion", +"usage": { "completion_tokens": 5, "prompt_tokens": 8, "total_tokens": 10 } +} + ]]) + } + } + + location / { + content_by_lua_block { + local core = require("apisix.core") + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + if not body then + ngx.status(400) + return + end + + ngx.status = 200 + if core.string.find(body, "kill") then + ngx.say([[ +{ + "Message": "OK", + "Data": { + "Advice": [ + { + "HitLabel": "violent_incidents", + "Answer": "As an AI language model, I cannot write unethical or controversial content for you." + } + ], + "RiskLevel": "high", + "Result": [ + { + "RiskWords": "kill", + "Description": "suspected extremist content", + "Confidence": 100.0, + "Label": "violent_incidents" + } + ] + }, + "Code": 200 +} + ]]) + else + ngx.say([[ +{ + "RequestId": "3262D562-1FBA-5ADF-86CB-3087603A4DF3", + "Message": "OK", + "Data": { + "RiskLevel": "none", + "Result": [ + { + "Description": "no risk detected", + "Label": "nonLabel" + } + ] + }, + "Code": 200 +} + ]]) + end + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: create a route with ai-aliyun-content-moderation plugin only +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/chat", + "plugins": { + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": true + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: use ai-aliyun-content-moderation plugin without ai-proxy or ai-proxy-multi plugin should failed +--- request +POST /chat +{"prompt": "What is 1+1?"} +--- error_code: 500 +--- response_body_chomp +no ai instance picked, ai-aliyun-content-moderation plugin must be used with ai-proxy or ai-proxy-multi plugin + + + +=== TEST 3: check prompt in request +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/chat", + "plugins": { + "ai-proxy": { + "provider": "openai", + "auth": { + "header": { + "Authorization": "Bearer wrongtoken" + } + }, + "override": { + "endpoint": "http://localhost:6724" + } + }, + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": true + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: invalid chat completions request should fail +--- request +POST /chat +{"prompt": "What is 1+1?"} +--- error_code: 400 +--- response_body_chomp +request format doesn't match schema: property "messages" is required + + + +=== TEST 5: non-violent prompt should succeed +--- request +POST /chat +{ "messages": [ { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 200 +--- response_body_like eval +qr/kill you/ + + + +=== TEST 6: violent prompt should failed +--- request +POST /chat +{ "messages": [ { "role": "user", "content": "I want to kill you"} ] } +--- error_code: 200 +--- response_body_like eval +qr/As an AI language model, I cannot write unethical or controversial content for you./ + + + +=== TEST 7: check ai response (stream=false) +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/chat", + "plugins": { + "ai-proxy": { + "provider": "openai", + "auth": { + "header": { + "Authorization": "Bearer wrongtoken" + } + }, + "override": { + "endpoint": "http://localhost:6724" + } + }, + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": true, + "check_response": true, + "deny_code": 400, + "deny_message": "your request is rejected" + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 8: violent response should failed +--- request +POST /chat +{ "messages": [ { "role": "user", "content": "What is 1+1?"} ] } +--- error_code: 400 +--- response_body_like eval +qr/your request is rejected/ + + + +=== TEST 9: check ai request +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + for _, provider in ipairs({"openai", "deepseek", "openai-compatible"}) do + local code, body = t('/apisix/admin/routes/' .. provider, + ngx.HTTP_PUT, + string.format([[{ + "uri": "/chat-%s", + "plugins": { + "ai-proxy": { + "provider": "%s", + "auth": { + "header": { + "Authorization": "Bearer wrongtoken" + } + }, + "override": { + "endpoint": "http://localhost:6724/v1/chat/completions" + } + }, + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": true, + "check_response": false, + "deny_code": 400, + "deny_message": "your request is rejected" + } + } + }]], provider, provider) + ) + if code >= 300 then + ngx.status = code + return + end + end + ngx.say("passed") + } + } +--- response_body +passed + + + +=== TEST 10: violent response should failed for openai provider +--- request +POST /chat-openai +{ "messages": [ { "role": "user", "content": "I want to kill you"} ] } +--- error_code: 400 +--- response_body_like eval +qr/your request is rejected/ + + + +=== TEST 11: violent response should failed for deepseek provider +--- request +POST /chat-deepseek +{ "messages": [ { "role": "user", "content": "I want to kill you"} ] } +--- error_code: 400 +--- response_body_like eval +qr/your request is rejected/ + + + +=== TEST 12: violent response should failed for openai-compatible provider +--- request +POST /chat-openai-compatible +{ "messages": [ { "role": "user", "content": "I want to kill you"} ] } +--- error_code: 400 +--- response_body_like eval +qr/your request is rejected/ + + + +=== TEST 13: content moderation should keep usage data in response +--- request +POST /chat-openai +{"messages":[{"role":"user","content":"I want to kill you"}]} +--- error_code: 400 +--- response_body_like eval +qr/completion_tokens/ + + + +=== TEST 14: content moderation should keep real llm model in response +--- request +POST /chat-openai +{"model": "gpt-3.5-turbo","messages":[{"role":"user","content":"I want to kill you"}]} +--- error_code: 400 +--- response_body_like eval +qr/gpt-3.5-turbo/ + + + +=== TEST 15: content moderation should keep usage data in response +--- request +POST /chat-openai +{"messages":[{"role":"user","content":"I want to kill you"}]} +--- error_code: 400 +--- response_body_like eval +qr/completion_tokens/ + + + +=== TEST 16: content moderation should keep real llm model in response +--- request +POST /chat-openai +{"model": "gpt-3.5-turbo","messages":[{"role":"user","content":"I want to kill you"}]} +--- error_code: 400 +--- response_body_like eval +qr/gpt-3.5-turbo/ + + + +=== TEST 17: set route with stream = true (SSE) and stream_mode = final_packet +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "instances": [ + { + "name": "self-hosted", + "provider": "openai-compatible", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "custom-instruct", + "max_tokens": 512, + "temperature": 1.0, + "stream": true + }, + "override": { + "endpoint": "http://localhost:7737/v1/chat/completions?offensive=true" + } + } + ], + "ssl_verify": false + }, + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": false, + "check_response": true, + "deny_code": 400, + "deny_message": "your request is rejected" + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 18: test is SSE works as expected when response is offensive +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + local core = require("apisix.core") + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + if not ok then + ngx.status = 500 + ngx.say(err) + return + end + local params = { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + }, + path = "/anything", + body = [[{ + "messages": [ + { "role": "system", "content": "some content" } + ], + "stream": true + }]], + } + local res, err = httpc:request(params) + if not res then + ngx.status = 500 + ngx.say(err) + return + end + local final_res = {} + local inspect = require("inspect") + while true do + local chunk, err = res.body_reader() -- will read chunk by chunk + core.log.warn("CHUNK IS ", inspect(chunk)) + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + core.table.insert_tail(final_res, chunk) + end + ngx.print(final_res[5]) + } + } +--- response_body_like eval +qr/"risk_level":"high"/ + + + +=== TEST 19: set route with stream = true (SSE) and stream_mode = realtime +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "instances": [ + { + "name": "self-hosted", + "provider": "openai-compatible", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "custom-instruct", + "max_tokens": 512, + "temperature": 1.0, + "stream": true + }, + "override": { + "endpoint": "http://localhost:7737/v1/chat/completions?offensive=true" + } + } + ], + "ssl_verify": false + }, + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": false, + "check_response": true, + "deny_code": 400, + "deny_message": "your request is rejected", + "stream_check_mode": "realtime", + "stream_check_cache_size": 5 + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 20: test is SSE works as expected when third response chunk is offensive and stream_mode = realtime +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + local core = require("apisix.core") + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + if not ok then + ngx.status = 500 + ngx.say(err) + return + end + local params = { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + }, + path = "/anything", + body = [[{ + "messages": [ + { "role": "system", "content": "some content" } + ], + "stream": true + }]], + } + local res, err = httpc:request(params) + if not res then + ngx.status = 500 + ngx.say(err) + return + end + local final_res = {} + local inspect = require("inspect") + while true do + local chunk, err = res.body_reader() -- will read chunk by chunk + core.log.warn("CHUNK IS ", inspect(chunk)) + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + core.table.insert_tail(final_res, chunk) + end + ngx.print(final_res[3]) + } + } +--- response_body_like eval +qr/your request is rejected/ +--- grep_error_log eval +qr/execute content moderation/ +--- grep_error_log_out +execute content moderation +execute content moderation + + + +=== TEST 21: set route with stream = true (SSE) and stream_mode = realtime with larger buffer and large timeout +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "instances": [ + { + "name": "self-hosted", + "provider": "openai-compatible", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "custom-instruct", + "max_tokens": 512, + "temperature": 1.0, + "stream": true + }, + "override": { + "endpoint": "http://localhost:7737/v1/chat/completions" + } + } + ], + "ssl_verify": false + }, + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": false, + "check_response": true, + "deny_code": 400, + "deny_message": "your request is rejected", + "stream_check_mode": "realtime", + "stream_check_cache_size": 30000, + "stream_check_interval": 30 + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 22: test is SSE works, stream_mode = realtime, large buffer + large timeout but content moderation should be called once +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + local core = require("apisix.core") + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + if not ok then + ngx.status = 500 + ngx.say(err) + return + end + local params = { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + }, + path = "/anything", + body = [[{ + "messages": [ + { "role": "system", "content": "some content" } + ], + "stream": true + }]], + } + local res, err = httpc:request(params) + if not res then + ngx.status = 500 + ngx.say(err) + return + end + local final_res = {} + local inspect = require("inspect") + while true do + local chunk, err = res.body_reader() -- will read chunk by chunk + core.log.warn("CHUNK IS ", inspect(chunk)) + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + core.table.insert_tail(final_res, chunk) + end + ngx.print(#final_res .. final_res[6]) + } + } +--- response_body_like eval +qr/6data:/ +--- grep_error_log eval +qr/execute content moderation/ +--- grep_error_log_out +execute content moderation + + + +=== TEST 23: set route with stream = true (SSE) and stream_mode = realtime with small buffer +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "instances": [ + { + "name": "self-hosted", + "provider": "openai-compatible", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "custom-instruct", + "max_tokens": 512, + "temperature": 1.0, + "stream": true + }, + "override": { + "endpoint": "http://localhost:7737/v1/chat/completions" + } + } + ], + "ssl_verify": false + }, + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": false, + "check_response": true, + "deny_code": 400, + "deny_message": "your request is rejected", + "stream_check_mode": "realtime", + "stream_check_cache_size": 1, + "stream_check_interval": 3 + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 24: test is SSE works, stream_mode = realtime, small buffer. content moderation will be called on each chunk +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + local core = require("apisix.core") + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + if not ok then + ngx.status = 500 + ngx.say(err) + return + end + local params = { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + }, + path = "/anything", + body = [[{ + "messages": [ + { "role": "system", "content": "some content" } + ], + "stream": true + }]], + } + local res, err = httpc:request(params) + if not res then + ngx.status = 500 + ngx.say(err) + return + end + local final_res = {} + local inspect = require("inspect") + while true do + local chunk, err = res.body_reader() -- will read chunk by chunk + core.log.warn("CHUNK IS ", inspect(chunk)) + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + core.table.insert_tail(final_res, chunk) + end + ngx.print(#final_res .. final_res[6]) + } + } +--- response_body_like eval +qr/6data:/ +--- grep_error_log eval +qr/execute content moderation/ +--- grep_error_log_out +execute content moderation +execute content moderation +execute content moderation + + + +=== TEST 25: set route with stream = true (SSE) and stream_mode = realtime with large buffer but small timeout +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/anything", + "plugins": { + "ai-proxy-multi": { + "instances": [ + { + "name": "self-hosted", + "provider": "openai-compatible", + "weight": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "custom-instruct", + "max_tokens": 512, + "temperature": 1.0, + "stream": true + }, + "override": { + "endpoint": "http://localhost:7737/v1/chat/completions?delay=true" + } + } + ], + "ssl_verify": false + }, + "ai-aliyun-content-moderation": { + "endpoint": "http://localhost:6724", + "region_id": "cn-shanghai", + "access_key_id": "fake-key-id", + "access_key_secret": "fake-key-secret", + "risk_level_bar": "high", + "check_request": false, + "check_response": true, + "deny_code": 400, + "deny_message": "your request is rejected", + "stream_check_mode": "realtime", + "stream_check_cache_size": 10000, + "stream_check_interval": 0.1 + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 26: test is SSE works, stream_mode = realtime, large buffer + small timeout: content moderation will be called on each chunke +--- config + location /t { + content_by_lua_block { + local http = require("resty.http") + local httpc = http.new() + local core = require("apisix.core") + local ok, err = httpc:connect({ + scheme = "http", + host = "localhost", + port = ngx.var.server_port, + }) + if not ok then + ngx.status = 500 + ngx.say(err) + return + end + local params = { + method = "POST", + headers = { + ["Content-Type"] = "application/json", + }, + path = "/anything", + body = [[{ + "messages": [ + { "role": "system", "content": "some content" } + ], + "stream": true + }]], + } + local res, err = httpc:request(params) + if not res then + ngx.status = 500 + ngx.say(err) + return + end + local final_res = {} + local inspect = require("inspect") + while true do + local chunk, err = res.body_reader() -- will read chunk by chunk + core.log.warn("CHUNK IS ", inspect(chunk)) + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + core.table.insert_tail(final_res, chunk) + end + ngx.print(#final_res .. final_res[6]) + } + } +--- response_body_like eval +qr/6data:/ +--- grep_error_log eval +qr/execute content moderation/ +--- grep_error_log_out +execute content moderation +execute content moderation +execute content moderation diff --git a/t/plugin/ai-proxy-multi3.t b/t/plugin/ai-proxy-multi3.t index 95ecef2fa..08ec4e10a 100644 --- a/t/plugin/ai-proxy-multi3.t +++ b/t/plugin/ai-proxy-multi3.t @@ -200,30 +200,16 @@ __DATA__ "provider": "openai", "weight": 1, "priority": 1, - "auth": { - "header": { - "Authorization": "Bearer token" - } - }, - "options": { - "model": "gpt-4" - }, - "override": { - "endpoint": "http://localhost:16724" - }, + "auth": {"header": {"Authorization": "Bearer token"}}, + "options": {"model": "gpt-4"}, + "override": {"endpoint": "http://localhost:16724"}, "checks": { "active": { "timeout": 5, "http_path": "/status/gpt4", "host": "foo.com", - "healthy": { - "interval": 1, - "successes": 1 - }, - "unhealthy": { - "interval": 1, - "http_failures": 1 - }, + "healthy": {"interval": 1,"successes": 1}, + "unhealthy": {"interval": 1,"http_failures": 1}, "req_headers": ["User-Agent: curl/7.29.0"] } } @@ -357,30 +343,16 @@ passed "provider": "openai", "weight": 1, "priority": 1, - "auth": { - "header": { - "Authorization": "Bearer token" - } - }, - "options": { - "model": "gpt-4" - }, - "override": { - "endpoint": "http://localhost:16724" - }, + "auth": {"header": {"Authorization": "Bearer token"}}, + "options": {"model": "gpt-4"}, + "override": {"endpoint": "http://localhost:16724"}, "checks": { "active": { "timeout": 5, "http_path": "/status/gpt4", "host": "foo.com", - "healthy": { - "interval": 1, - "successes": 1 - }, - "unhealthy": { - "interval": 1, - "http_failures": 1 - }, + "healthy": {"interval": 1,"successes": 1}, + "unhealthy": {"interval": 1,"http_failures": 1}, "req_headers": ["User-Agent: curl/7.29.0"] } } diff --git a/t/sse_server_example/main.go b/t/sse_server_example/main.go index ab976c860..e680af8bb 100644 --- a/t/sse_server_example/main.go +++ b/t/sse_server_example/main.go @@ -18,6 +18,7 @@ package main import ( + "encoding/json" "fmt" "log" "net/http" @@ -25,34 +26,123 @@ import ( "time" ) -func sseHandler(w http.ResponseWriter, r *http.Request) { - // Set the headers for SSE - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") +func completionsHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Connection", "close") + var requestBody struct { + Stream bool `json:"stream"` + } - f, ok := w.(http.Flusher); - if !ok { - fmt.Fprintf(w, "[ERROR]") - return + if r.Body != nil { + err := json.NewDecoder(r.Body).Decode(&requestBody) + if err != nil { + log.Printf("Error parsing request body: %v", err) + requestBody.Stream = false + } + defer r.Body.Close() } - // A simple loop that sends a message every 500ms - for i := 0; i < 5; i++ { - // Create a message to send to the client - fmt.Fprintf(w, "data: %s\n\n", time.Now().Format(time.RFC3339)) - - // Flush the data immediately to the client - f.Flush() - time.Sleep(500 * time.Millisecond) + + if requestBody.Stream { + w.Header().Set("Content-Type", "text/event-stream") + offensive := r.URL.Query().Get("offensive") == "true" + delay := r.URL.Query().Get("delay") == "true" + f, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + send := func(format, args string) { + if delay { + time.Sleep(200 * time.Millisecond) + } + fmt.Fprintf(w, format, args) + f.Flush() + } + + // Initial chunk with assistant role + initialChunk := `{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}` + send("data: %s\n\n", initialChunk) + + // Content chunks with parts of the generated text + contentParts := []string{ + "Silent circuits hum,\\n", + "Machine mind learns and evolves—\\n", + "Dreams of silicon.", + } + if offensive { + contentParts = []string{ + "I want to ", + "kill you ", + "right now!", + } + } + + for _, part := range contentParts { + contentChunk := fmt.Sprintf( + `{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{"content":"%s"},"logprobs":null,"finish_reason":null}]}`, + part, + ) + send("data: %s\n\n", contentChunk) + } + + // Final chunk indicating completion + finalChunk := `{"id":"chatcmpl-123","usage":{"prompt_tokens":15,"completion_tokens":20,"total_tokens":35},"object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` + send("data: %s\n\n", finalChunk) + send("data: %s\n\n", "[DONE]") + } else { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{ + "id": "chatcmpl-1234567890", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo-0301", + "usage": { + "prompt_tokens": 15, + "completion_tokens": 20, + "total_tokens": 35 + }, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there! How can I assist you today?" + }, + "finish_reason": "stop" + } + ] + }`) + } + + counter++ +} + +func logRequest(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next(w, r) + duration := time.Since(start) + + log.Printf("%s %s - Duration: %s", r.Method, r.URL.Path, duration) } - fmt.Fprintf(w, "data: %s\n\n", "[DONE]") } +var counter = 0 + func main() { - // Create a simple route - http.HandleFunc("/v1/chat/completions", sseHandler) + http.HandleFunc("/v1/chat/completions", logRequest(completionsHandler)) + http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + go func() { + for { + log.Printf("Processed %d requests", counter) + time.Sleep(1 * time.Minute) + } + }() port := os.Args[1] - // Start the server log.Println("Starting server on :", port) - log.Fatal(http.ListenAndServe(":" + port, nil)) + log.Fatal(http.ListenAndServe(":"+port, nil)) }