This is an automated email from the ASF dual-hosted git repository.

alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go-pixiu.git


The following commit(s) were added to refs/heads/develop by this push:
     new bd840052 [AI] feat: implementation of token calculation filter (#659)
bd840052 is described below

commit bd84005231b12184bfc6400951f628ab87e9ccaa
Author: Alan <[email protected]>
AuthorDate: Sat May 10 18:27:44 2025 +0800

    [AI] feat: implementation of token calculation filter (#659)
    
    * impl: implementation of token calculation filter
    
    * go mod tidy
    
    * change token calc logic
    
    * register filter
    
    * feat: add HTTP/HTTPS scheme in http filter, enable http filter has the 
ability to make https call
    
    * change config validation location
    
    * revert
    
    * impl: implementation of token calculation filter
    
    * go mod tidy
    
    * change token calc logic
    
    * register filter
    
    * tokenizer calculation
    
    * update
    
    * update
    
    * add comment
    
    * update
    
    * update
    
    * go mod tidy
    
    ---------
    
    Co-authored-by: Xuetao Li <[email protected]>
---
 go.mod                                 |   2 +-
 go.sum                                 |   3 +-
 pkg/common/constant/key.go             |   2 +
 pkg/filter/tokenizer/tokenizer.go      | 220 +++++++++++++++++++++++++++++++++
 pkg/filter/tokenizer/tokenizer_test.go |  86 +++++++++++++
 pkg/pluginregistry/registry.go         |   1 +
 6 files changed, 312 insertions(+), 2 deletions(-)

diff --git a/go.mod b/go.mod
index d7a560e0..8cc62525 100644
--- a/go.mod
+++ b/go.mod
@@ -74,7 +74,7 @@ require (
        github.com/coreos/go-semver v0.3.0 // indirect
        github.com/coreos/go-systemd/v22 v22.3.2 // indirect
        github.com/davecgh/go-spew v1.1.1 // indirect
-       github.com/dlclark/regexp2 v1.7.0 // indirect
+       github.com/dlclark/regexp2 v1.10.0 // indirect
        github.com/dustin/go-humanize v1.0.1 // indirect
        github.com/eapache/go-resiliency v1.7.0 // indirect
        github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 
// indirect
diff --git a/go.sum b/go.sum
index 298f3ef0..0260750d 100644
--- a/go.sum
+++ b/go.sum
@@ -538,8 +538,9 @@ github.com/davecgh/go-spew v1.1.1 
h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod 
h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod 
h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
 github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod 
h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
-github.com/dlclark/regexp2 v1.7.0 
h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo=
 github.com/dlclark/regexp2 v1.7.0/go.mod 
h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
+github.com/dlclark/regexp2 v1.10.0 
h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
+github.com/dlclark/regexp2 v1.10.0/go.mod 
h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
 github.com/dubbo-go-pixiu/pixiu-api v0.1.6-0.20220612115254-d9a176b25b99 
h1:UjDxgIEu6DbJVJTrxm5mwC0j54jNao1pkYVlT8X+KgY=
 github.com/dubbo-go-pixiu/pixiu-api 
v0.1.6-0.20220612115254-d9a176b25b99/go.mod 
h1:1l+6pDTdEHwCyyyJmfckOAdGp6f5PZ33ZVMgxso9q/U=
 github.com/dubbogo/go-zookeeper v1.0.3/go.mod 
h1:fn6n2CAEer3novYgk9ULLwAjuV8/g4DdC2ENwRb6E+c=
diff --git a/pkg/common/constant/key.go b/pkg/common/constant/key.go
index 5de033fc..762465f5 100644
--- a/pkg/common/constant/key.go
+++ b/pkg/common/constant/key.go
@@ -51,6 +51,8 @@ const (
 
        DubboHttpFilter  = "dgp.filter.dubbo.http"
        DubboProxyFilter = "dgp.filter.dubbo.proxy"
+
+       LLMTokenizerFilter = "dgp.filter.llm.tokenizer"
 )
 
 const (
diff --git a/pkg/filter/tokenizer/tokenizer.go 
b/pkg/filter/tokenizer/tokenizer.go
new file mode 100644
index 00000000..b9a9a122
--- /dev/null
+++ b/pkg/filter/tokenizer/tokenizer.go
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package tokenizer
+
+import (
+       "bufio"
+       "encoding/json"
+       "fmt"
+       "io"
+       "strings"
+       "sync"
+)
+
+import (
+       "github.com/apache/dubbo-go-pixiu/pkg/client"
+       "github.com/apache/dubbo-go-pixiu/pkg/common/constant"
+       "github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter"
+       "github.com/apache/dubbo-go-pixiu/pkg/context/http"
+       "github.com/apache/dubbo-go-pixiu/pkg/logger"
+)
+
+const (
+       Kind                = constant.LLMTokenizerFilter
+       LoggerFmt           = "[Tokenizer] [DOWNSTREAM] "
+       PromptTokensDetails = "prompt_tokens_details"
+)
+
+func init() {
+       filter.RegisterHttpFilter(&Plugin{})
+}
+
+type (
+       // Plugin is http filter plugin.
+       Plugin struct {
+       }
+       // FilterFactory is http filter instance
+       FilterFactory struct {
+               cfg *Config
+       }
+       // Filter is http filter instance
+       Filter struct {
+               cfg *Config
+       }
+       // Config describe the config of FilterFactory
+       Config struct {
+       }
+)
+
+func (p *Plugin) Kind() string {
+       return Kind
+}
+
+func (p *Plugin) CreateFilterFactory() (filter.HttpFilterFactory, error) {
+       return &FilterFactory{cfg: &Config{}}, nil
+}
+
+func (factory *FilterFactory) Config() any {
+       return factory.cfg
+}
+
+func (factory *FilterFactory) Apply() error {
+       return nil
+}
+
+func (factory *FilterFactory) PrepareFilterChain(ctx *http.HttpContext, chain 
filter.FilterChain) error {
+       f := &Filter{
+               cfg: factory.cfg,
+       }
+       chain.AppendEncodeFilters(f)
+       return nil
+}
+
+func (f *Filter) Encode(hc *http.HttpContext) filter.FilterStatus {
+       switch res := hc.TargetResp.(type) {
+       case *client.StreamResponse:
+               pr, pw := io.Pipe()
+               res.Stream = newTeeReadCloser(res.Stream, pw)
+               go f.processStreamResponse(pr)
+       case *client.UnaryResponse:
+               f.processUsageData(res.Data)
+       default:
+               logger.Warnf(LoggerFmt+"Response type not suitable for token 
calc: %T", res)
+       }
+
+       return filter.Continue
+}
+
+func (f *Filter) processStreamResponse(stream io.Reader) {
+       scanner := bufio.NewScanner(stream)
+       currentLine := make([]byte, 0, 1024)
+       // read the stream by line
+       // and process the data lines
+       // the data line is prefixed with "data:"
+       // the data line is a json string
+       // the for loop is to read the streamline by line and concat the 
separate "data:" lines
+       for scanner.Scan() {
+               line := scanner.Text()
+               line = strings.TrimSpace(line)
+               if strings.HasPrefix(line, "data:") {
+                       f.processUsageData(currentLine)
+                       currentLine = make([]byte, 0, 1024)
+                       line = strings.TrimPrefix(line, "data:")
+               }
+               currentLine = append(currentLine, line...)
+       }
+       f.processUsageData(currentLine)
+       if err := scanner.Err(); err != nil && err != io.EOF {
+               logger.Errorf(LoggerFmt+"Error reading stream: %v", err)
+       }
+}
+
+func (f *Filter) processUsageData(data []byte) {
+       var dataCont map[string]any
+       err := json.Unmarshal(data, &dataCont)
+       if err != nil {
+               return
+       }
+
+       usage, ok := dataCont["usage"].(map[string]any)
+       if !ok || usage == nil {
+               return
+       }
+
+       // todo: currently we only log the usage, we should export it to metrics
+       f.logUsage(usage)
+}
+
+func (f *Filter) logUsage(usage map[string]any) {
+       for key, value := range usage {
+               if key == PromptTokensDetails {
+                       details, ok := value.(map[string]any)
+                       if !ok {
+                               logger.Warnf(LoggerFmt+PromptTokensDetails+" is 
not a map, value: %+v", value)
+                               continue
+                       }
+                       for detailKey, detailValue := range details {
+                               logger.Infof(LoggerFmt+"Usage | %s: %v", 
detailKey, detailValue)
+                       }
+               } else {
+                       logger.Infof(LoggerFmt+"Usage | %s: %v", key, value)
+               }
+       }
+}
+
+type teeReadCloser struct {
+       reader   io.Reader
+       closer   io.Closer
+       writer   io.Writer
+       once     sync.Once
+       closeErr error
+}
+
+func newTeeReadCloser(r io.ReadCloser, w io.Writer) *teeReadCloser {
+       return &teeReadCloser{
+               reader: r,
+               closer: r,
+               writer: w,
+       }
+}
+
+func (t *teeReadCloser) Read(p []byte) (n int, err error) {
+       n, err = t.reader.Read(p)
+       if n <= 0 || err != nil {
+               return
+       }
+       nw, err := t.writer.Write(p[:n])
+       if err != nil {
+               logger.Errorf(LoggerFmt+"Error writing to tee writer: %v", err)
+               return
+       }
+       if nw != n {
+               logger.Errorf(LoggerFmt+"Short write to tee writer: %d/%d", nw, 
n)
+               //err = fmt.Errorf("short write to tee writer: %d/%d", nw, n)
+       }
+       return n, nil
+}
+
+func (t *teeReadCloser) Close() (err error) {
+       var (
+               closerErr error
+               writerErr error
+       )
+
+       t.once.Do(func() {
+               closerErr = t.closer.Close()
+               if closerErr != nil {
+                       logger.Errorf(LoggerFmt+"Error closing closer: %v", 
closerErr)
+               }
+
+               if t.writer != nil {
+                       writerCloser, ok := t.writer.(io.Closer)
+                       if ok {
+                               writerErr = writerCloser.Close()
+                               if writerErr != nil {
+                                       logger.Errorf(LoggerFmt+"Error closing 
writer: %v", writerErr)
+                               }
+                       }
+               }
+       })
+
+       if closerErr != nil || writerErr != nil {
+               err = fmt.Errorf("closing closer error: %w. closing writer 
error: %w", closerErr, writerErr)
+       }
+       return err
+}
diff --git a/pkg/filter/tokenizer/tokenizer_test.go 
b/pkg/filter/tokenizer/tokenizer_test.go
new file mode 100644
index 00000000..7881099f
--- /dev/null
+++ b/pkg/filter/tokenizer/tokenizer_test.go
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package tokenizer
+
+import (
+       "bytes"
+       "io"
+       "net/http"
+       "strings"
+       "testing"
+       "time"
+)
+
+import (
+       "github.com/stretchr/testify/assert"
+)
+
+import (
+       "github.com/apache/dubbo-go-pixiu/pkg/client"
+       "github.com/apache/dubbo-go-pixiu/pkg/context/mock"
+)
+
+func TestUnaryResponse(t *testing.T) {
+       filter := &Filter{}
+
+       request, err := http.NewRequest("POST", 
"http://www.dubbogopixiu.com/mock/test?name=tc";, 
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
+       assert.NoError(t, err)
+       c := mock.GetMockHTTPContext(request)
+       c.TargetResp = &client.UnaryResponse{
+               Data: []byte(`{
+               "usage": {
+                       "prompt_tokens": 7,
+                       "completion_tokens": 32,
+                       "total_tokens": 39,
+                       "prompt_tokens_details": {
+                               "cached_tokens": 0
+                       },
+                       "prompt_cache_hit_tokens": 0,
+                       "prompt_cache_miss_tokens": 7
+               }
+       }`)}
+       filter.Encode(c)
+}
+
+func TestStreamResponse(t *testing.T) {
+       filter := &Filter{}
+
+       request, err := http.NewRequest("POST", 
"http://www.dubbogopixiu.com/mock/test?name=tc";, 
bytes.NewReader([]byte("{\"id\":\"12345\"}")))
+       assert.NoError(t, err)
+       c := mock.GetMockHTTPContext(request)
+       s := io.NopCloser(strings.NewReader(`data: {
+               "usage": {
+                       "prompt_tokens": 7,
+                       "completion_tokens": 32,
+                       "total_tokens": 39,
+                       "prompt_tokens_details": {
+                               "cached_tokens": 0
+                       },
+                       "prompt_cache_hit_tokens": 0,
+                       "prompt_cache_miss_tokens": 7
+               }
+       }
+
+`))
+       c.TargetResp = &client.StreamResponse{Stream: s}
+       filter.Encode(c)
+       buf := make([]byte, 1024)
+       c.TargetResp.(*client.StreamResponse).Stream.Read(buf)
+       time.Sleep(3 * time.Millisecond)
+       c.TargetResp.(*client.StreamResponse).Stream.Close()
+}
diff --git a/pkg/pluginregistry/registry.go b/pkg/pluginregistry/registry.go
index 0209d96f..1dd1d599 100644
--- a/pkg/pluginregistry/registry.go
+++ b/pkg/pluginregistry/registry.go
@@ -47,6 +47,7 @@ import (
        _ 
"github.com/apache/dubbo-go-pixiu/pkg/filter/network/httpconnectionmanager"
        _ "github.com/apache/dubbo-go-pixiu/pkg/filter/prometheus"
        _ "github.com/apache/dubbo-go-pixiu/pkg/filter/sentinel/ratelimit"
+       _ "github.com/apache/dubbo-go-pixiu/pkg/filter/tokenizer"
        _ "github.com/apache/dubbo-go-pixiu/pkg/filter/tracing"
        _ "github.com/apache/dubbo-go-pixiu/pkg/filter/traffic"
        _ "github.com/apache/dubbo-go-pixiu/pkg/listener/http"

Reply via email to