This is an automated email from the ASF dual-hosted git repository.
alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go-pixiu.git
The following commit(s) were added to refs/heads/develop by this push:
new 510ea58c [AI] feat: Implement LLM proxy filter with retry and fallback
mechanisms (#685)
510ea58c is described below
commit 510ea58c5fde255ac96eb08c41228693c4c0d609
Author: Xuetao Li <[email protected]>
AuthorDate: Wed Jul 2 09:49:10 2025 +0800
[AI] feat: Implement LLM proxy filter with retry and fallback mechanisms
(#685)
* feat: Implement LLM proxy filter with retry and fallback mechanisms
* fix: Ensure original request body is closed properly using defer
* fix: Ensure original request body is closed properly using defer
* delete debug
* add docs
* Update pkg/filter/llm/proxy/filter.go
Co-authored-by: Copilot <[email protected]>
* use hc.Request.GetBody
* update
---------
Co-authored-by: Copilot <[email protected]>
---
pkg/common/constant/key.go | 1 +
pkg/common/http/manager.go | 2 +-
pkg/common/util/response.go | 5 +
pkg/filter/llm/proxy/filter.go | 235 +++++++++++++++++++++++++++++++++++++++++
pkg/model/llm.go | 6 +-
pkg/pluginregistry/registry.go | 1 +
pkg/server/cluster_manager.go | 37 ++++++-
7 files changed, 283 insertions(+), 4 deletions(-)
diff --git a/pkg/common/constant/key.go b/pkg/common/constant/key.go
index 762465f5..11d6cf63 100644
--- a/pkg/common/constant/key.go
+++ b/pkg/common/constant/key.go
@@ -52,6 +52,7 @@ const (
DubboHttpFilter = "dgp.filter.dubbo.http"
DubboProxyFilter = "dgp.filter.dubbo.proxy"
+ LLMProxyFilter = "dgp.filter.llm.proxy"
LLMTokenizerFilter = "dgp.filter.llm.tokenizer"
)
diff --git a/pkg/common/http/manager.go b/pkg/common/http/manager.go
index ee9a5833..8555483b 100644
--- a/pkg/common/http/manager.go
+++ b/pkg/common/http/manager.go
@@ -100,7 +100,7 @@ func (hcm *HttpConnectionManager) handleHTTPRequest(c
*pch.HttpContext) {
// recover any err when filterChain run
defer func() {
if err := recover(); err != nil {
- logger.Warnf("[dubbopixiu go] Occur An Unexpected Err:
%+v", err)
+ logger.Warnf("[dubbo-go-pixiu] Occur An Unexpected Err:
%+v", err)
c.SendLocalReply(stdHttp.StatusInternalServerError,
[]byte(fmt.Sprintf("Occur An Unexpected Err: %v", err)))
}
}()
diff --git a/pkg/common/util/response.go b/pkg/common/util/response.go
index 347632e6..25394970 100644
--- a/pkg/common/util/response.go
+++ b/pkg/common/util/response.go
@@ -180,3 +180,8 @@ func struct2Map(obj any) map[string]any {
}
return data
}
+
+// IsHTTPRespSuccessful checks if the HTTP response status code indicates
success (2xx).
+func IsHTTPRespSuccessful(statusCode int) bool {
+ return statusCode >= 200 && statusCode < 300
+}
diff --git a/pkg/filter/llm/proxy/filter.go b/pkg/filter/llm/proxy/filter.go
new file mode 100644
index 00000000..0b89d0d5
--- /dev/null
+++ b/pkg/filter/llm/proxy/filter.go
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package proxy
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+)
+
+import (
+ "github.com/apache/dubbo-go-pixiu/pkg/common/constant"
+ "github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter"
+ "github.com/apache/dubbo-go-pixiu/pkg/common/util"
+ contexthttp "github.com/apache/dubbo-go-pixiu/pkg/context/http"
+ "github.com/apache/dubbo-go-pixiu/pkg/logger"
+ "github.com/apache/dubbo-go-pixiu/pkg/model"
+ "github.com/apache/dubbo-go-pixiu/pkg/server"
+)
+
+const (
+ Kind = constant.LLMProxyFilter
+)
+
+func init() {
+ filter.RegisterHttpFilter(&Plugin{})
+}
+
+type (
+ // Plugin is http filter plugin.
+ Plugin struct {
+ }
+ // FilterFactory is http filter instance
+ FilterFactory struct {
+ cfg *Config
+ client http.Client
+ }
+ Filter struct {
+ client http.Client
+ scheme string
+ }
+ // Config describe the config of FilterFactory
+ Config struct {
+ Timeout time.Duration `yaml:"timeout"
json:"timeout,omitempty"`
+ MaxIdleConns int `yaml:"maxIdleConns"
json:"maxIdleConns,omitempty"`
+ MaxIdleConnsPerHost int `yaml:"maxIdleConnsPerHost"
json:"maxIdleConnsPerHost,omitempty"`
+ MaxConnsPerHost int `yaml:"maxConnsPerHost"
json:"maxConnsPerHost,omitempty"`
+ Scheme string `yaml:"scheme"
json:"scheme,omitempty" default:"http"`
+ }
+)
+
+func (p *Plugin) Kind() string {
+ return Kind
+}
+
+func (p *Plugin) CreateFilterFactory() (filter.HttpFilterFactory, error) {
+ return &FilterFactory{cfg: &Config{}}, nil
+}
+
+func (factory *FilterFactory) Config() any {
+ return factory.cfg
+}
+
+func (factory *FilterFactory) Apply() error {
+ scheme := strings.TrimSpace(strings.ToLower(factory.cfg.Scheme))
+
+ if scheme != "http" && scheme != "https" {
+ return fmt.Errorf("%s: scheme must be http or https", Kind)
+ }
+
+ factory.cfg.Scheme = scheme
+
+ cfg := factory.cfg
+ client := http.Client{
+ Timeout: cfg.Timeout,
+ Transport: http.RoundTripper(&http.Transport{
+ MaxIdleConns: cfg.MaxIdleConns,
+ MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost,
+ MaxConnsPerHost: cfg.MaxConnsPerHost,
+ }),
+ }
+ factory.client = client
+ return nil
+}
+
+func (factory *FilterFactory) PrepareFilterChain(ctx *contexthttp.HttpContext,
chain filter.FilterChain) error {
+ //reuse http client
+ f := &Filter{factory.client, factory.cfg.Scheme}
+ chain.AppendDecodeFilters(f)
+ return nil
+}
+
+func (f *Filter) Decode(hc *contexthttp.HttpContext) filter.FilterStatus {
+ rEntry := hc.GetRouteEntry()
+ if rEntry == nil {
+ bt, _ := json.Marshal(contexthttp.ErrResponse{Message: "no
route entry"})
+ hc.SendLocalReply(http.StatusBadRequest, bt)
+ return filter.Stop
+ }
+
+ logger.Debugf("[dubbo-go-pixiu] client choose endpoint from cluster:
%v", rEntry.Cluster)
+
+ var (
+ clusterName = rEntry.Cluster
+ clusterManager = server.GetClusterManager()
+ endpoint = clusterManager.PickEndpoint(clusterName, hc)
+ )
+
+ if endpoint == nil {
+ logger.Debugf("[dubbo-go-pixiu] cluster not found endpoint")
+ bt, _ := json.Marshal(contexthttp.ErrResponse{Message: "cluster
not found endpoint"})
+ hc.SendLocalReply(http.StatusServiceUnavailable, bt)
+ return filter.Stop
+ }
+
+ r := hc.Request
+ defer r.Body.Close()
+
+ var (
+ req *http.Request
+ resp *http.Response
+ err error
+ )
+
+ if hc.Request.Body != nil && hc.Request.GetBody == nil {
+ bodyBytes, err := io.ReadAll(hc.Request.Body)
+ hc.Request.Body.Close()
+
+ if err != nil {
+ bt, _ := json.Marshal(contexthttp.ErrResponse{Message:
fmt.Sprintf("failed to read request body: %v", err)})
+ hc.SendLocalReply(http.StatusInternalServerError, bt)
+ return filter.Stop
+ }
+
+ hc.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+ hc.Request.GetBody = func() (io.ReadCloser, error) {
+ return io.NopCloser(bytes.NewReader(bodyBytes)), nil
+ }
+ }
+
+ logger.Debugf("[dubbo-go-pixiu] client choose endpoint [%s: %v]",
endpoint.ID, endpoint.Address.GetAddress())
+
+ // make request
+FALLBACK:
+ for {
+ RETRY:
+ for retry := uint(0); retry <= endpoint.LLMMeta.RetryTimes;
retry++ {
+ req, err = f.assembleRequest(endpoint, r)
+ if err != nil {
+ logger.Warnf("[dubbo-go-pixiu] client assemble
request failed: %v", err)
+ break RETRY
+ }
+
+ resp, err = f.client.Do(req)
+ if err != nil {
+ logger.Warnf("[dubbo-go-pixiu] client call
endpoint [%s: %v] failed: %v", endpoint.ID, endpoint.Address.GetAddress(), err)
+ break RETRY
+ }
+ if util.IsHTTPRespSuccessful(resp.StatusCode) {
+ // If the response is successful, we can break
out of the fallback loop.
+ break FALLBACK
+ }
+ // If the response is not successful, we will retry
with the next endpoint.
+ logger.Debugf("[dubbo-go-pixiu] client retry endpoint
[%s: %v]", endpoint.ID, endpoint.Address.GetAddress())
+ }
+
+ if !endpoint.LLMMeta.Fallback {
+ // If fallback is not enabled, we will break out of the
fallback loop.
+ break FALLBACK
+ }
+
+ endpoint = clusterManager.PickNextEndpoint(clusterName,
endpoint.ID)
+ if endpoint == nil {
+ break FALLBACK
+ }
+
+ // If we have a next endpoint, we will retry with the next
endpoint.
+ logger.Debugf("[dubbo-go-pixiu] client fallback to endpoint
[%s: %v]", endpoint.ID, endpoint.Address.GetAddress())
+ }
+
+ if err != nil {
+ var urlErr *url.Error
+ ok := errors.As(err, &urlErr)
+ if ok && urlErr.Timeout() {
+ hc.SendLocalReply(http.StatusGatewayTimeout,
[]byte(err.Error()))
+ return filter.Stop
+ }
+ hc.SendLocalReply(http.StatusServiceUnavailable,
[]byte(err.Error()))
+ return filter.Stop
+ }
+
+ logger.Debugf("[dubbo-go-pixiu] client call resp:%v", resp)
+ hc.SourceResp = resp
+ // response write in hcm
+ return filter.Continue
+
+}
+
+func (f *Filter) assembleRequest(endpoint *model.Endpoint, r *http.Request)
(*http.Request, error) {
+ parsedURL := url.URL{
+ Host: endpoint.Address.GetAddress(),
+ Scheme: f.scheme,
+ Path: r.URL.Path,
+ RawQuery: r.URL.RawQuery,
+ }
+
+ req, err := http.NewRequest(r.Method, parsedURL.String(), r.Body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header = r.Header
+
+ return req, nil
+}
diff --git a/pkg/model/llm.go b/pkg/model/llm.go
index d74b1250..1ea4bcaa 100644
--- a/pkg/model/llm.go
+++ b/pkg/model/llm.go
@@ -32,8 +32,10 @@ import (
type (
// LLMMeta LLM metadata for llm call
LLMMeta struct {
- Provider string `yaml:"provider" json:"provider"`
// Provider the cluster unique name
- APIKeys []LLMAPIKey `yaml:"api_keys" json:"api_keys"
mapstructure:"api_keys"` // APIKey the cluster unique name
+ Provider string `yaml:"provider" json:"provider"`
// Provider the cluster unique name
+ APIKeys []LLMAPIKey `yaml:"api_keys" json:"api_keys"
mapstructure:"api_keys"` // APIKey the cluster unique name
+ RetryTimes uint `yaml:"retry_times" json:"retry_times"
mapstructure:"retry_times" default:"0"` // Retry times for failed call
+ Fallback bool `yaml:"fallback" json:"fallback"
mapstructure:"fallback"` // Fallback to other provider if
failed
}
LLMAPIKey struct {
diff --git a/pkg/pluginregistry/registry.go b/pkg/pluginregistry/registry.go
index 02e3034c..e4094231 100644
--- a/pkg/pluginregistry/registry.go
+++ b/pkg/pluginregistry/registry.go
@@ -40,6 +40,7 @@ import (
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/http/loadbalancer"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/http/proxyrewrite"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/http/remote"
+ _ "github.com/apache/dubbo-go-pixiu/pkg/filter/llm/proxy"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/llm/tokenizer"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/metric"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/network/dubboproxy"
diff --git a/pkg/server/cluster_manager.go b/pkg/server/cluster_manager.go
index acbbebf4..fb0a3a13 100644
--- a/pkg/server/cluster_manager.go
+++ b/pkg/server/cluster_manager.go
@@ -153,13 +153,48 @@ func (cm *ClusterManager) CompareAndSetStore(store
*ClusterStore) bool {
return true
}
+// PickEndpoint picks an endpoint from the cluster by its name and load
balancing policy.
func (cm *ClusterManager) PickEndpoint(clusterName string, policy
model.LbPolicy) *model.Endpoint {
cm.rw.RLock()
defer cm.rw.RUnlock()
+ c := cm.getCluster(clusterName)
+ if c == nil {
+ logger.Warnf("[dubbo-go-pixiu] cluster %s not found",
clusterName)
+ return nil
+ }
+ return cm.pickOneEndpoint(c, policy)
+}
+
+// PickNextEndpoint picks the next endpoint in the cluster after the current
endpoint ID.
+func (cm *ClusterManager) PickNextEndpoint(clusterName string, curEndpointID
string) *model.Endpoint {
+ cm.rw.RLock()
+ defer cm.rw.RUnlock()
+
+ c := cm.getCluster(clusterName)
+ if c == nil {
+ logger.Warnf("[dubbo-go-pixiu] cluster %s not found",
clusterName)
+ return nil
+ }
+
+ for i, endpoint := range c.Endpoints {
+ if endpoint.ID == curEndpointID {
+ // pick next endpoint
+ if i < len(c.Endpoints)-1 {
+ return c.Endpoints[i+1]
+ }
+ return nil // have tried all endpoints
+ }
+ }
+
+ return nil
+}
+
+// getCluster returns the cluster configuration by its name.
+func (cm *ClusterManager) getCluster(clusterName string) *model.ClusterConfig {
for _, c := range cm.store.Config {
if c.Name == clusterName {
- return cm.pickOneEndpoint(c, policy)
+ return c
}
}
return nil