This is an automated email from the ASF dual-hosted git repository.
alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go-pixiu.git
The following commit(s) were added to refs/heads/develop by this push:
new 951234dc feat(ai-gateway): enhance API key handling and health check
logic for endpoints (#731)
951234dc is described below
commit 951234dc9a553788caa1130659fc29c6c2a29bad
Author: Xuetao Li <[email protected]>
AuthorDate: Sun Sep 21 22:53:50 2025 +0800
feat(ai-gateway): enhance API key handling and health check logic for
endpoints (#731)
* feat(ai-gateway): enhance API key handling and health check logic for
endpoints
* fix copilot
---
cmd/pixiu/pixiu.go | 8 ++------
pkg/common/constant/http.go | 3 +++
pkg/common/constant/pixiu.go | 7 +++++++
pkg/filter/llm/proxy/filter.go | 42 +++++++++++++++++++++++++++++++++++-------
pkg/model/llm.go | 14 +++++---------
5 files changed, 52 insertions(+), 22 deletions(-)
diff --git a/cmd/pixiu/pixiu.go b/cmd/pixiu/pixiu.go
index d07c7a02..c5a4a902 100644
--- a/cmd/pixiu/pixiu.go
+++ b/cmd/pixiu/pixiu.go
@@ -29,14 +29,10 @@ import (
import (
"github.com/apache/dubbo-go-pixiu/pkg/cmd"
+ "github.com/apache/dubbo-go-pixiu/pkg/common/constant"
_ "github.com/apache/dubbo-go-pixiu/pkg/pluginregistry"
)
-const (
- // Version pixiu version
- Version = "1.0.0"
-)
-
// main pixiu run method
func main() {
app := getRootCmd()
@@ -53,7 +49,7 @@ func getRootCmd() *cobra.Command {
"services. It supports HTTP-to-Dubbo and HTTP-to-HTTP
proxy and more protocols will be supported in the near \n" +
"future. \n" +
"(c) " + strconv.Itoa(time.Now().Year()) + " Dubbogo",
- Version: Version,
+ Version: constant.Version,
}
rootCmd.AddCommand(cmd.GatewayCmd)
diff --git a/pkg/common/constant/http.go b/pkg/common/constant/http.go
index 767d49c0..bbb98dd4 100644
--- a/pkg/common/constant/http.go
+++ b/pkg/common/constant/http.go
@@ -24,6 +24,7 @@ const (
HeaderKeyTransferEncoding = "Transfer-Encoding"
HeaderKeyContentLength = "Content-Length"
HeaderKeyContentEncoding = "Content-Encoding"
+ HeaderKeyUserAgent = "User-Agent"
HeaderKeyAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HeaderKeyAccessControlAllowHeaders = "Access-Control-Allow-Headers"
@@ -48,6 +49,8 @@ const (
HeaderValueKeepAlive = "keep-alive"
HeaderValueNoCache = "no-cache"
+ HeaderValueAuthorization = "Authorization"
+
HeaderValueAll = "*"
PathSlash = "/"
diff --git a/pkg/common/constant/pixiu.go b/pkg/common/constant/pixiu.go
index f78d64df..7b3f8409 100644
--- a/pkg/common/constant/pixiu.go
+++ b/pkg/common/constant/pixiu.go
@@ -21,6 +21,13 @@ import (
"time"
)
+const (
+ // Name pixiu name
+ Name = "dubbo-go-pixiu"
+ // Version pixiu version
+ Version = "1.0.0"
+)
+
// default timeout 1s.
const (
DefaultTimeoutStr = "1s"
diff --git a/pkg/filter/llm/proxy/filter.go b/pkg/filter/llm/proxy/filter.go
index c6be8956..7edc8bdc 100644
--- a/pkg/filter/llm/proxy/filter.go
+++ b/pkg/filter/llm/proxy/filter.go
@@ -41,7 +41,10 @@ import (
)
const (
- Kind = constant.LLMProxyFilter
+ Kind = constant.LLMProxyFilter
+ APIKeyPrefix = "Bearer"
+ LLMUnhealthyKey = "LLMUnhealthy"
+ HealthyCheckTimeKey = "HealthyCheckTime"
// Context key to pass attempt data from proxy to downstream filters
LLMUpstreamAttemptsKey = "llm_upstream_attempts"
)
@@ -236,6 +239,12 @@ func (f *Filter) assembleRequest(endpoint *model.Endpoint,
r *http.Request) (*ht
// Copy headers from original request
req.Header = r.Header
+ // replace the header value with the endpoint's api key
+ if apiKey := endpoint.LLMMeta.APIKey; apiKey != "" {
+ req.Header.Set(constant.HeaderValueAuthorization,
fmt.Sprintf("%s %s", APIKeyPrefix, apiKey))
+ }
+ req.Header.Set(constant.HeaderKeyUserAgent, fmt.Sprintf("%s %s",
constant.Name, constant.Version))
+
return req, nil
}
@@ -250,24 +259,41 @@ func (s *Strategy) Execute(executor *RequestExecutor)
(*http.Response, error) {
attempts []UpstreamAttempt
)
- // 1. Pick initial endpoint from the cluster
+ // 1. Pick initial endpoint from the cluster based on load balancing.
endpoint := executor.clusterManager.PickEndpoint(executor.clusterName,
executor.hc)
// 2. The main fallback loop. It continues as long as we have a valid
endpoint to try.
for endpoint != nil {
logger.Debugf("[dubbo-go-pixiu] client attempting endpoint [%s:
%v]", endpoint.ID, endpoint.Address.GetAddress())
- // 3. Dynamically load the retry policy for the current endpoint
+ // 3. Check the health of current endpoint,
+ if unhealthy, ok := endpoint.Metadata[LLMUnhealthyKey]; ok &&
unhealthy == "true" {
+ // check the health cooldown time
+ if t, ok := endpoint.Metadata[HealthyCheckTimeKey]; ok {
+ lt, err := time.Parse(time.RFC3339, t)
+ if err == nil && time.Since(lt) <
time.Millisecond*time.Duration(endpoint.LLMMeta.HealthCheckInterval) {
+ logger.Debugf("[dubbo-go-pixiu]
endpoint [%s: %v] is still in unhealthy cooldown period. Skipping to next
endpoint.", endpoint.ID, endpoint.Address.GetAddress())
+ endpoint =
getNextFallbackEndpoint(endpoint, executor)
+ continue
+ }
+ // The Cooldown period has passed, ready for a
new attempt
+ delete(endpoint.Metadata, LLMUnhealthyKey)
+ delete(endpoint.Metadata, HealthyCheckTimeKey)
+ logger.Debugf("[dubbo-go-pixiu] endpoint [%s:
%v] cooldown period passed. Retrying this endpoint.", endpoint.ID,
endpoint.Address.GetAddress())
+ }
+ }
+
+ // 4. Dynamically load the retry policy for the current endpoint
var retryPolicy retry.RetryPolicy
retryPolicy, err = retry.GetRetryPolicy(endpoint)
if err != nil {
- logger.Errorf("could not load retry policy for endpoint
%s: %v. Skipping to next endpoint.", endpoint.ID, err)
+ logger.Errorf("could not load retry policy for endpoint
[%s: %v]. Skipping to next endpoint.", endpoint.ID, err)
endpoint = getNextFallbackEndpoint(endpoint, executor)
continue
}
retryPolicy.Reset()
- // 4. The retry loop for the current endpoint.
+ // 5. The retry loop for the current endpoint.
for retryPolicy.Attempt() {
var req *http.Request
req, err = executor.filter.assembleRequest(endpoint,
executor.hc.Request)
@@ -309,14 +335,16 @@ func (s *Strategy) Execute(executor *RequestExecutor)
(*http.Response, error) {
endpoint.ID, endpoint.Address.GetAddress(),
err, resp.Status)
}
- // 5. If we are here, all retries for the current endpoint are
exhausted.
+ // 6. If we are here, all retries for the current endpoint are
exhausted.
// Get the next endpoint for fallback. The loop will terminate
if it's nil.
+ endpoint.Metadata[LLMUnhealthyKey] = "true"
+ endpoint.Metadata[HealthyCheckTimeKey] =
time.Now().Format(time.RFC3339)
endpoint = getNextFallbackEndpoint(endpoint, executor)
}
+ // 7. If we've exited the loop, all attempts and fallbacks have failed.
executor.hc.Params[LLMUpstreamAttemptsKey] = attempts
- // 6. If we've exited the loop, all attempts and fallbacks have failed.
// Return the last known error and response.
if err == nil && resp != nil {
err = fmt.Errorf("request failed with status code %d after all
retries and fallbacks", resp.StatusCode)
diff --git a/pkg/model/llm.go b/pkg/model/llm.go
index 7b0889a1..d229f398 100644
--- a/pkg/model/llm.go
+++ b/pkg/model/llm.go
@@ -32,15 +32,11 @@ import (
type (
// LLMMeta LLM metadata for llm call
LLMMeta struct {
- Provider string `yaml:"provider" json:"provider"`
// Provider the cluster unique name
- APIKeys []LLMAPIKey `yaml:"api_keys" json:"api_keys"
mapstructure:"api_keys"` // APIKey the cluster unique name
- RetryPolicy RetryPolicy `yaml:"retry_policy"
json:"retry_policy" mapstructure:"retry_policy"` // RetryPolicy key
- Fallback bool `yaml:"fallback" json:"fallback"
mapstructure:"fallback"` // Fallback to the next provider if failed
- }
-
- LLMAPIKey struct {
- Name string `yaml:"name" json:"name"` // Name of the api key
- Key string `yaml:"key" json:"key"` // Real Key
+ Provider string `yaml:"provider"
json:"provider"`
// Provider the cluster unique name
+ APIKey string `yaml:"api_key" json:"api_key"
mapstructure:"api_key"`
// APIKey the cluster unique name
+ RetryPolicy RetryPolicy `yaml:"retry_policy"
json:"retry_policy" mapstructure:"retry_policy"`
// RetryPolicy key
+ Fallback bool `yaml:"fallback"
json:"fallback" mapstructure:"fallback"`
// Fallback to the next provider if failed
+ HealthCheckInterval int64 `yaml:"health_check_interval"
json:"health_check_interval" mapstructure:"health_check_interval"
default:"5000"` // HealthCheckInterval the interval for health check
}
LLMProviderDomains struct {