This is an automated email from the ASF dual-hosted git repository. ashishtiwari pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/apisix.git
The following commit(s) were added to refs/heads/master by this push: new 8bb6802bd feat(ai-proxy-multi): add support for healthcheck (#12509) 8bb6802bd is described below commit 8bb6802bd2b2ccd4dabe78e90beb8f87ec046035 Author: Ashish Tiwari <ashishjaitiwari15112...@gmail.com> AuthorDate: Thu Aug 21 10:50:55 2025 +0530 feat(ai-proxy-multi): add support for healthcheck (#12509) --- apisix/healthcheck_manager.lua | 36 +- apisix/http/service.lua | 4 + apisix/plugin.lua | 26 ++ apisix/plugins/ai-proxy-multi.lua | 147 ++++++- apisix/plugins/ai-proxy/schema.lua | 9 + apisix/router.lua | 3 + apisix/schema_def.lua | 288 ++++++------ t/APISIX.pm | 1 + t/control/services.t | 2 +- t/plugin/ai-proxy-multi3.t | 865 +++++++++++++++++++++++++++++++++++++ 10 files changed, 1232 insertions(+), 149 deletions(-) diff --git a/apisix/healthcheck_manager.lua b/apisix/healthcheck_manager.lua index 066349829..8c93360fb 100644 --- a/apisix/healthcheck_manager.lua +++ b/apisix/healthcheck_manager.lua @@ -26,6 +26,8 @@ local healthcheck local events = require("apisix.events") local tab_clone = core.table.clone local timer_every = ngx.timer.every +local ngx_re = require('ngx.re') +local jp = require("jsonpath") local string_sub = string.sub local _M = {} @@ -58,8 +60,17 @@ local function remove_etcd_prefix(key) return string_sub(key, #prefix + 1) end +local function parse_path(resource_full_path) + local resource_path_parts = ngx_re.split(resource_full_path, "#") + local resource_path = resource_path_parts[1] or resource_full_path + local resource_sub_path = resource_path_parts[2] or "" + return resource_path, resource_sub_path +end local function fetch_latest_conf(resource_path) + -- if resource path contains json path, extract out the prefix + -- for eg: extracts /routes/1 from /routes/1#plugins.abc + resource_path = parse_path(resource_path) local resource_type, id -- Handle both formats: -- 1. /<etcd-prefix>/<resource_type>/<id> @@ -206,6 +217,15 @@ function _M.upstream_version(index, nodes_ver) end +local function get_plugin_name(path) + -- Extract JSON path (after '#') or use full path + local json_path = path:match("#(.+)$") or path + -- Match plugin name in the JSON path segment + return json_path:match("^plugins%['([^']+)'%]") + or json_path:match('^plugins%["([^"]+)"%]') + or json_path:match("^plugins%.([^%.]+)") +end + local function timer_create_checker() if core.table.nkeys(waiting_pool) == 0 then return @@ -224,7 +244,21 @@ local function timer_create_checker() if not res_conf then goto continue end - local upstream = res_conf.value.upstream or res_conf.value + local upstream + local plugin_name = get_plugin_name(resource_path) + if plugin_name and plugin_name ~= "" then + local _, sub_path = parse_path(resource_path) + local json_path = "$." .. sub_path + --- the users of the API pass the jsonpath(in resourcepath) to + --- upstream_constructor_config which is passed to the + --- callback construct_upstream to create an upstream dynamically + local upstream_constructor_config = jp.value(res_conf.value, json_path) + local plugin = require("apisix.plugins." .. plugin_name) + upstream = plugin.construct_upstream(upstream_constructor_config) + upstream.resource_key = resource_path + else + upstream = res_conf.value.upstream or res_conf.value + end local new_version = _M.upstream_version(res_conf.modifiedIndex, upstream._nodes_ver) core.log.info("checking waiting pool for resource: ", resource_path, " current version: ", new_version, " requested version: ", resource_ver) diff --git a/apisix/http/service.lua b/apisix/http/service.lua index 97b224d62..66bb21023 100644 --- a/apisix/http/service.lua +++ b/apisix/http/service.lua @@ -17,6 +17,7 @@ local core = require("apisix.core") local apisix_upstream = require("apisix.upstream") local plugin_checker = require("apisix.plugin").plugin_checker +local plugin = require("apisix.plugin") local services local error = error @@ -46,6 +47,9 @@ local function filter(service) return end + + plugin.set_plugins_meta_parent(service.value.plugins, service) + apisix_upstream.filter_upstream(service.value.upstream, service) core.log.info("filter service: ", core.json.delay_encode(service, true)) diff --git a/apisix/plugin.lua b/apisix/plugin.lua index 87f024d66..5eed30001 100644 --- a/apisix/plugin.lua +++ b/apisix/plugin.lua @@ -34,6 +34,8 @@ local type = type local local_plugins = core.table.new(32, 0) local tostring = tostring local error = error +local getmetatable = getmetatable +local setmetatable = setmetatable -- make linter happy to avoid error: getting the Lua global "load" -- luacheck: globals load, ignore lua_load local lua_load = load @@ -1234,6 +1236,30 @@ function _M.run_plugin(phase, plugins, api_ctx) return api_ctx, plugin_run end +function _M.set_plugins_meta_parent(plugins, parent) + if not plugins then + return + end + for _, plugin_conf in pairs(plugins) do + if not plugin_conf._meta then + plugin_conf._meta = {} + end + if not plugin_conf._meta.parent then + local parent_info = { + resource_key = parent.key, + resource_version = tostring(parent.modifiedIndex) + } + local mt_table = getmetatable(plugin_conf._meta) + if mt_table then + mt_table.parent = parent_info + else + plugin_conf._meta = setmetatable(plugin_conf._meta, + { __index = {parent = parent_info} }) + end + end + end +end + function _M.run_global_rules(api_ctx, global_rules, phase_name) if global_rules and #global_rules > 0 then diff --git a/apisix/plugins/ai-proxy-multi.lua b/apisix/plugins/ai-proxy-multi.lua index 9e9ee93e7..4c2dff582 100644 --- a/apisix/plugins/ai-proxy-multi.lua +++ b/apisix/plugins/ai-proxy-multi.lua @@ -19,6 +19,10 @@ local core = require("apisix.core") local schema = require("apisix.plugins.ai-proxy.schema") local base = require("apisix.plugins.ai-proxy.base") local plugin = require("apisix.plugin") +local ipmatcher = require("resty.ipmatcher") +local healthcheck_manager = require("apisix.healthcheck_manager") +local tonumber = tonumber +local pairs = pairs local require = require local pcall = pcall @@ -118,17 +122,100 @@ local function transform_instances(new_instances, instance) new_instances[instance.priority][instance.name] = instance.weight end +local function parse_domain_for_node(node) + local host = node.domain or node.host + if not ipmatcher.parse_ipv4(host) + and not ipmatcher.parse_ipv6(host) + then + node.domain = host -local function create_server_picker(conf, ups_tab) + local ip, err = core.resolver.parse_domain(host) + if ip then + node.host = ip + end + + if err then + core.log.error("dns resolver domain: ", host, " error: ", err) + end + end +end + + +local function resolve_endpoint(instance_conf) + local endpoint = core.table.try_read_attr(instance_conf, "override", "endpoint") + local scheme, host, port, _ = endpoint:match("^(https?)://([^:/]+):?(%d*)(/?.*)$") + if port == "" then + port = (scheme == "https") and "443" or "80" + end + local node = { + host = host, + port = tonumber(port), + scheme = scheme, + } + parse_domain_for_node(node) + return node +end + + +local function get_checkers_status_ver(checkers) + local status_ver_total = 0 + for _, checker in pairs(checkers) do + status_ver_total = status_ver_total + checker.status_ver + end + return status_ver_total +end + + + +local function fetch_health_instances(conf, checkers) + local instances = conf.instances + local new_instances = core.table.new(0, #instances) + if not checkers then + for _, ins in ipairs(conf.instances) do + transform_instances(new_instances, ins) + end + return new_instances + end + + for _, ins in ipairs(instances) do + local checker = checkers[ins.name] + if checker then + local host = ins.checks and ins.checks.active and ins.checks.active.host + local port = ins.checks and ins.checks.active and ins.checks.active.port + + local node = resolve_endpoint(ins) + local ok, err = checker:get_target_status(node.host, port or node.port, host) + if ok then + transform_instances(new_instances, ins) + elseif err then + core.log.error("failed to get health check target status, addr: ", + node.host, ":", port or node.port, ", host: ", host, ", err: ", err) + end + else + transform_instances(new_instances, ins) + end + end + + if core.table.nkeys(new_instances) == 0 then + core.log.warn("all upstream nodes is unhealthy, use default") + for _, ins in ipairs(instances) do + transform_instances(new_instances, ins) + end + end + + return new_instances +end + + +local function create_server_picker(conf, ups_tab, checkers) local picker = pickers[conf.balancer.algorithm] -- nil check if not picker then pickers[conf.balancer.algorithm] = require("apisix.balancer." .. conf.balancer.algorithm) picker = pickers[conf.balancer.algorithm] end - local new_instances = {} - for _, ins in ipairs(conf.instances) do - transform_instances(new_instances, ins) - end + + local new_instances = fetch_health_instances(conf, checkers) + core.log.info("fetch health instances: ", core.json.delay_encode(new_instances)) if #new_instances._priority_index > 1 then core.log.info("new instances: ", core.json.delay_encode(new_instances)) @@ -149,11 +236,57 @@ local function get_instance_conf(instances, name) end +function _M.construct_upstream(instance) + local upstream = {} + local node = resolve_endpoint(instance) + if not node then + return nil, "failed to resolve endpoint for instance: " .. instance.name + end + + if not node.host or not node.port then + return nil, "invalid upstream node: " .. core.json.encode(node) + end + + parse_domain_for_node(node) + + local node = { + host = node.host, + port = node.port, + scheme = node.scheme, + weight = instance.weight or 1, + priority = instance.priority or 0, + name = instance.name, + } + upstream.nodes = {node} + upstream.checks = instance.checks + return upstream +end + + local function pick_target(ctx, conf, ups_tab) + local checkers + for i, instance in ipairs(conf.instances) do + if instance.checks then + -- json path is 0 indexed so we need to decrement i + local resource_path = conf._meta.parent.resource_key .. + "#plugins['ai-proxy-multi'].instances[" .. i-1 .. "]" + local resource_version = conf._meta.parent.resource_version + local checker = healthcheck_manager.fetch_checker(resource_path, resource_version) + checkers = checkers or {} + checkers[instance.name] = checker + end + end + + local version = plugin.conf_version(conf) + if checkers then + local status_ver = get_checkers_status_ver(checkers) + version = version .. "#" .. status_ver + end + local server_picker = ctx.server_picker if not server_picker then - server_picker = lrucache_server_picker(ctx.matched_route.key, plugin.conf_version(conf), - create_server_picker, conf, ups_tab) + server_picker = lrucache_server_picker(ctx.matched_route.key, version, + create_server_picker, conf, ups_tab, checkers) end if not server_picker then return nil, nil, "failed to fetch server picker" diff --git a/apisix/plugins/ai-proxy/schema.lua b/apisix/plugins/ai-proxy/schema.lua index 0a3c0280d..3510dca69 100644 --- a/apisix/plugins/ai-proxy/schema.lua +++ b/apisix/plugins/ai-proxy/schema.lua @@ -14,6 +14,8 @@ -- See the License for the specific language governing permissions and -- limitations under the License. -- +local schema_def = require("apisix.schema_def") + local _M = {} local auth_item_schema = { @@ -88,6 +90,13 @@ local ai_instance_schema = { }, }, }, + checks = { + type = "object", + properties = { + active = schema_def.health_checker_active, + }, + required = {"active"} + } }, required = {"name", "provider", "auth", "weight"} }, diff --git a/apisix/router.lua b/apisix/router.lua index 93b123e5b..19244da32 100644 --- a/apisix/router.lua +++ b/apisix/router.lua @@ -18,6 +18,7 @@ local require = require local http_route = require("apisix.http.route") local apisix_upstream = require("apisix.upstream") local core = require("apisix.core") +local set_plugins_meta_parent = require("apisix.plugin").set_plugins_meta_parent local str_lower = string.lower local ipairs = ipairs @@ -33,6 +34,8 @@ local function filter(route) return end + set_plugins_meta_parent(route.value.plugins, route) + if route.value.host then route.value.host = str_lower(route.value.host) elseif route.value.hosts then diff --git a/apisix/schema_def.lua b/apisix/schema_def.lua index d8b620884..78f59fcd3 100644 --- a/apisix/schema_def.lua +++ b/apisix/schema_def.lua @@ -126,170 +126,178 @@ local timeout_def = { } -local health_checker = { +local health_checker_active = { type = "object", properties = { - active = { + type = { + type = "string", + enum = {"http", "https", "tcp"}, + default = "http" + }, + timeout = {type = "number", default = 1}, + concurrency = {type = "integer", default = 10}, + host = host_def, + port = { + type = "integer", + minimum = 1, + maximum = 65535 + }, + http_path = {type = "string", default = "/"}, + https_verify_certificate = {type = "boolean", default = true}, + healthy = { type = "object", properties = { - type = { - type = "string", - enum = {"http", "https", "tcp"}, - default = "http" + interval = {type = "integer", minimum = 1, default = 1}, + http_statuses = { + type = "array", + minItems = 1, + items = { + type = "integer", + minimum = 200, + maximum = 599 + }, + uniqueItems = true, + default = {200, 302} + }, + successes = { + type = "integer", + minimum = 1, + maximum = 254, + default = 2 + } + } + }, + unhealthy = { + type = "object", + properties = { + interval = {type = "integer", minimum = 1, default = 1}, + http_statuses = { + type = "array", + minItems = 1, + items = { + type = "integer", + minimum = 200, + maximum = 599 + }, + uniqueItems = true, + default = {429, 404, 500, 501, 502, 503, 504, 505} }, - timeout = {type = "number", default = 1}, - concurrency = {type = "integer", default = 10}, - host = host_def, - port = { + http_failures = { type = "integer", minimum = 1, - maximum = 65535 + maximum = 254, + default = 5 }, - http_path = {type = "string", default = "/"}, - https_verify_certificate = {type = "boolean", default = true}, - healthy = { - type = "object", - properties = { - interval = {type = "integer", minimum = 1, default = 1}, - http_statuses = { - type = "array", - minItems = 1, - items = { - type = "integer", - minimum = 200, - maximum = 599 - }, - uniqueItems = true, - default = {200, 302} - }, - successes = { - type = "integer", - minimum = 1, - maximum = 254, - default = 2 - } - } + tcp_failures = { + type = "integer", + minimum = 1, + maximum = 254, + default = 2 }, - unhealthy = { - type = "object", - properties = { - interval = {type = "integer", minimum = 1, default = 1}, - http_statuses = { - type = "array", - minItems = 1, - items = { - type = "integer", - minimum = 200, - maximum = 599 - }, - uniqueItems = true, - default = {429, 404, 500, 501, 502, 503, 504, 505} - }, - http_failures = { - type = "integer", - minimum = 1, - maximum = 254, - default = 5 - }, - tcp_failures = { - type = "integer", - minimum = 1, - maximum = 254, - default = 2 - }, - timeouts = { - type = "integer", - minimum = 1, - maximum = 254, - default = 3 - } - } + timeouts = { + type = "integer", + minimum = 1, + maximum = 254, + default = 3 + } + } + }, + req_headers = { + type = "array", + minItems = 1, + items = { + type = "string", + uniqueItems = true, + }, + } + } +} +_M.health_checker_active = health_checker_active + + +local health_checker_passive = { + type = "object", + properties = { + type = { + type = "string", + enum = {"http", "https", "tcp"}, + default = "http" + }, + healthy = { + type = "object", + properties = { + http_statuses = { + type = "array", + minItems = 1, + items = { + type = "integer", + minimum = 200, + maximum = 599, + }, + uniqueItems = true, + default = {200, 201, 202, 203, 204, 205, 206, 207, + 208, 226, 300, 301, 302, 303, 304, 305, + 306, 307, 308} }, - req_headers = { - type = "array", - minItems = 1, - items = { - type = "string", - uniqueItems = true, - }, + successes = { + type = "integer", + minimum = 0, + maximum = 254, + default = 5 } } }, - passive = { + unhealthy = { type = "object", properties = { - type = { - type = "string", - enum = {"http", "https", "tcp"}, - default = "http" + http_statuses = { + type = "array", + minItems = 1, + items = { + type = "integer", + minimum = 200, + maximum = 599, + }, + uniqueItems = true, + default = {429, 500, 503} }, - healthy = { - type = "object", - properties = { - http_statuses = { - type = "array", - minItems = 1, - items = { - type = "integer", - minimum = 200, - maximum = 599, - }, - uniqueItems = true, - default = {200, 201, 202, 203, 204, 205, 206, 207, - 208, 226, 300, 301, 302, 303, 304, 305, - 306, 307, 308} - }, - successes = { - type = "integer", - minimum = 0, - maximum = 254, - default = 5 - } - } + tcp_failures = { + type = "integer", + minimum = 0, + maximum = 254, + default = 2 }, - unhealthy = { - type = "object", - properties = { - http_statuses = { - type = "array", - minItems = 1, - items = { - type = "integer", - minimum = 200, - maximum = 599, - }, - uniqueItems = true, - default = {429, 500, 503} - }, - tcp_failures = { - type = "integer", - minimum = 0, - maximum = 254, - default = 2 - }, - timeouts = { - type = "integer", - minimum = 0, - maximum = 254, - default = 7 - }, - http_failures = { - type = "integer", - minimum = 0, - maximum = 254, - default = 5 - }, - } - } - }, + timeouts = { + type = "integer", + minimum = 0, + maximum = 254, + default = 7 + }, + http_failures = { + type = "integer", + minimum = 0, + maximum = 254, + default = 5 + }, + } } }, +} +_M.health_checker_passive = health_checker_passive + + +local health_checker = { + type = "object", + properties = { + active = health_checker_active, + passive = health_checker_passive, + }, anyOf = { {required = {"active"}}, {required = {"active", "passive"}}, }, - additionalProperties = false, } +_M.health_checker = health_checker local nodes_schema = { diff --git a/t/APISIX.pm b/t/APISIX.pm index a3273976b..e09c3826e 100644 --- a/t/APISIX.pm +++ b/t/APISIX.pm @@ -616,6 +616,7 @@ _EOC_ lua_shared_dict xds-config 1m; lua_shared_dict xds-config-version 1m; lua_shared_dict cas_sessions 10m; + lua_shared_dict test 5m; proxy_ssl_name \$upstream_host; proxy_ssl_server_name on; diff --git a/t/control/services.t b/t/control/services.t index 3a959fe4c..b0fa844be 100644 --- a/t/control/services.t +++ b/t/control/services.t @@ -157,7 +157,7 @@ services: } } --- response_body eval -qr/\{"id":"5","plugins":\{"limit-count":\{"allow_degradation":false,"count":2,"key":"remote_addr","key_type":"var","policy":"local","rejected_code":503,"show_limit_quota_header":true,"time_window":60\}\},"upstream":\{"hash_on":"vars","nodes":\[\{"host":"127.0.0.1","port":1980,"weight":1\}\],"pass_host":"pass",.*"scheme":"http","type":"roundrobin"\}\}/ +qr/\{"id":"5","plugins":\{"limit-count":\{"_meta":\{\},"allow_degradation":false,"count":2,"key":"remote_addr","key_type":"var","policy":"local","rejected_code":503,"show_limit_quota_header":true,"time_window":60\}\},"upstream":\{"hash_on":"vars","nodes":\[\{"host":"127.0.0.1","port":1980,"weight":1\}\],"pass_host":"pass",.*"scheme":"http","type":"roundrobin"\}\}/ diff --git a/t/plugin/ai-proxy-multi3.t b/t/plugin/ai-proxy-multi3.t new file mode 100644 index 000000000..95ecef2fa --- /dev/null +++ b/t/plugin/ai-proxy-multi3.t @@ -0,0 +1,865 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +BEGIN { + if ($ENV{TEST_EVENTS_MODULE} ne "lua-resty-worker-events") { + $SkipReason = "Only for lua-resty-worker-events events module"; + } +} +use Test::Nginx::Socket::Lua $SkipReason ? (skip_all => $SkipReason) : (); +use t::APISIX 'no_plan'; + +log_level("info"); +repeat_each(1); +no_long_string(); +no_root_location(); + + +add_block_preprocessor(sub { + my ($block) = @_; + + if (!defined $block->request) { + $block->set_value("request", "GET /t"); + } + + my $http_config = $block->http_config // <<_EOC_; + server { + server_name openai; + listen 16724; + + default_type 'application/json'; + + location /anything { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body = ngx.req.get_body_data() + + if body ~= "SELECT * FROM STUDENTS" then + ngx.status = 503 + ngx.say("passthrough doesn't work") + return + end + ngx.say('{"foo", "bar"}') + } + } + + location /v1/chat/completions { + content_by_lua_block { + local json = require("cjson.safe") + + if ngx.req.get_method() ~= "POST" then + ngx.status = 400 + ngx.say("Unsupported request method: ", ngx.req.get_method()) + end + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local test_type = ngx.req.get_headers()["test-type"] + if test_type == "options" then + if body.foo == "bar" then + ngx.status = 200 + ngx.say("options works") + else + ngx.status = 500 + ngx.say("model options feature doesn't work") + end + return + end + + local header_auth = ngx.req.get_headers()["authorization"] + local query_auth = ngx.req.get_uri_args()["apikey"] + + if header_auth ~= "Bearer token" and query_auth ~= "apikey" then + ngx.status = 401 + ngx.say("Unauthorized") + return + end + + if header_auth == "Bearer token" or query_auth == "apikey" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if not body.messages or #body.messages < 1 then + ngx.status = 400 + ngx.say([[{ "error": "bad request"}]]) + return + end + + if body.messages[1].content == "write an SQL query to get all rows from student table" then + ngx.print("SELECT * FROM STUDENTS") + return + end + + ngx.status = 200 + ngx.say(string.format([[ +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { "content": "1 + 1 = 2.", "role": "assistant" } + } + ], + "created": 1723780938, + "id": "chatcmpl-9wiSIg5LYrrpxwsr2PubSQnbtod1P", + "model": "%s", + "object": "chat.completion", + "system_fingerprint": "fp_abc28019ad", + "usage": { "completion_tokens": 5, "prompt_tokens": 8, "total_tokens": 10 } +} + ]], body.model)) + return + end + + + ngx.status = 503 + ngx.say("reached the end of the test suite") + } + } + + location /random { + content_by_lua_block { + ngx.say("path override works") + } + } + + location ~ ^/status.* { + content_by_lua_block { + local test_dict = ngx.shared["test"] + local uri = ngx.var.uri + local total_key = uri .. "#total" + local count_key = uri .. "#count" + local total = test_dict:get(total_key) + if not total then + return + end + + local count = test_dict:incr(count_key, 1, 0) + ngx.log(ngx.INFO, "uri: ", uri, " total: ", total, " count: ", count) + if count < total then + return + end + ngx.status = 500 + ngx.say("error") + } + } + + location /error { + content_by_lua_block { + ngx.status = 500 + ngx.say("error") + } + } + } +_EOC_ + + $block->set_value("http_config", $http_config); +}); + +run_tests(); + +__DATA__ + +=== TEST 1: set route, only one instance has checker +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/ai", + "plugins": { + "ai-proxy-multi": { + "fallback_strategy": "instance_health_and_rate_limiting", + "instances": [ + { + "name": "openai-gpt4", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + "checks": { + "active": { + "timeout": 5, + "http_path": "/status/gpt4", + "host": "foo.com", + "healthy": { + "interval": 1, + "successes": 1 + }, + "unhealthy": { + "interval": 1, + "http_failures": 1 + }, + "req_headers": ["User-Agent: curl/7.29.0"] + } + } + }, + { + "name": "openai-gpt3", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": {"header": {"Authorization": "Bearer token"}}, + "options": {"model": "gpt-3"}, + "override": {"endpoint": "http://localhost:16724"} + } + ], + "ssl_verify": false + } + } + }]] + ) + + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 2: once instance changes from unhealthy to healthy +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local core = require("apisix.core") + local test_dict = ngx.shared["test"] + + local send_request = function() + local code, _, body = t("/ai", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + assert(code == 200, "request should be successful") + return body + end + + -- set the instance to unhealthy + test_dict:set("/status/gpt4#total", 0) + -- trigger the health check + send_request() + ngx.sleep(1) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + if i == 1 then + ngx.sleep(4) -- trigger healthcheck + end + end + + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] <= 2, "gpt-4 should be unhealthy") + assert(instances_count["gpt-3"] >= 8, "gpt-3 should be healthy") + + -- set the instance to healthy + test_dict:set("/status/gpt4#total", 30) + ngx.sleep(1) + + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + if i == 1 then + ngx.sleep(4) -- trigger healthcheck + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + + local v = instances_count["gpt-4"] - instances_count["gpt-3"] + assert(v <= 2, "difference between gpt-4 and gpt-3 should be less than 2") + ngx.say("passed") + } + } +--- timeout: 20 +--- response_body +passed + + + +=== TEST 3: set service, only one instance has checker +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/services/1', + ngx.HTTP_PUT, + [[{ + "plugins": { + "ai-proxy-multi": { + "fallback_strategy": "instance_health_and_rate_limiting", + "instances": [ + { + "name": "openai-gpt4", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + "checks": { + "active": { + "timeout": 5, + "http_path": "/status/gpt4", + "host": "foo.com", + "healthy": { + "interval": 1, + "successes": 1 + }, + "unhealthy": { + "interval": 1, + "http_failures": 1 + }, + "req_headers": ["User-Agent: curl/7.29.0"] + } + } + }, + { + "name": "openai-gpt3", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": {"header": {"Authorization": "Bearer token"}}, + "options": {"model": "gpt-3"}, + "override": {"endpoint": "http://localhost:16724"} + } + ], + "ssl_verify": false + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 4: set route 1 related to service 1 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/ai", + "service_id": 1 + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 5: instance changes from unhealthy to healthy +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local core = require("apisix.core") + local test_dict = ngx.shared["test"] + local send_request = function() + local code, _, body = t("/ai", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + assert(code == 200, "request should be successful") + return body + end + -- set the instance to unhealthy + test_dict:set("/status/gpt4#total", 0) + -- trigger the health check + send_request() + ngx.sleep(2) + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + if i == 1 then + ngx.sleep(4) -- trigger healthcheck + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] <= 2, "gpt-4 should be unhealthy") + assert(instances_count["gpt-3"] >= 8, "gpt-3 should be healthy") + -- set the instance to healthy + test_dict:set("/status/gpt4#total", 30) + ngx.sleep(2) + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + if i == 1 then + ngx.sleep(4) -- trigger healthcheck + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + local diff = instances_count["gpt-4"] - instances_count["gpt-3"] + assert(diff <= 2, "difference between gpt-4 and gpt-3 should be less than 2") + ngx.say("passed") + } + } +--- timeout: 20 +--- response_body +passed + + + +=== TEST 6: set route, two instances have checker +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local checks_tmp = [[ + "checks": { + "active": { + "timeout": 5, + "http_path": "/status/%s", + "host": "foo.com", + "healthy": { + "interval": 1, + "successes": 1 + }, + "unhealthy": { + "interval": 1, + "http_failures": 1 + }, + "req_headers": ["User-Agent: curl/7.29.0"] + } + } + ]] + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/ai", + "plugins": { + "ai-proxy-multi": { + "fallback_strategy": "instance_health_and_rate_limiting", + "instances": [ + { + "name": "openai-gpt4", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + ]] .. string.format(checks_tmp, "gpt4").. [[ + }, + { + "name": "openai-gpt3", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-3" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + ]] .. string.format(checks_tmp, "gpt3") .. [[ + } + ], + "ssl_verify": false + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 7: healthy conversion of two instances +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local core = require("apisix.core") + local test_dict = ngx.shared["test"] + local send_request = function() + local code, _, body = t("/ai", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + assert(code == 200, "request should be successful") + return body + end + -- set the gpt4 instance to unhealthy + -- set the gpt3 instance to healthy + test_dict:set("/status/gpt4#total", 0) + test_dict:set("/status/gpt3#total", 50) + -- trigger the health check + send_request() + ngx.sleep(2) + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + if i == 1 then + ngx.sleep(4) -- trigger healthcheck + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] <= 2, "gpt-4 should be unhealthy") + assert(instances_count["gpt-3"] >= 8, "gpt-3 should be healthy") + -- set the gpt4 instance to healthy + -- set the gpt3 instance to unhealthy + test_dict:set("/status/gpt4#total", 50) + test_dict:set("/status/gpt3#total", 0) + ngx.sleep(2) + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + if i == 1 then + ngx.sleep(4) -- trigger healthcheck + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] >= 8, "gpt-4 should be healthy") + assert(instances_count["gpt-3"] <= 2, "gpt-3 should be unhealthy") + ngx.say("passed") + } + } +--- timeout: 20 +--- response_body +passed + + + +=== TEST 8: set route, two instances have checker +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local checks_tmp = [[ + "checks": { + "active": { + "timeout": 5, + "http_path": "/status/%s", + "host": "foo.com", + "healthy": { + "interval": 1, + "successes": 1 + }, + "unhealthy": { + "interval": 1, + "http_failures": 1 + }, + "req_headers": ["User-Agent: curl/7.29.0"] + } + } + ]] + local code, body = t('/apisix/admin/services/1', + ngx.HTTP_PUT, + [[{ + "plugins": { + "ai-proxy-multi": { + "fallback_strategy": "instance_health_and_rate_limiting", + "instances": [ + { + "name": "openai-gpt4", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-4" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + ]] .. string.format(checks_tmp, "gpt4").. [[ + }, + { + "name": "openai-gpt3", + "provider": "openai", + "weight": 1, + "priority": 1, + "auth": { + "header": { + "Authorization": "Bearer token" + } + }, + "options": { + "model": "gpt-3" + }, + "override": { + "endpoint": "http://localhost:16724" + }, + ]] .. string.format(checks_tmp, "gpt3") .. [[ + } + ], + "ssl_verify": false + } + } + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 9: set route 1 related to service 1 +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local code, body = t('/apisix/admin/routes/1', + ngx.HTTP_PUT, + [[{ + "uri": "/ai", + "service_id": 1 + }]] + ) + if code >= 300 then + ngx.status = code + end + ngx.say(body) + } + } +--- response_body +passed + + + +=== TEST 10: healthy conversion of two instances +--- config + location /t { + content_by_lua_block { + local t = require("lib.test_admin").test + local core = require("apisix.core") + local test_dict = ngx.shared["test"] + local send_request = function() + local code, _, body = t("/ai", + ngx.HTTP_POST, + [[{ + "messages": [ + { "role": "system", "content": "You are a mathematician" }, + { "role": "user", "content": "What is 1+1?" } + ] + }]], + nil, + { + ["test-type"] = "options", + ["Content-Type"] = "application/json", + } + ) + assert(code == 200, "request should be successful") + return body + end + -- set the gpt4 instance to unhealthy + -- set the gpt3 instance to healthy + test_dict:set("/status/gpt4#total", 0) + test_dict:set("/status/gpt3#total", 50) + -- trigger the health check + send_request() + ngx.sleep(1.2) + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + if i == 1 then + ngx.sleep(4) -- trigger healthcheck + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] <= 2, "gpt-4 should be unhealthy") + assert(instances_count["gpt-3"] >= 8, "gpt-3 should be healthy") + -- set the gpt4 instance to healthy + -- set the gpt3 instance to unhealthy + test_dict:set("/status/gpt4#total", 50) + test_dict:set("/status/gpt3#total", 0) + ngx.sleep(1.2) + local instances_count = { + ["gpt-4"] = 0, + ["gpt-3"] = 0, + } + for i = 1, 10 do + local resp = send_request() + if core.string.find(resp, "gpt-4") then + instances_count["gpt-4"] = instances_count["gpt-4"] + 1 + else + instances_count["gpt-3"] = instances_count["gpt-3"] + 1 + end + if i == 1 then + ngx.sleep(4) -- trigger healthcheck + end + end + ngx.log(ngx.INFO, "instances_count test:", core.json.delay_encode(instances_count)) + assert(instances_count["gpt-4"] >= 8, "gpt-4 should be healthy") + assert(instances_count["gpt-3"] <= 2, "gpt-3 should be unhealthy") + ngx.say("passed") + } + } +--- timeout: 20 +--- response_body +passed