This is an automated email from the ASF dual-hosted git repository.
xuetaoli pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go-pixiu.git
The following commit(s) were added to refs/heads/develop by this push:
new ea927c12 feat: Add KVCache offload funtion in PIXIU gateway (#878)
ea927c12 is described below
commit ea927c124c2fc2e274a76aaa40790655937fb101
Author: 陈乐樂 <[email protected]>
AuthorDate: Mon Mar 9 11:20:14 2026 +0800
feat: Add KVCache offload funtion in PIXIU gateway (#878)
* feat: KV Cache
* fix: fix some problems
* feat: add KVCache feats
* feat: v2
* feat: v2
* chore: ignore .worktrees
* fix: fix some bugs
* feat: v3
* feat: v4
* feat: fix some problems
* fix: delete some unuesd gomod
* fix: fix some problems
* fix: fix some problems
* add apache lisences
* fix: fix some problems
* fix: add unit test and fix some problems
* fix: delete some unused code
* fix: add docs
* feat: add some new features
---
configs/ai_kvcache_config.yaml | 37 +++
docs/ai/kvcache.md | 185 ++++++++++++
docs/ai/kvcache_CN.md | 185 ++++++++++++
go.mod | 8 +-
go.sum | 11 +-
pkg/common/constant/key.go | 1 +
pkg/filter/ai/kvcache/circuit_breaker.go | 106 +++++++
pkg/filter/ai/kvcache/config.go | 175 ++++++++++++
pkg/filter/ai/kvcache/filter.go | 142 ++++++++++
pkg/filter/ai/kvcache/handlers.go | 211 ++++++++++++++
pkg/filter/ai/kvcache/handlers_test.go | 129 +++++++++
pkg/filter/ai/kvcache/lmcache_client.go | 185 ++++++++++++
pkg/filter/ai/kvcache/lmcache_client_test.go | 97 +++++++
pkg/filter/ai/kvcache/load_monitor.go | 132 +++++++++
pkg/filter/ai/kvcache/strategy.go | 115 ++++++++
pkg/filter/ai/kvcache/strategy_test.go | 129 +++++++++
pkg/filter/ai/kvcache/test_helpers_test.go | 47 ++++
pkg/filter/ai/kvcache/token_manager.go | 403 +++++++++++++++++++++++++++
pkg/filter/ai/kvcache/token_manager_test.go | 99 +++++++
pkg/filter/ai/kvcache/types.go | 41 +++
pkg/filter/llm/proxy/filter.go | 23 +-
pkg/pluginregistry/registry.go | 1 +
pkg/server/cluster_manager.go | 17 ++
23 files changed, 2469 insertions(+), 10 deletions(-)
diff --git a/configs/ai_kvcache_config.yaml b/configs/ai_kvcache_config.yaml
new file mode 100644
index 00000000..33c932a5
--- /dev/null
+++ b/configs/ai_kvcache_config.yaml
@@ -0,0 +1,37 @@
+http_filters:
+ - name: dgp.filter.ai.kvcache
+ config:
+ enabled: true
+ vllm_endpoint: "http://localhost:8001"
+ lmcache_endpoint: "http://localhost:8080"
+ default_model: "Qwen2.5-3B-Instruct"
+ request_timeout: 5s
+ lookup_routing_timeout: 100ms
+ hot_window: 5m
+ hot_max_records: 500
+ token_cache:
+ enabled: true
+ max_size: 10000
+ ttl: 10m
+ cache_strategy:
+ enable_compression: true
+ enable_pinning: true
+ enable_eviction: true
+ memory_threshold: 0.8
+ hot_content_threshold: 50
+ load_threshold: 0.7
+ compress_method: "zstd"
+ pin_instance_id: "vllm-instance-1"
+ pin_location: "redis"
+ compress_instance_id: "vllm-instance-2"
+ compress_location: "redis"
+ evict_instance_id: "vllm-instance-1"
+
+clusters:
+ - name: vllm_cluster
+ lb_policy: round_robin
+ endpoints:
+ - id: vllm-instance-1
+ socket_address: { address: "localhost", port: 8001 }
+ - id: vllm-instance-2
+ socket_address: { address: "localhost", port: 8002 }
\ No newline at end of file
diff --git a/docs/ai/kvcache.md b/docs/ai/kvcache.md
new file mode 100644
index 00000000..ed3ec280
--- /dev/null
+++ b/docs/ai/kvcache.md
@@ -0,0 +1,185 @@
+## AI KVCache Filter Configuration
+
+English | [中文](./kvcache_CN.md)
+
+This document explains how to configure and use the `dgp.filter.ai.kvcache`
filter in Dubbo-go-Pixiu.
+
+The filter integrates with vLLM (`/tokenize`) and LMCache controller APIs
(`/lookup`, `/pin`, `/compress`, `/evict`) to:
+
+- provide cache-aware routing hints
+- trigger cache-management actions asynchronously
+- keep the main request path non-blocking
+
+---
+
+### Architecture and Request Flow
+
+`dgp.filter.ai.kvcache` is an HTTP decode filter. A typical request flow is:
+
+1. Parse request body and extract `model` and `prompt` (or fallback from
`messages`).
+2. Record local hotness statistics (`model + prompt`) in the token manager.
+3. Try cache-aware routing:
+ - read token cache for prompt
+ - call LMCache `/lookup`
+ - set a preferred endpoint hint in context (`llm_preferred_endpoint_id`)
+4. Start an async cache-management goroutine (best-effort):
+ - call vLLM `/tokenize`
+ - call LMCache `/lookup` if needed
+ - execute strategy decisions (`compress` / `pin` / `evict`)
+5. Continue the filter chain immediately (main request is not blocked by cache
management).
+
+---
+
+### Routing Contract (Important)
+
+Current cache-aware routing uses **instance id matching**:
+
+- The kvcache filter writes `llm_preferred_endpoint_id` into request context.
+- `dgp.filter.llm.proxy` reads this value and tries to select an endpoint by
`endpoint.id`.
+
+So for routing to work:
+
+- `LMCache lookup instance_id` must equal `pixiu cluster endpoint.id`.
+
+If no match exists, the request falls back to normal load-balancing behavior.
+
+Note:
+
+- Current implementation does **not** call LMCache `/query_worker_info`.
+- This means kvcache routing is based on instance-id contract, not dynamic
IP/port discovery.
+
+---
+
+### Configuration Example
+
+```yaml
+listeners:
+ - name: net/http
+ protocol_type: HTTP
+ address:
+ socket_address:
+ address: 0.0.0.0
+ port: 8888
+ filter_chains:
+ filters:
+ - name: dgp.filter.httpconnectionmanager
+ config:
+ route_config:
+ routes:
+ - match:
+ prefix: /
+ route:
+ cluster: vllm_cluster
+ http_filters:
+ - name: dgp.filter.ai.kvcache
+ config:
+ enabled: true
+ vllm_endpoint: "http://127.0.0.1:8000"
+ lmcache_endpoint: "http://127.0.0.1:9000"
+ default_model: "demo"
+ request_timeout: "2s"
+ lookup_routing_timeout: "50ms"
+ hot_window: "5m"
+ hot_max_records: 300
+ token_cache:
+ enabled: true
+ max_size: 1024
+ ttl: "10m"
+ cache_strategy:
+ enable_compression: true
+ enable_pinning: true
+ enable_eviction: true
+ load_threshold: 0.7
+ memory_threshold: 0.85
+ hot_content_threshold: 10
+ pin_instance_id: "vllm-instance-1"
+ pin_location: "LocalCPUBackend"
+ compress_instance_id: "vllm-instance-1"
+ compress_location: "LocalCPUBackend"
+ compress_method: "zstd"
+ evict_instance_id: "vllm-instance-1"
+ retry:
+ max_attempts: 3
+ base_backoff: "100ms"
+ max_backoff: "2s"
+ circuit_breaker:
+ failure_threshold: 5
+ recovery_timeout: "10s"
+ half_open_max_calls: 2
+ - name: dgp.filter.llm.proxy
+```
+
+---
+
+### Key Config Fields
+
+`enabled`
+- Enable/disable kvcache filter.
+
+`vllm_endpoint`
+- Base URL for `/tokenize`.
+
+`lmcache_endpoint`
+- Base URL for LMCache controller APIs.
+
+`lookup_routing_timeout`
+- Short timeout for synchronous routing lookup in decode path.
+
+`token_cache`
+- Local in-memory token cache settings.
+- Cache key is SHA-256 of `model + "\x00" + prompt`.
+
+`cache_strategy.load_threshold`
+- Ratio in `[0,1]`.
+- Current code compares this threshold with measured CPU usage ratio.
+
+`cache_strategy.memory_threshold`
+- Ratio in `[0,1]`, used for eviction decisions.
+
+`cache_strategy.hot_content_threshold`
+- Minimum access count within `hot_window` to mark content as hot for pinning.
+
+`retry`
+- Retry settings for LMCache API calls (`lookup/pin/compress/evict`).
+
+`circuit_breaker`
+- Protection for tokenizer and LMCache client calls.
+
+---
+
+### Operational Notes
+
+1. Best-effort behavior
+
+- If tokenization or LMCache calls fail, main request path still continues.
+- Failures are logged with `[kvcache]` prefix.
+
+2. Context cancellation
+
+- Cache-management work uses request-scoped context and timeout.
+- Client cancel/timeout can stop ongoing background operations.
+
+3. Compression trigger semantics
+
+- `load_threshold` is treated as a ratio.
+- Compression decision currently uses CPU usage ratio check.
+
+4. Real engine vs mock
+
+- Full performance benefit requires real vLLM + LMCache deployment.
+- For smoke validation, you can run mock `/tokenize` and LMCache APIs to
verify chain integration and routing hint behavior.
+
+---
+
+### Validation Checklist
+
+1. `endpoint.id` matches LMCache `instance_id` for targeted routing.
+2. `dgp.filter.ai.kvcache` is placed before `dgp.filter.llm.proxy`.
+3. Logs show:
+ - routing lookup success/failure
+ - strategy action failures (if any)
+4. Compare baseline vs enabled:
+ - hit ratio
+ - p95/p99 latency
+ - upstream load distribution
+
diff --git a/docs/ai/kvcache_CN.md b/docs/ai/kvcache_CN.md
new file mode 100644
index 00000000..bd7701c2
--- /dev/null
+++ b/docs/ai/kvcache_CN.md
@@ -0,0 +1,185 @@
+## AI KVCache 过滤器配置
+
+[English](./kvcache.md) | 中文
+
+本文档说明如何在 Dubbo-go-Pixiu 中配置和使用 `dgp.filter.ai.kvcache` 过滤器。
+
+该过滤器通过对接 vLLM(`/tokenize`)与 LMCache controller
API(`/lookup`、`/pin`、`/compress`、`/evict`),实现:
+
+- cache 感知的路由提示
+- 异步缓存管理动作触发
+- 主请求链路非阻塞
+
+---
+
+### 架构与请求链路
+
+`dgp.filter.ai.kvcache` 是一个 HTTP Decode 过滤器。典型处理流程如下:
+
+1. 读取请求体,提取 `model` 与 `prompt`(必要时从 `messages` 回退提取)。
+2. 在 TokenManager 中记录本地热点统计(`model + prompt`)。
+3. 尝试 cache 感知路由:
+ - 从本地 token cache 获取 token
+ - 调用 LMCache `/lookup`
+ - 把优选端点提示写入上下文(`llm_preferred_endpoint_id`)
+4. 启动异步缓存管理协程(best-effort):
+ - 调用 vLLM `/tokenize`
+ - 必要时调用 LMCache `/lookup`
+ - 按策略执行 `compress` / `pin` / `evict`
+5. 立即放行后续过滤器链(主请求不被缓存管理阻塞)。
+
+---
+
+### 路由约定(重点)
+
+当前 cache 路由依赖 **instance id 对齐**:
+
+- kvcache 过滤器在上下文写入 `llm_preferred_endpoint_id`
+- `dgp.filter.llm.proxy` 读取该值并按 `endpoint.id` 选目标实例
+
+因此要生效必须满足:
+
+- `LMCache lookup 返回的 instance_id` 与 `pixiu cluster endpoint.id` 一致
+
+如果不一致,请求会自动回退到正常负载均衡。
+
+说明:
+
+- 当前实现 **不会** 调 LMCache `/query_worker_info`
+- 也就是路由依据是 instance-id 合约,不是动态查询 IP/Port
+
+---
+
+### 配置示例
+
+```yaml
+listeners:
+ - name: net/http
+ protocol_type: HTTP
+ address:
+ socket_address:
+ address: 0.0.0.0
+ port: 8888
+ filter_chains:
+ filters:
+ - name: dgp.filter.httpconnectionmanager
+ config:
+ route_config:
+ routes:
+ - match:
+ prefix: /
+ route:
+ cluster: vllm_cluster
+ http_filters:
+ - name: dgp.filter.ai.kvcache
+ config:
+ enabled: true
+ vllm_endpoint: "http://127.0.0.1:8000"
+ lmcache_endpoint: "http://127.0.0.1:9000"
+ default_model: "demo"
+ request_timeout: "2s"
+ lookup_routing_timeout: "50ms"
+ hot_window: "5m"
+ hot_max_records: 300
+ token_cache:
+ enabled: true
+ max_size: 1024
+ ttl: "10m"
+ cache_strategy:
+ enable_compression: true
+ enable_pinning: true
+ enable_eviction: true
+ load_threshold: 0.7
+ memory_threshold: 0.85
+ hot_content_threshold: 10
+ pin_instance_id: "vllm-instance-1"
+ pin_location: "LocalCPUBackend"
+ compress_instance_id: "vllm-instance-1"
+ compress_location: "LocalCPUBackend"
+ compress_method: "zstd"
+ evict_instance_id: "vllm-instance-1"
+ retry:
+ max_attempts: 3
+ base_backoff: "100ms"
+ max_backoff: "2s"
+ circuit_breaker:
+ failure_threshold: 5
+ recovery_timeout: "10s"
+ half_open_max_calls: 2
+ - name: dgp.filter.llm.proxy
+```
+
+---
+
+### 关键配置项说明
+
+`enabled`
+- 开关,是否启用 kvcache 过滤器。
+
+`vllm_endpoint`
+- `/tokenize` 的上游地址。
+
+`lmcache_endpoint`
+- LMCache controller API 的上游地址。
+
+`lookup_routing_timeout`
+- Decode 路径里同步 lookup 的短超时时间。
+
+`token_cache`
+- 本地内存 token 缓存配置。
+- 当前 cache key 为 `model + "\x00" + prompt` 的 SHA-256。
+
+`cache_strategy.load_threshold`
+- 比例值,范围 `[0,1]`。
+- 当前代码里用于和 CPU 使用率比例比较,决定是否触发压缩。
+
+`cache_strategy.memory_threshold`
+- 比例值,范围 `[0,1]`,用于驱逐决策。
+
+`cache_strategy.hot_content_threshold`
+- 在 `hot_window` 内达到该访问次数后,判定为热点内容,用于 pin。
+
+`retry`
+- LMCache API(`lookup/pin/compress/evict`)调用的重试参数。
+
+`circuit_breaker`
+- tokenizer 与 LMCache 调用的熔断保护参数。
+
+---
+
+### 运行与行为说明
+
+1. Best-effort 行为
+
+- tokenize 或 LMCache 调用失败时,主请求链路仍会继续。
+- 错误日志统一使用 `[kvcache]` 前缀。
+
+2. 请求取消传播
+
+- 缓存管理使用 request-scoped context + timeout。
+- 客户端取消/超时可以及时中断后台缓存任务。
+
+3. 压缩触发语义
+
+- `load_threshold` 按比例处理。
+- 当前压缩决策依据 CPU 使用率比例。
+
+4. 真实引擎与 mock
+
+- 真正收益评估需要真实 vLLM + LMCache 实例部署。
+- 联调阶段可先用 mock `/tokenize` 与 mock LMCache API 验证链路接入和路由提示是否正确。
+
+---
+
+### 验证清单
+
+1. `endpoint.id` 与 LMCache `instance_id` 对齐。
+2. `dgp.filter.ai.kvcache` 放在 `dgp.filter.llm.proxy` 之前。
+3. 日志中可观测到:
+ - routing lookup 成功/失败
+ - strategy 执行失败(若有)
+4. 对比开启前后:
+ - 命中率
+ - p95/p99 延迟
+ - 上游实例流量分布
+
diff --git a/go.mod b/go.mod
index 4ffef83f..5346503a 100644
--- a/go.mod
+++ b/go.mod
@@ -21,6 +21,7 @@ require (
github.com/gin-gonic/gin v1.10.1
github.com/go-errors/errors v1.0.1
github.com/go-playground/assert/v2 v2.2.0
+ github.com/go-resty/resty/v2 v2.7.0
github.com/go-sql-driver/mysql v1.9.2
github.com/goinggo/mapstructure v0.0.0-20140717182941-194205d9b4a9
github.com/golang-jwt/jwt/v4 v4.5.2
@@ -39,7 +40,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.23.0
github.com/prometheus/common v0.65.0
- github.com/smartystreets/goconvey v1.7.2
+ github.com/smartystreets/goconvey v1.8.1
github.com/spf13/cast v1.7.1
github.com/spf13/cobra v1.5.0
github.com/spf13/viper v1.8.1
@@ -146,7 +147,6 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
- github.com/go-resty/resty/v2 v2.7.0 // indirect
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
github.com/gobwas/glob v0.2.3 // indirect
github.com/goccy/go-json v0.10.3 // indirect
@@ -154,7 +154,7 @@ require (
github.com/golang/snappy v0.0.4 // indirect
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
github.com/google/uuid v1.6.0 // indirect
- github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 //
indirect
+ github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
@@ -219,7 +219,7 @@ require (
github.com/segmentio/asm v1.2.1 // indirect
github.com/shirou/gopsutil/v3 v3.22.2 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect
- github.com/smartystreets/assertions v1.2.0 // indirect
+ github.com/smarty/assertions v1.15.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/afero v1.10.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
diff --git a/go.sum b/go.sum
index 7e5adc62..adb0a94f 100644
--- a/go.sum
+++ b/go.sum
@@ -552,8 +552,9 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod
h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+
github.com/googleapis/gax-go/v2 v2.0.5/go.mod
h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/google-cloud-go-testing
v0.0.0-20200911160855-bcd43fbb19e8/go.mod
h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod
h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
-github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00
h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod
h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
+github.com/gopherjs/gopherjs v1.17.2
h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
+github.com/gopherjs/gopherjs v1.17.2/go.mod
h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/context v1.1.1/go.mod
h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.2/go.mod
h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/mux v1.7.3/go.mod
h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
@@ -973,14 +974,14 @@ github.com/sirupsen/logrus v1.6.0/go.mod
h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrf
github.com/sirupsen/logrus v1.7.0/go.mod
h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.0
h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod
h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
+github.com/smarty/assertions v1.15.0
h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
+github.com/smarty/assertions v1.15.0/go.mod
h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod
h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/assertions v1.1.0/go.mod
h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
-github.com/smartystreets/assertions v1.2.0
h1:42S6lae5dvLc7BrLu/0ugRtcFVjoJNMC/N3yZFZkDFs=
-github.com/smartystreets/assertions v1.2.0/go.mod
h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod
h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/smartystreets/goconvey v1.6.4/go.mod
h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
-github.com/smartystreets/goconvey v1.7.2
h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
-github.com/smartystreets/goconvey v1.7.2/go.mod
h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM=
+github.com/smartystreets/goconvey v1.8.1
h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
+github.com/smartystreets/goconvey v1.8.1/go.mod
h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60=
github.com/soheilhy/cmux v0.1.4/go.mod
h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/soheilhy/cmux v0.1.5-0.20210205191134-5ec6847320e5
h1:GJTW+uNMIV1RKwox+T4aN0/sQlYRg78uHZf2H0aBcDw=
github.com/soheilhy/cmux v0.1.5-0.20210205191134-5ec6847320e5/go.mod
h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0=
diff --git a/pkg/common/constant/key.go b/pkg/common/constant/key.go
index 4f70cc32..43afe753 100644
--- a/pkg/common/constant/key.go
+++ b/pkg/common/constant/key.go
@@ -59,6 +59,7 @@ const (
LLMProxyFilter = "dgp.filter.llm.proxy"
LLMTokenizerFilter = "dgp.filter.llm.tokenizer"
+ AIKVCacheFilter = "dgp.filter.ai.kvcache"
MCPServerFilter = "dgp.filter.mcp.mcpserver"
)
diff --git a/pkg/filter/ai/kvcache/circuit_breaker.go
b/pkg/filter/ai/kvcache/circuit_breaker.go
new file mode 100644
index 00000000..8fb37de2
--- /dev/null
+++ b/pkg/filter/ai/kvcache/circuit_breaker.go
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "errors"
+ "sync"
+ "time"
+)
+
+var ErrCircuitBreakerOpen = errors.New("kvcache circuit breaker open")
+
+type CircuitState int
+
+const (
+ CircuitClosed CircuitState = 0
+ CircuitOpen CircuitState = 1
+ CircuitHalfOpen CircuitState = 2
+)
+
+type CircuitBreaker struct {
+ state CircuitState
+ failureCount int
+ lastFailTime time.Time
+ halfOpenCalls int
+ config CircuitBreakerConfig
+ mutex sync.Mutex
+}
+
+type CircuitBreakerConfig struct {
+ FailureThreshold int `yaml:"failure_threshold"
json:"failure_threshold" mapstructure:"failure_threshold"`
+ RecoveryTimeout time.Duration `yaml:"recovery_timeout"
json:"recovery_timeout" mapstructure:"recovery_timeout"`
+ HalfOpenMaxCalls int `yaml:"half_open_max_calls"
json:"half_open_max_calls" mapstructure:"half_open_max_calls"`
+}
+
+func NewCircuitBreaker(cfg CircuitBreakerConfig) *CircuitBreaker {
+ return &CircuitBreaker{config: cfg, state: CircuitClosed}
+}
+
+func (cb *CircuitBreaker) Execute(operation func() error) error {
+ if cb == nil {
+ return operation()
+ }
+ if !cb.allow() {
+ return ErrCircuitBreakerOpen
+ }
+ err := operation()
+ cb.recordResult(err)
+ return err
+}
+
+func (cb *CircuitBreaker) allow() bool {
+ cb.mutex.Lock()
+ defer cb.mutex.Unlock()
+
+ switch cb.state {
+ case CircuitOpen:
+ if time.Since(cb.lastFailTime) >= cb.config.RecoveryTimeout {
+ cb.state = CircuitHalfOpen
+ cb.halfOpenCalls = 0
+ } else {
+ return false
+ }
+ case CircuitHalfOpen:
+ if cb.halfOpenCalls >= cb.config.HalfOpenMaxCalls {
+ return false
+ }
+ cb.halfOpenCalls++
+ }
+ return true
+}
+
+func (cb *CircuitBreaker) recordResult(err error) {
+ cb.mutex.Lock()
+ defer cb.mutex.Unlock()
+
+ if err == nil {
+ cb.failureCount = 0
+ if cb.state == CircuitHalfOpen || cb.state == CircuitOpen {
+ cb.state = CircuitClosed
+ cb.halfOpenCalls = 0
+ }
+ return
+ }
+
+ cb.failureCount++
+ cb.lastFailTime = time.Now()
+ if cb.failureCount >= cb.config.FailureThreshold {
+ cb.state = CircuitOpen
+ }
+}
diff --git a/pkg/filter/ai/kvcache/config.go b/pkg/filter/ai/kvcache/config.go
new file mode 100644
index 00000000..3c93d7b3
--- /dev/null
+++ b/pkg/filter/ai/kvcache/config.go
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "fmt"
+ "strings"
+ "time"
+)
+
+const (
+ minRatio = 0.0
+ maxRatio = 1.0
+
+ defaultRequestTimeout = 2 * time.Second
+ defaultLookupRoutingTimeout = 50 * time.Millisecond
+ defaultHotWindow = 5 * time.Minute
+ defaultHotMaxRecords = 300
+ defaultHotMaxKeys = 0
+ defaultRetryMaxAttempts = 3
+ defaultRetryBaseBackoff = 100 * time.Millisecond
+ defaultRetryMaxBackoff = 2 * time.Second
+ defaultCBFailureThreshold = 5
+ defaultCBRecoveryTimeout = 10 * time.Second
+ defaultCBHalfOpenMaxCalls = 2
+ defaultCompressMethod = "zstd"
+)
+
+type Config struct {
+ Enabled bool `yaml:"enabled"
json:"enabled" mapstructure:"enabled"`
+ VLLMEndpoint string `yaml:"vllm_endpoint"
json:"vllm_endpoint" mapstructure:"vllm_endpoint"`
+ LMCacheEndpoint string `yaml:"lmcache_endpoint"
json:"lmcache_endpoint" mapstructure:"lmcache_endpoint"`
+ DefaultModel string `yaml:"default_model"
json:"default_model" mapstructure:"default_model"`
+ RequestTimeout time.Duration `yaml:"request_timeout"
json:"request_timeout" mapstructure:"request_timeout"`
+ LookupRoutingTimeout time.Duration
`yaml:"lookup_routing_timeout" json:"lookup_routing_timeout"
mapstructure:"lookup_routing_timeout"`
+ HotWindow time.Duration `yaml:"hot_window"
json:"hot_window" mapstructure:"hot_window"`
+ HotMaxRecords int `yaml:"hot_max_records"
json:"hot_max_records" mapstructure:"hot_max_records"`
+ HotMaxKeys int `yaml:"hot_max_keys"
json:"hot_max_keys" mapstructure:"hot_max_keys"`
+ MaxIdleConns int `yaml:"max_idle_conns"
json:"max_idle_conns" mapstructure:"max_idle_conns"`
+ MaxIdleConnsPerHost int
`yaml:"max_idle_conns_per_host" json:"max_idle_conns_per_host"
mapstructure:"max_idle_conns_per_host"`
+ MaxConnsPerHost int `yaml:"max_conns_per_host"
json:"max_conns_per_host" mapstructure:"max_conns_per_host"`
+ TokenCache TokenCacheConfig `yaml:"token_cache"
json:"token_cache" mapstructure:"token_cache"`
+ CacheStrategy CacheStrategyConfig `yaml:"cache_strategy"
json:"cache_strategy" mapstructure:"cache_strategy"`
+ CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"
json:"circuit_breaker" mapstructure:"circuit_breaker"`
+ Retry RetryConfig `yaml:"retry" json:"retry"
mapstructure:"retry"`
+}
+
+type TokenCacheConfig struct {
+ MaxSize int `yaml:"max_size" json:"max_size"
mapstructure:"max_size"`
+ TTL time.Duration `yaml:"ttl" json:"ttl" mapstructure:"ttl"`
+ Enabled bool `yaml:"enabled" json:"enabled"
mapstructure:"enabled"`
+}
+
+type CacheStrategyConfig struct {
+ EnableCompression bool `yaml:"enable_compression"
json:"enable_compression" mapstructure:"enable_compression"`
+ EnablePinning bool `yaml:"enable_pinning"
json:"enable_pinning" mapstructure:"enable_pinning"`
+ EnableEviction bool `yaml:"enable_eviction"
json:"enable_eviction" mapstructure:"enable_eviction"`
+ MemoryThreshold float64 `yaml:"memory_threshold"
json:"memory_threshold" mapstructure:"memory_threshold"`
+ HotContentThreshold int `yaml:"hot_content_threshold"
json:"hot_content_threshold" mapstructure:"hot_content_threshold"`
+ LoadThreshold float64 `yaml:"load_threshold"
json:"load_threshold" mapstructure:"load_threshold"`
+ PinInstanceID string `yaml:"pin_instance_id"
json:"pin_instance_id" mapstructure:"pin_instance_id"`
+ PinLocation string `yaml:"pin_location" json:"pin_location"
mapstructure:"pin_location"`
+ CompressInstanceID string `yaml:"compress_instance_id"
json:"compress_instance_id" mapstructure:"compress_instance_id"`
+ CompressLocation string `yaml:"compress_location"
json:"compress_location" mapstructure:"compress_location"`
+ CompressMethod string `yaml:"compress_method"
json:"compress_method" mapstructure:"compress_method"`
+ EvictInstanceID string `yaml:"evict_instance_id"
json:"evict_instance_id" mapstructure:"evict_instance_id"`
+}
+
+type RetryConfig struct {
+ MaxAttempts int `yaml:"max_attempts" json:"max_attempts"
mapstructure:"max_attempts"`
+ BaseBackoff time.Duration `yaml:"base_backoff" json:"base_backoff"
mapstructure:"base_backoff"`
+ MaxBackoff time.Duration `yaml:"max_backoff" json:"max_backoff"
mapstructure:"max_backoff"`
+}
+
+func (c *Config) Validate() error {
+ if !c.Enabled {
+ return nil
+ }
+ if strings.TrimSpace(c.VLLMEndpoint) == "" {
+ return fmt.Errorf("[kvcache] vllm_endpoint is required when
enabled")
+ }
+ if strings.TrimSpace(c.LMCacheEndpoint) == "" {
+ return fmt.Errorf("[kvcache] lmcache_endpoint is required when
enabled")
+ }
+ if c.TokenCache.MaxSize < 0 {
+ return fmt.Errorf("[kvcache] token_cache.max_size must be >= 0")
+ }
+ if c.CacheStrategy.MemoryThreshold < minRatio ||
c.CacheStrategy.MemoryThreshold > maxRatio {
+ return fmt.Errorf("[kvcache] cache_strategy.memory_threshold
must be between 0 and 1")
+ }
+ if c.CacheStrategy.LoadThreshold < minRatio ||
c.CacheStrategy.LoadThreshold > maxRatio {
+ return fmt.Errorf("[kvcache] cache_strategy.load_threshold must
be between 0 and 1")
+ }
+ if c.CacheStrategy.HotContentThreshold < 0 {
+ return fmt.Errorf("[kvcache]
cache_strategy.hot_content_threshold must be >= 0")
+ }
+ if c.Retry.MaxAttempts < 0 {
+ return fmt.Errorf("[kvcache] retry.max_attempts must be >= 0")
+ }
+ if c.Retry.BaseBackoff < 0 || c.Retry.MaxBackoff < 0 {
+ return fmt.Errorf("[kvcache] retry backoff durations must be >=
0")
+ }
+ if c.HotWindow < 0 {
+ return fmt.Errorf("[kvcache] hot_window must be >= 0")
+ }
+ if c.HotMaxRecords < 0 {
+ return fmt.Errorf("[kvcache] hot_max_records must be >= 0")
+ }
+ if c.HotMaxKeys < 0 {
+ return fmt.Errorf("[kvcache] hot_max_keys must be >= 0")
+ }
+ return nil
+}
+
+func (c *Config) ApplyDefaults() {
+ if c.RequestTimeout <= 0 {
+ c.RequestTimeout = defaultRequestTimeout
+ }
+ if c.LookupRoutingTimeout <= 0 {
+ c.LookupRoutingTimeout = defaultLookupRoutingTimeout
+ }
+ if c.HotWindow <= 0 {
+ c.HotWindow = defaultHotWindow
+ }
+ if c.HotMaxRecords <= 0 {
+ c.HotMaxRecords = defaultHotMaxRecords
+ }
+ if c.HotMaxKeys < 0 {
+ c.HotMaxKeys = defaultHotMaxKeys
+ }
+ if c.Retry.MaxAttempts <= 0 {
+ c.Retry.MaxAttempts = defaultRetryMaxAttempts
+ }
+ if c.Retry.BaseBackoff <= 0 {
+ c.Retry.BaseBackoff = defaultRetryBaseBackoff
+ }
+ if c.Retry.MaxBackoff <= 0 {
+ c.Retry.MaxBackoff = defaultRetryMaxBackoff
+ }
+ if c.CircuitBreaker.FailureThreshold <= 0 {
+ c.CircuitBreaker.FailureThreshold = defaultCBFailureThreshold
+ }
+ if c.CircuitBreaker.RecoveryTimeout <= 0 {
+ c.CircuitBreaker.RecoveryTimeout = defaultCBRecoveryTimeout
+ }
+ if c.CircuitBreaker.HalfOpenMaxCalls <= 0 {
+ c.CircuitBreaker.HalfOpenMaxCalls = defaultCBHalfOpenMaxCalls
+ }
+ if c.CacheStrategy.CompressMethod == "" {
+ c.CacheStrategy.CompressMethod = defaultCompressMethod
+ }
+}
+
+func (c *Config) DeepCopy() *Config {
+ if c == nil {
+ return nil
+ }
+ cp := *c
+ return &cp
+}
diff --git a/pkg/filter/ai/kvcache/filter.go b/pkg/filter/ai/kvcache/filter.go
new file mode 100644
index 00000000..46da5287
--- /dev/null
+++ b/pkg/filter/ai/kvcache/filter.go
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "context"
+ "net/http"
+)
+
+import (
+ "github.com/go-resty/resty/v2"
+)
+
+import (
+ "github.com/apache/dubbo-go-pixiu/pkg/common/constant"
+ "github.com/apache/dubbo-go-pixiu/pkg/common/extension/filter"
+ contexthttp "github.com/apache/dubbo-go-pixiu/pkg/context/http"
+ "github.com/apache/dubbo-go-pixiu/pkg/logger"
+)
+
+const (
+ Kind = constant.AIKVCacheFilter
+)
+
+func init() {
+ filter.RegisterHttpFilter(&Plugin{})
+}
+
+type (
+ Plugin struct{}
+
+ FilterFactory struct {
+ cfg *Config
+ httpClient *http.Client
+ resty *resty.Client
+ tokenManager *TokenManager
+ lmcacheClient *LMCacheClient
+ cacheStrategy *CacheStrategy
+ }
+
+ Filter struct {
+ cfg *Config
+ tokenManager *TokenManager
+ lmcacheClient *LMCacheClient
+ cacheStrategy *CacheStrategy
+ }
+)
+
+func (p *Plugin) Kind() string { return Kind }
+
+func (p *Plugin) CreateFilterFactory() (filter.HttpFilterFactory, error) {
+ return &FilterFactory{cfg: &Config{}}, nil
+}
+
+func (factory *FilterFactory) Config() any { return factory.cfg }
+
+func (factory *FilterFactory) Apply() error {
+ factory.cfg.ApplyDefaults()
+ if err := factory.cfg.Validate(); err != nil {
+ return err
+ }
+ cfg := factory.cfg
+ factory.httpClient = &http.Client{
+ Timeout: cfg.RequestTimeout,
+ Transport: &http.Transport{
+ MaxIdleConns: cfg.MaxIdleConns,
+ MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost,
+ MaxConnsPerHost: cfg.MaxConnsPerHost,
+ },
+ }
+ factory.resty = resty.NewWithClient(factory.httpClient).
+ SetTimeout(cfg.RequestTimeout)
+
+ cbToken := NewCircuitBreaker(cfg.CircuitBreaker)
+ cbLMCache := NewCircuitBreaker(cfg.CircuitBreaker)
+ factory.tokenManager = NewTokenManager(cfg.VLLMEndpoint, factory.resty,
cfg.TokenCache, cbToken, cfg.HotWindow, cfg.HotMaxRecords, cfg.HotMaxKeys)
+ factory.lmcacheClient = NewLMCacheClient(cfg.LMCacheEndpoint,
factory.resty, cfg.Retry, cbLMCache)
+ factory.cacheStrategy = NewCacheStrategy(cfg.CacheStrategy,
factory.lmcacheClient, factory.tokenManager)
+ return nil
+}
+
+func (factory *FilterFactory) PrepareFilterChain(_ *contexthttp.HttpContext,
chain filter.FilterChain) error {
+ f := &Filter{
+ cfg: factory.cfg,
+ tokenManager: factory.tokenManager,
+ lmcacheClient: factory.lmcacheClient,
+ cacheStrategy: factory.cacheStrategy,
+ }
+ chain.AppendDecodeFilters(f)
+ return nil
+}
+
+func (f *Filter) Decode(hc *contexthttp.HttpContext) filter.FilterStatus {
+ if f.cfg == nil || !f.cfg.Enabled {
+ return filter.Continue
+ }
+ if f.cacheStrategy != nil {
+ f.cacheStrategy.RecordRequest()
+ }
+ body, err := readRequestBody(hc.Request)
+ if err != nil {
+ logger.Warnf("[kvcache] read request body failed: %v", err)
+ return filter.Continue
+ }
+ prompt, model, err := extractPromptAndModel(body)
+ if err != nil {
+ logger.Warnf("[kvcache] parse request body failed: %v", err)
+ return filter.Continue
+ }
+ if prompt == "" {
+ return filter.Continue
+ }
+ if model == "" {
+ model = f.cfg.DefaultModel
+ }
+
+ f.tokenManager.RecordHot(model, prompt)
+
+ cacheStatus, routed := f.tryRouteToCachedInstance(hc, model, prompt)
+
+ ctx, cancel := context.WithTimeout(requestScopedContext(hc),
effectiveTimeout(hc, f.cfg))
+ go func() {
+ defer cancel()
+ f.manageCache(ctx, model, prompt, body, cacheStatus, routed)
+ }()
+ return filter.Continue
+}
diff --git a/pkg/filter/ai/kvcache/handlers.go
b/pkg/filter/ai/kvcache/handlers.go
new file mode 100644
index 00000000..3d75a487
--- /dev/null
+++ b/pkg/filter/ai/kvcache/handlers.go
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+)
+
+import (
+ contexthttp "github.com/apache/dubbo-go-pixiu/pkg/context/http"
+ "github.com/apache/dubbo-go-pixiu/pkg/logger"
+)
+
+const llmPreferredEndpointIDKey = "llm_preferred_endpoint_id"
+
+func (f *Filter) manageCache(ctx context.Context, model string, prompt string,
rawBody []byte, cacheStatus *LookupResponse, lookupDone bool) {
+ if ctx.Err() != nil {
+ return
+ }
+ tokens, err := f.tokenManager.GetTokens(ctx, model, prompt, rawBody)
+ if err != nil {
+ logger.Warnf("[kvcache] tokenize failed: %v", err)
+ return
+ }
+ if ctx.Err() != nil {
+ return
+ }
+ if !lookupDone || cacheStatus == nil {
+ cacheStatus, err = f.lmcacheClient.Lookup(ctx,
&LookupRequest{Tokens: tokens})
+ if err != nil {
+ logger.Warnf("[kvcache] lookup failed: %v", err)
+ return
+ }
+ }
+ decision := f.cacheStrategy.MakeDecision(ctx, cacheStatus, model,
prompt)
+ if ctx.Err() != nil {
+ return
+ }
+ if err := f.cacheStrategy.ExecuteDecision(ctx, decision, tokens); err
!= nil {
+ logger.Warnf("[kvcache] execute strategy failed: %v", err)
+ }
+}
+
+func readRequestBody(req *http.Request) ([]byte, error) {
+ if req == nil || req.Body == nil {
+ return nil, nil
+ }
+ if req.GetBody != nil {
+ reader, err := req.GetBody()
+ if err == nil {
+ defer reader.Close()
+ return io.ReadAll(reader)
+ }
+ }
+ bodyBytes, err := io.ReadAll(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ req.Body.Close()
+ req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+ req.GetBody = func() (io.ReadCloser, error) {
+ return io.NopCloser(bytes.NewReader(bodyBytes)), nil
+ }
+ return bodyBytes, nil
+}
+
+func extractPromptAndModel(body []byte) (string, string, error) {
+ if len(body) == 0 {
+ return "", "", nil
+ }
+ var payload map[string]any
+ if err := json.Unmarshal(body, &payload); err != nil {
+ return "", "", err
+ }
+ model, _ := payload["model"].(string)
+ prompt := coercePrompt(payload["prompt"])
+ if prompt == "" {
+ prompt = extractPromptFromMessages(payload["messages"])
+ }
+ return strings.TrimSpace(prompt), model, nil
+}
+
+func coercePrompt(value any) string {
+ switch v := value.(type) {
+ case string:
+ return v
+ case []any:
+ parts := make([]string, 0, len(v))
+ for _, item := range v {
+ if str, ok := item.(string); ok {
+ parts = append(parts, str)
+ }
+ }
+ return strings.Join(parts, "\n")
+ default:
+ return ""
+ }
+}
+
+func extractPromptFromMessages(value any) string {
+ msgs, ok := value.([]any)
+ if !ok {
+ return ""
+ }
+ parts := make([]string, 0, len(msgs))
+ for _, msg := range msgs {
+ msgMap, ok := msg.(map[string]any)
+ if !ok {
+ continue
+ }
+ if content, ok := msgMap["content"].(string); ok {
+ parts = append(parts, content)
+ }
+ }
+ return strings.Join(parts, "\n")
+}
+
+func selectPreferredInstanceID(resp *LookupResponse) string {
+ if resp == nil || len(resp.LayoutInfo) == 0 {
+ return ""
+ }
+ var (
+ selected string
+ maxCount int
+ )
+ for instanceID, layout := range resp.LayoutInfo {
+ if layout.TokenCount > maxCount || selected == "" {
+ selected = instanceID
+ maxCount = layout.TokenCount
+ }
+ }
+ return selected
+}
+
+func effectiveTimeout(hc *contexthttp.HttpContext, cfg *Config) time.Duration {
+ if cfg == nil {
+ return 0
+ }
+ timeout := cfg.RequestTimeout
+ if hc != nil && hc.Timeout > 0 && (timeout <= 0 || hc.Timeout <
timeout) {
+ timeout = hc.Timeout
+ }
+ if timeout <= 0 {
+ return 2 * time.Second
+ }
+ return timeout
+}
+
+func requestScopedContext(hc *contexthttp.HttpContext) context.Context {
+ if hc != nil && hc.Request != nil {
+ return hc.Request.Context()
+ }
+ if hc != nil && hc.Ctx != nil {
+ return hc.Ctx
+ }
+ return context.Background()
+}
+
+func (f *Filter) tryRouteToCachedInstance(hc *contexthttp.HttpContext, model
string, prompt string) (*LookupResponse, bool) {
+ if f == nil || f.tokenManager == nil || f.lmcacheClient == nil {
+ return nil, false
+ }
+ tokens, ok := f.tokenManager.GetCachedTokens(model, prompt)
+ if !ok || len(tokens) == 0 {
+ logger.Debugf("[kvcache] routing lookup skipped: token cache
miss")
+ return nil, false
+ }
+ timeout := effectiveTimeout(hc, f.cfg)
+ if f.cfg != nil && f.cfg.LookupRoutingTimeout > 0 &&
f.cfg.LookupRoutingTimeout < timeout {
+ timeout = f.cfg.LookupRoutingTimeout
+ }
+ ctx, cancel := context.WithTimeout(requestScopedContext(hc), timeout)
+ defer cancel()
+ cacheStatus, err := f.lmcacheClient.Lookup(ctx, &LookupRequest{Tokens:
tokens})
+ if err != nil {
+ logger.Debugf("[kvcache] routing lookup failed: %v", err)
+ return nil, false
+ }
+ instanceID := selectPreferredInstanceID(cacheStatus)
+ if instanceID == "" {
+ logger.Debugf("[kvcache] routing lookup returned empty
instance")
+ return cacheStatus, false
+ }
+ if hc.Params == nil {
+ hc.Params = make(map[string]any)
+ }
+ hc.Params[llmPreferredEndpointIDKey] = instanceID
+ logger.Debugf("[kvcache] routing preferred endpoint set: %s",
instanceID)
+ return cacheStatus, true
+}
diff --git a/pkg/filter/ai/kvcache/handlers_test.go
b/pkg/filter/ai/kvcache/handlers_test.go
new file mode 100644
index 00000000..42a3f011
--- /dev/null
+++ b/pkg/filter/ai/kvcache/handlers_test.go
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "net/http"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/go-resty/resty/v2"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+import (
+ "github.com/apache/dubbo-go-pixiu/pkg/context/mock"
+)
+
+func TestExtractPromptAndModel(t *testing.T) {
+ testCases := []struct {
+ name string
+ body string
+ wantPrompt string
+ wantModel string
+ wantErr bool
+ }{
+ {
+ name: "prompt field",
+ body: `{"model":"m1","prompt":"hello"}`,
+ wantPrompt: "hello",
+ wantModel: "m1",
+ },
+ {
+ name: "prompt array",
+ body: `{"model":"m2","prompt":["a","b"]}`,
+ wantPrompt: "a\nb",
+ wantModel: "m2",
+ },
+ {
+ name: "messages fallback",
+ body:
`{"model":"m3","messages":[{"role":"user","content":"hi"},{"role":"assistant","content":"there"}]}`,
+ wantPrompt: "hi\nthere",
+ wantModel: "m3",
+ },
+ {
+ name: "invalid json",
+ body: `{"model":`,
+ wantErr: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ prompt, model, err :=
extractPromptAndModel([]byte(tc.body))
+ if tc.wantErr {
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ assert.Equal(t, tc.wantPrompt, prompt)
+ assert.Equal(t, tc.wantModel, model)
+ })
+ }
+}
+
+func TestTryRouteToCachedInstance(t *testing.T) {
+ var lookupCalls int64
+ restyClient := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ atomic.AddInt64(&lookupCalls, 1)
+ assert.Equal(t, "/lookup", r.URL.Path)
+ return newHTTPResponse(http.StatusOK,
`{"event_id":"evt","layout_info":{"node-a":{"0":"x","1":2},"node-b":{"0":"y","1":8}}}`),
nil
+ })
+
+ tm := NewTokenManager("", resty.New(), TokenCacheConfig{
+ Enabled: true,
+ MaxSize: 10,
+ TTL: time.Minute,
+ }, nil, time.Minute, 10)
+ key := tm.cacheKey("m1", "prompt-1")
+ tm.storeCache(key, []int{1, 2, 3})
+
+ lmcacheClient := NewLMCacheClient("http://lmcache.local", restyClient,
RetryConfig{
+ MaxAttempts: 1,
+ BaseBackoff: time.Millisecond,
+ MaxBackoff: time.Millisecond,
+ }, nil)
+ f := &Filter{
+ cfg: &Config{
+ RequestTimeout: time.Second,
+ LookupRoutingTimeout: 100 * time.Millisecond,
+ },
+ tokenManager: tm,
+ lmcacheClient: lmcacheClient,
+ }
+
+ req, err := http.NewRequest(http.MethodPost, "http://example.com", nil)
+ require.NoError(t, err)
+ hc := mock.GetMockHTTPContext(req)
+
+ cacheStatus, routed := f.tryRouteToCachedInstance(hc, "m1", "prompt-1")
+ require.True(t, routed)
+ require.NotNil(t, cacheStatus)
+ assert.Equal(t, "node-b", hc.Params[llmPreferredEndpointIDKey])
+ assert.Equal(t, int64(1), atomic.LoadInt64(&lookupCalls))
+
+ cacheStatus, routed = f.tryRouteToCachedInstance(hc, "m1", "prompt-2")
+ assert.False(t, routed)
+ assert.Nil(t, cacheStatus)
+ assert.Equal(t, int64(1), atomic.LoadInt64(&lookupCalls))
+}
diff --git a/pkg/filter/ai/kvcache/lmcache_client.go
b/pkg/filter/ai/kvcache/lmcache_client.go
new file mode 100644
index 00000000..c0410692
--- /dev/null
+++ b/pkg/filter/ai/kvcache/lmcache_client.go
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+)
+
+import (
+ "github.com/go-resty/resty/v2"
+)
+
+type LMCacheClient struct {
+ httpClient *resty.Client
+ baseURL string
+ retry RetryConfig
+ circuitBreaker *CircuitBreaker
+}
+
+type PinRequest struct {
+ Tokens []int `json:"tokens"`
+ InstanceID string `json:"instance_id"`
+ Location string `json:"location"`
+}
+
+type LookupRequest struct {
+ Tokens []int `json:"tokens"`
+}
+
+type CompressRequest struct {
+ Tokens []int `json:"tokens"`
+ InstanceID string `json:"instance_id"`
+ Location string `json:"location"`
+ Method string `json:"method"`
+}
+
+type EvictRequest struct {
+ Tokens []int `json:"tokens"`
+ InstanceID string `json:"instance_id"`
+}
+
+type PinResponse struct {
+ EventID string `json:"event_id"`
+ NumTokens int `json:"num_tokens"`
+}
+
+type CompressResponse struct {
+ EventID string `json:"event_id"`
+ NumTokens int `json:"num_tokens"`
+}
+
+type EvictResponse struct {
+ EventID string `json:"event_id"`
+ NumTokens int `json:"num_tokens"`
+}
+
+func NewLMCacheClient(baseURL string, httpClient *resty.Client, retry
RetryConfig, cb *CircuitBreaker) *LMCacheClient {
+ return &LMCacheClient{
+ httpClient: httpClient,
+ baseURL: strings.TrimRight(baseURL, "/"),
+ retry: retry,
+ circuitBreaker: cb,
+ }
+}
+
+func (lc *LMCacheClient) Pin(ctx context.Context, req *PinRequest)
(*PinResponse, error) {
+ var resp PinResponse
+ if err := lc.doRequestWithRetry(ctx, "/pin", req, &resp, "pin"); err !=
nil {
+ return nil, err
+ }
+ return &resp, nil
+}
+
+func (lc *LMCacheClient) Lookup(ctx context.Context, req *LookupRequest)
(*LookupResponse, error) {
+ var resp LookupResponse
+ if err := lc.doRequestWithRetry(ctx, "/lookup", req, &resp, "lookup");
err != nil {
+ return nil, err
+ }
+ return &resp, nil
+}
+
+func (lc *LMCacheClient) Compress(ctx context.Context, req *CompressRequest)
(*CompressResponse, error) {
+ var resp CompressResponse
+ if err := lc.doRequestWithRetry(ctx, "/compress", req, &resp,
"compress"); err != nil {
+ return nil, err
+ }
+ return &resp, nil
+}
+
+func (lc *LMCacheClient) Evict(ctx context.Context, req *EvictRequest)
(*EvictResponse, error) {
+ var resp EvictResponse
+ if err := lc.doRequestWithRetry(ctx, "/evict", req, &resp, "evict");
err != nil {
+ return nil, err
+ }
+ return &resp, nil
+}
+
+func (lc *LMCacheClient) doRequestWithRetry(ctx context.Context, path string,
payload any, out any, op string) error {
+ var lastErr error
+ maxAttempts := lc.retry.MaxAttempts
+ if maxAttempts < 1 {
+ maxAttempts = 1
+ }
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ err := lc.execute(func() error {
+ return lc.doRequest(ctx, path, payload, out)
+ })
+ if err == nil {
+ return nil
+ }
+ if err == ErrCircuitBreakerOpen {
+ return err
+ }
+ lastErr = err
+ backoff := lc.backoffDuration(attempt)
+ select {
+ case <-time.After(backoff):
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+ return lastErr
+}
+
+func (lc *LMCacheClient) doRequest(ctx context.Context, path string, payload
any, out any) error {
+ url := lc.baseURL + path
+ resp, err := lc.httpClient.R().
+ SetContext(ctx).
+ SetHeader("Content-Type", "application/json").
+ SetBody(payload).
+ Post(url)
+ if err != nil {
+ return fmt.Errorf("[kvcache] call lmcache: %w", err)
+ }
+ if resp.StatusCode() < 200 || resp.StatusCode() >= 300 {
+ return fmt.Errorf("[kvcache] lmcache status %d: %s",
resp.StatusCode(), strings.TrimSpace(string(resp.Body())))
+ }
+ if out == nil {
+ return nil
+ }
+ if err := json.Unmarshal(resp.Body(), out); err != nil {
+ return fmt.Errorf("[kvcache] decode lmcache response: %w", err)
+ }
+ return nil
+}
+
+func (lc *LMCacheClient) execute(operation func() error) error {
+ if lc.circuitBreaker == nil {
+ return operation()
+ }
+ return lc.circuitBreaker.Execute(operation)
+}
+
+func (lc *LMCacheClient) backoffDuration(attempt int) time.Duration {
+ backoff := lc.retry.BaseBackoff
+ for i := 0; i < attempt; i++ {
+ backoff *= 2
+ if backoff >= lc.retry.MaxBackoff {
+ return lc.retry.MaxBackoff
+ }
+ }
+ if backoff > lc.retry.MaxBackoff {
+ return lc.retry.MaxBackoff
+ }
+ return backoff
+}
diff --git a/pkg/filter/ai/kvcache/lmcache_client_test.go
b/pkg/filter/ai/kvcache/lmcache_client_test.go
new file mode 100644
index 00000000..346d2380
--- /dev/null
+++ b/pkg/filter/ai/kvcache/lmcache_client_test.go
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "context"
+ "net/http"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestLMCacheClientRetrySuccess(t *testing.T) {
+ var attempts int64
+ restyClient := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ assert.Equal(t, "/lookup", r.URL.Path)
+ attempt := atomic.AddInt64(&attempts, 1)
+ if attempt < 3 {
+ return newHTTPResponse(http.StatusBadGateway,
"temporary failure"), nil
+ }
+ return newHTTPResponse(http.StatusOK,
`{"event_id":"evt-1","layout_info":{"node-a":{"0":"mem0","1":6}}}`), nil
+ })
+
+ client := NewLMCacheClient("http://lmcache.local", restyClient,
RetryConfig{
+ MaxAttempts: 3,
+ BaseBackoff: time.Millisecond,
+ MaxBackoff: 2 * time.Millisecond,
+ }, nil)
+
+ resp, err := client.Lookup(context.Background(), &LookupRequest{Tokens:
[]int{1, 2}})
+ require.NoError(t, err)
+ assert.Equal(t, int64(3), atomic.LoadInt64(&attempts))
+ assert.Equal(t, "evt-1", resp.EventID)
+ assert.Equal(t, 6, resp.LayoutInfo["node-a"].TokenCount)
+}
+
+func TestLMCacheClientContextCancelDuringBackoff(t *testing.T) {
+ restyClient := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ return newHTTPResponse(http.StatusBadGateway, "retry"), nil
+ })
+
+ client := NewLMCacheClient("http://lmcache.local", restyClient,
RetryConfig{
+ MaxAttempts: 3,
+ BaseBackoff: 100 * time.Millisecond,
+ MaxBackoff: 100 * time.Millisecond,
+ }, nil)
+
+ ctx, cancel := context.WithTimeout(context.Background(),
20*time.Millisecond)
+ defer cancel()
+ _, err := client.Pin(ctx, &PinRequest{Tokens: []int{1}})
+ assert.ErrorIs(t, err, context.DeadlineExceeded)
+}
+
+func TestLMCacheClientCircuitBreakerOpen(t *testing.T) {
+ var called int64
+ restyClient := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ atomic.AddInt64(&called, 1)
+ return newHTTPResponse(http.StatusOK,
`{"event_id":"evt-2","num_tokens":1}`), nil
+ })
+
+ cb := NewCircuitBreaker(CircuitBreakerConfig{
+ FailureThreshold: 1,
+ RecoveryTimeout: time.Minute,
+ HalfOpenMaxCalls: 1,
+ })
+ cb.state = CircuitOpen
+ cb.lastFailTime = time.Now()
+
+ client := NewLMCacheClient("http://lmcache.local", restyClient,
RetryConfig{
+ MaxAttempts: 2,
+ BaseBackoff: time.Millisecond,
+ MaxBackoff: time.Millisecond,
+ }, cb)
+ _, err := client.Evict(context.Background(), &EvictRequest{Tokens:
[]int{1}})
+ assert.ErrorIs(t, err, ErrCircuitBreakerOpen)
+ assert.Equal(t, int64(0), atomic.LoadInt64(&called))
+}
diff --git a/pkg/filter/ai/kvcache/load_monitor.go
b/pkg/filter/ai/kvcache/load_monitor.go
new file mode 100644
index 00000000..8fee9974
--- /dev/null
+++ b/pkg/filter/ai/kvcache/load_monitor.go
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "runtime"
+ runtimemetrics "runtime/metrics"
+ "sync"
+ "time"
+)
+
+type LoadMonitor struct {
+ window time.Duration
+ last time.Time
+ count int64
+ rate float64
+ mutex sync.Mutex
+
+ lastCPUSampleAt time.Time
+ lastCPUSeconds float64
+ hasCPUSample bool
+}
+
+func NewLoadMonitor() *LoadMonitor {
+ return &LoadMonitor{
+ window: time.Second,
+ last: time.Now(),
+ }
+}
+
+func (lm *LoadMonitor) RecordRequest() {
+ if lm == nil {
+ return
+ }
+ lm.mutex.Lock()
+ lm.count++
+ lm.mutex.Unlock()
+}
+
+func (lm *LoadMonitor) Snapshot() LoadMetrics {
+ if lm == nil {
+ return LoadMetrics{}
+ }
+ lm.mutex.Lock()
+ defer lm.mutex.Unlock()
+
+ cpuUsage := lm.sampleCPUUsage()
+ now := time.Now()
+ elapsed := now.Sub(lm.last)
+ if elapsed >= lm.window && elapsed > 0 {
+ lm.rate = float64(lm.count) / elapsed.Seconds()
+ lm.count = 0
+ lm.last = now
+ }
+ var ms runtime.MemStats
+ runtime.ReadMemStats(&ms)
+ memUsage := 0.0
+ if ms.Sys > 0 {
+ memUsage = float64(ms.Alloc) / float64(ms.Sys)
+ }
+ return LoadMetrics{
+ CPUUsage: cpuUsage,
+ MemoryUsage: memUsage,
+ RequestRate: lm.rate,
+ }
+}
+
+func (lm *LoadMonitor) sampleCPUUsage() float64 {
+ totalCPUSeconds, gomaxprocs, ok := readRuntimeCPUStats()
+ if !ok {
+ return 0
+ }
+
+ now := time.Now()
+ if !lm.hasCPUSample {
+ lm.lastCPUSeconds = totalCPUSeconds
+ lm.lastCPUSampleAt = now
+ lm.hasCPUSample = true
+ return 0
+ }
+
+ wall := now.Sub(lm.lastCPUSampleAt).Seconds()
+ if wall <= 0 || gomaxprocs <= 0 {
+ return 0
+ }
+
+ cpuDelta := totalCPUSeconds - lm.lastCPUSeconds
+ lm.lastCPUSeconds = totalCPUSeconds
+ lm.lastCPUSampleAt = now
+ if cpuDelta <= 0 {
+ return 0
+ }
+
+ usage := cpuDelta / (wall * gomaxprocs)
+ if usage < 0 {
+ return 0
+ }
+ if usage > 1 {
+ return 1
+ }
+ return usage
+}
+
+func readRuntimeCPUStats() (totalCPUSeconds float64, gomaxprocs float64, ok
bool) {
+ samples := []runtimemetrics.Sample{
+ {Name: "/cpu/classes/total:cpu-seconds"},
+ {Name: "/sched/gomaxprocs:threads"},
+ }
+ runtimemetrics.Read(samples)
+
+ totalValue := samples[0].Value
+ gomaxValue := samples[1].Value
+ if totalValue.Kind() != runtimemetrics.KindFloat64 || gomaxValue.Kind()
!= runtimemetrics.KindUint64 {
+ return 0, 0, false
+ }
+ return totalValue.Float64(), float64(gomaxValue.Uint64()), true
+}
diff --git a/pkg/filter/ai/kvcache/strategy.go
b/pkg/filter/ai/kvcache/strategy.go
new file mode 100644
index 00000000..3aef2a4f
--- /dev/null
+++ b/pkg/filter/ai/kvcache/strategy.go
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "context"
+)
+
+type CacheStrategy struct {
+ config CacheStrategyConfig
+ loadMonitor *LoadMonitor
+ lmcacheClient *LMCacheClient
+ tokenManager *TokenManager
+}
+
+type StrategyDecision struct {
+ ShouldCompress bool
+ ShouldPin bool
+ ShouldEvict bool
+ Reason string
+}
+
+func NewCacheStrategy(cfg CacheStrategyConfig, client *LMCacheClient,
tokenManager *TokenManager) *CacheStrategy {
+ return &CacheStrategy{
+ config: cfg,
+ loadMonitor: NewLoadMonitor(),
+ lmcacheClient: client,
+ tokenManager: tokenManager,
+ }
+}
+
+func (cs *CacheStrategy) RecordRequest() {
+ if cs == nil || cs.loadMonitor == nil {
+ return
+ }
+ cs.loadMonitor.RecordRequest()
+}
+
+func (cs *CacheStrategy) MakeDecision(_ context.Context, cacheStatus
*LookupResponse, model string, prompt string) *StrategyDecision {
+ if cs == nil {
+ return &StrategyDecision{}
+ }
+ decision := &StrategyDecision{}
+ metrics := cs.loadMonitor.Snapshot()
+
+ // load_threshold is validated as a ratio [0,1], so only ratio-based
metrics
+ // should participate in this decision.
+ if cs.config.EnableCompression && cs.config.LoadThreshold > 0 &&
+ metrics.CPUUsage >= cs.config.LoadThreshold {
+ decision.ShouldCompress = true
+ decision.Reason = "high_load"
+ }
+ if cs.config.EnableEviction && cs.config.MemoryThreshold > 0 &&
metrics.MemoryUsage >= cs.config.MemoryThreshold {
+ decision.ShouldEvict = true
+ decision.Reason = "memory_threshold"
+ }
+ if cs.config.EnablePinning && cs.tokenManager != nil &&
+ cs.tokenManager.IsHot(model, prompt,
cs.config.HotContentThreshold) {
+ decision.ShouldPin = true
+ decision.Reason = "hot_content"
+ }
+ return decision
+}
+
+func (cs *CacheStrategy) ExecuteDecision(ctx context.Context, decision
*StrategyDecision, tokens []int) error {
+ if cs == nil || decision == nil {
+ return nil
+ }
+ if decision.ShouldCompress {
+ _, err := cs.lmcacheClient.Compress(ctx, &CompressRequest{
+ Tokens: tokens,
+ InstanceID: cs.config.CompressInstanceID,
+ Location: cs.config.CompressLocation,
+ Method: cs.config.CompressMethod,
+ })
+ if err != nil {
+ return err
+ }
+ }
+ if decision.ShouldPin {
+ _, err := cs.lmcacheClient.Pin(ctx, &PinRequest{
+ Tokens: tokens,
+ InstanceID: cs.config.PinInstanceID,
+ Location: cs.config.PinLocation,
+ })
+ if err != nil {
+ return err
+ }
+ }
+ if decision.ShouldEvict {
+ _, err := cs.lmcacheClient.Evict(ctx, &EvictRequest{
+ Tokens: tokens,
+ InstanceID: cs.config.EvictInstanceID,
+ })
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/filter/ai/kvcache/strategy_test.go
b/pkg/filter/ai/kvcache/strategy_test.go
new file mode 100644
index 00000000..dd7ec390
--- /dev/null
+++ b/pkg/filter/ai/kvcache/strategy_test.go
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "context"
+ "net/http"
+ "sync"
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/go-resty/resty/v2"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestCacheStrategyMakeDecision(t *testing.T) {
+ tm := NewTokenManager("", resty.New(), TokenCacheConfig{}, nil,
time.Minute, 10)
+ tm.RecordHot("m1", "p1")
+ tm.RecordHot("m1", "p1")
+
+ cs := NewCacheStrategy(CacheStrategyConfig{
+ EnableCompression: true,
+ EnableEviction: true,
+ EnablePinning: true,
+ LoadThreshold: 0,
+ MemoryThreshold: 0.0000001,
+ HotContentThreshold: 2,
+ }, &LMCacheClient{}, tm)
+
+ decision := cs.MakeDecision(context.Background(), nil, "m1", "p1")
+ assert.True(t, decision.ShouldPin)
+ assert.True(t, decision.ShouldEvict)
+ assert.Equal(t, "hot_content", decision.Reason)
+}
+
+func TestCacheStrategyExecuteDecision(t *testing.T) {
+ var (
+ mu sync.Mutex
+ paths []string
+ )
+ restyClient := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ mu.Lock()
+ paths = append(paths, r.URL.Path)
+ mu.Unlock()
+ return newHTTPResponse(http.StatusOK,
`{"event_id":"ok","num_tokens":2}`), nil
+ })
+
+ client := NewLMCacheClient("http://lmcache.local", restyClient,
RetryConfig{
+ MaxAttempts: 1,
+ BaseBackoff: time.Millisecond,
+ MaxBackoff: time.Millisecond,
+ }, nil)
+ cs := NewCacheStrategy(CacheStrategyConfig{
+ CompressInstanceID: "compress-i",
+ CompressLocation: "loc-a",
+ CompressMethod: "zstd",
+ PinInstanceID: "pin-i",
+ PinLocation: "loc-b",
+ EvictInstanceID: "evict-i",
+ }, client, nil)
+
+ decision := &StrategyDecision{
+ ShouldCompress: true,
+ ShouldPin: true,
+ ShouldEvict: true,
+ }
+ err := cs.ExecuteDecision(context.Background(), decision, []int{1, 2})
+ require.NoError(t, err)
+
+ mu.Lock()
+ defer mu.Unlock()
+ assert.Equal(t, []string{"/compress", "/pin", "/evict"}, paths)
+}
+
+func TestCacheStrategyExecuteDecisionStopsOnFirstError(t *testing.T) {
+ var (
+ mu sync.Mutex
+ paths []string
+ )
+ restyClient := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ mu.Lock()
+ paths = append(paths, r.URL.Path)
+ mu.Unlock()
+ if r.URL.Path == "/compress" {
+ return newHTTPResponse(http.StatusInternalServerError,
"compress failed"), nil
+ }
+ return newHTTPResponse(http.StatusOK,
`{"event_id":"ok","num_tokens":2}`), nil
+ })
+
+ client := NewLMCacheClient("http://lmcache.local", restyClient,
RetryConfig{
+ MaxAttempts: 1,
+ BaseBackoff: time.Millisecond,
+ MaxBackoff: time.Millisecond,
+ }, nil)
+ cs := NewCacheStrategy(CacheStrategyConfig{
+ CompressMethod: "zstd",
+ }, client, nil)
+
+ decision := &StrategyDecision{
+ ShouldCompress: true,
+ ShouldPin: true,
+ ShouldEvict: true,
+ }
+ err := cs.ExecuteDecision(context.Background(), decision, []int{1, 2})
+ assert.Error(t, err)
+
+ mu.Lock()
+ defer mu.Unlock()
+ assert.Equal(t, []string{"/compress"}, paths)
+}
diff --git a/pkg/filter/ai/kvcache/test_helpers_test.go
b/pkg/filter/ai/kvcache/test_helpers_test.go
new file mode 100644
index 00000000..77f2575f
--- /dev/null
+++ b/pkg/filter/ai/kvcache/test_helpers_test.go
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "io"
+ "net/http"
+ "strings"
+)
+
+import (
+ "github.com/go-resty/resty/v2"
+)
+
+type roundTripFunc func(*http.Request) (*http.Response, error)
+
+func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return f(req)
+}
+
+func newRestyClientWithRoundTripper(fn roundTripFunc) *resty.Client {
+ httpClient := &http.Client{Transport: fn}
+ return resty.NewWithClient(httpClient)
+}
+
+func newHTTPResponse(statusCode int, body string) *http.Response {
+ return &http.Response{
+ StatusCode: statusCode,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(body)),
+ }
+}
diff --git a/pkg/filter/ai/kvcache/token_manager.go
b/pkg/filter/ai/kvcache/token_manager.go
new file mode 100644
index 00000000..0b0314cb
--- /dev/null
+++ b/pkg/filter/ai/kvcache/token_manager.go
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "sort"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+import (
+ "github.com/go-resty/resty/v2"
+)
+
+const hotMapMinSweepInterval = 30 * time.Second
+
+type TokenManager struct {
+ httpClient *resty.Client
+ endpoint string
+ cache sync.Map
+ config TokenCacheConfig
+ circuitBreaker *CircuitBreaker
+
+ cacheSize int64
+ hitCount int64
+ missCount int64
+
+ hotWindow time.Duration
+ hotMax int
+ hotMaxKeys int
+ hotMu sync.Mutex
+ hotMap map[string][]time.Time
+
+ hotSweepInterval time.Duration
+ hotLastSweep time.Time
+}
+
+type TokenizeRequest struct {
+ Model string `json:"model,omitempty"`
+ Prompt string `json:"prompt"`
+}
+
+type TokenizeResponse struct {
+ Count int `json:"count"`
+ Tokens []int `json:"tokens"`
+ MaxLen int `json:"max_model_len"`
+}
+
+type tokenCacheEntry struct {
+ tokens []int
+ expiresAt time.Time
+}
+
+func NewTokenManager(endpoint string, httpClient *resty.Client, cfg
TokenCacheConfig, cb *CircuitBreaker, hotWindow time.Duration, hotMax int,
hotMaxKeys ...int) *TokenManager {
+ maxKeys := 0
+ if len(hotMaxKeys) > 0 {
+ maxKeys = hotMaxKeys[0]
+ }
+ return &TokenManager{
+ httpClient: httpClient,
+ endpoint: endpoint,
+ config: cfg,
+ circuitBreaker: cb,
+ hotWindow: hotWindow,
+ hotMax: hotMax,
+ hotMaxKeys: maxKeys,
+ hotMap: make(map[string][]time.Time),
+ hotSweepInterval: computeHotSweepInterval(hotWindow),
+ }
+}
+
+func (tm *TokenManager) GetTokens(ctx context.Context, model string, prompt
string, rawBody []byte) ([]int, error) {
+ cacheKey := tm.cacheKey(model, prompt)
+ if tm.config.Enabled {
+ if tokens, ok := tm.loadCache(cacheKey); ok {
+ atomic.AddInt64(&tm.hitCount, 1)
+ return tokens, nil
+ }
+ atomic.AddInt64(&tm.missCount, 1)
+ }
+
+ tokens, err := tm.tokenize(ctx, model, prompt, rawBody)
+ if err != nil {
+ return nil, err
+ }
+
+ if tm.config.Enabled {
+ tm.storeCache(cacheKey, tokens)
+ }
+ return tokens, nil
+}
+
+func (tm *TokenManager) GetCachedTokens(model string, prompt string) ([]int,
bool) {
+ if !tm.config.Enabled {
+ return nil, false
+ }
+ cacheKey := tm.cacheKey(model, prompt)
+ tokens, ok := tm.loadCache(cacheKey)
+ if ok {
+ atomic.AddInt64(&tm.hitCount, 1)
+ } else {
+ atomic.AddInt64(&tm.missCount, 1)
+ }
+ return tokens, ok
+}
+
+func (tm *TokenManager) InvalidateCache(model string, prompt string) {
+ cacheKey := tm.cacheKey(model, prompt)
+ tm.deleteCache(cacheKey)
+}
+
+func (tm *TokenManager) GetCacheStats() CacheStats {
+ size := atomic.LoadInt64(&tm.cacheSize)
+ hit := atomic.LoadInt64(&tm.hitCount)
+ miss := atomic.LoadInt64(&tm.missCount)
+ total := hit + miss
+ var hitRate float64
+ if total > 0 {
+ hitRate = float64(hit) / float64(total)
+ }
+ return CacheStats{
+ Size: int(size),
+ HitRate: hitRate,
+ HitCount: hit,
+ MissCount: miss,
+ }
+}
+
+func (tm *TokenManager) tokenize(ctx context.Context, model string, prompt
string, rawBody []byte) ([]int, error) {
+ var tokens []int
+ err := tm.execute(ctx, func() error {
+ body, err := tm.buildTokenizeBody(model, prompt, rawBody)
+ if err != nil {
+ return err
+ }
+ resp, err := tm.doTokenizeRequest(ctx, body)
+ if err != nil {
+ return err
+ }
+ tokens = resp.Tokens
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return tokens, nil
+}
+
+func (tm *TokenManager) buildTokenizeBody(model string, prompt string, rawBody
[]byte) (any, error) {
+ if len(rawBody) > 0 {
+ return rawBody, nil
+ }
+ return TokenizeRequest{Model: model, Prompt: prompt}, nil
+}
+
+func (tm *TokenManager) doTokenizeRequest(ctx context.Context, body any)
(*TokenizeResponse, error) {
+ tokenizeURL := strings.TrimRight(tm.endpoint, "/") + "/tokenize"
+ resp, err := tm.httpClient.R().
+ SetContext(ctx).
+ SetHeader("Content-Type", "application/json").
+ SetBody(body).
+ Post(tokenizeURL)
+ if err != nil {
+ return nil, fmt.Errorf("[kvcache] call tokenize: %w", err)
+ }
+ if resp.StatusCode() < 200 || resp.StatusCode() >= 300 {
+ return nil, fmt.Errorf("[kvcache] tokenize status %d: %s",
resp.StatusCode(), strings.TrimSpace(string(resp.Body())))
+ }
+ var tokenResp TokenizeResponse
+ if err := json.Unmarshal(resp.Body(), &tokenResp); err != nil {
+ return nil, fmt.Errorf("[kvcache] decode tokenize response:
%w", err)
+ }
+ return &tokenResp, nil
+}
+
+func (tm *TokenManager) RecordHot(model string, prompt string) {
+ if tm == nil || tm.hotWindow <= 0 || model == "" || prompt == "" {
+ return
+ }
+ now := time.Now()
+ key := tm.cacheKey(model, prompt)
+ tm.hotMu.Lock()
+ defer tm.hotMu.Unlock()
+ tm.maybeSweepHotMapLocked(now)
+ entries := tm.hotMap[key]
+ entries = append(entries, now)
+ entries = trimHotWindow(entries, now, tm.hotWindow)
+ if tm.hotMax > 0 && len(entries) > tm.hotMax {
+ entries = entries[len(entries)-tm.hotMax:]
+ }
+ if len(entries) == 0 {
+ delete(tm.hotMap, key)
+ return
+ }
+ tm.hotMap[key] = entries
+ tm.enforceHotMapLimitLocked(now)
+}
+
+func (tm *TokenManager) IsHot(model string, prompt string, threshold int) bool
{
+ if tm == nil || tm.hotWindow <= 0 || threshold <= 0 || model == "" ||
prompt == "" {
+ return false
+ }
+ now := time.Now()
+ key := tm.cacheKey(model, prompt)
+ tm.hotMu.Lock()
+ defer tm.hotMu.Unlock()
+ entries := tm.hotMap[key]
+ if len(entries) == 0 {
+ return false
+ }
+ entries = trimHotWindow(entries, now, tm.hotWindow)
+ if tm.hotMax > 0 && len(entries) > tm.hotMax {
+ entries = entries[len(entries)-tm.hotMax:]
+ }
+ if len(entries) == 0 {
+ delete(tm.hotMap, key)
+ return false
+ }
+ tm.hotMap[key] = entries
+ return len(entries) >= threshold
+}
+
+func computeHotSweepInterval(hotWindow time.Duration) time.Duration {
+ if hotWindow <= 0 {
+ return 0
+ }
+ interval := hotWindow / 2
+ if interval <= 0 {
+ interval = hotWindow
+ }
+ if hotWindow > hotMapMinSweepInterval && interval <
hotMapMinSweepInterval {
+ interval = hotMapMinSweepInterval
+ }
+ return interval
+}
+
+func (tm *TokenManager) maybeSweepHotMapLocked(now time.Time) {
+ if tm.hotSweepInterval <= 0 {
+ return
+ }
+ if !tm.hotLastSweep.IsZero() && now.Sub(tm.hotLastSweep) <
tm.hotSweepInterval {
+ return
+ }
+ for key, entries := range tm.hotMap {
+ entries = trimHotWindow(entries, now, tm.hotWindow)
+ if tm.hotMax > 0 && len(entries) > tm.hotMax {
+ entries = entries[len(entries)-tm.hotMax:]
+ }
+ if len(entries) == 0 {
+ delete(tm.hotMap, key)
+ continue
+ }
+ tm.hotMap[key] = entries
+ }
+ tm.hotLastSweep = now
+}
+
+func (tm *TokenManager) enforceHotMapLimitLocked(now time.Time) {
+ if tm.hotMaxKeys <= 0 || len(tm.hotMap) <= tm.hotMaxKeys {
+ return
+ }
+
+ // First, aggressively drop expired keys when we are already over the
key cap.
+ for key, entries := range tm.hotMap {
+ entries = trimHotWindow(entries, now, tm.hotWindow)
+ if tm.hotMax > 0 && len(entries) > tm.hotMax {
+ entries = entries[len(entries)-tm.hotMax:]
+ }
+ if len(entries) == 0 {
+ delete(tm.hotMap, key)
+ continue
+ }
+ tm.hotMap[key] = entries
+ }
+ if len(tm.hotMap) <= tm.hotMaxKeys {
+ return
+ }
+
+ type hotKeyLastSeen struct {
+ key string
+ lastSeen time.Time
+ }
+ candidates := make([]hotKeyLastSeen, 0, len(tm.hotMap))
+ for key, entries := range tm.hotMap {
+ if len(entries) == 0 {
+ delete(tm.hotMap, key)
+ continue
+ }
+ candidates = append(candidates, hotKeyLastSeen{
+ key: key,
+ lastSeen: entries[len(entries)-1],
+ })
+ }
+ if len(candidates) <= tm.hotMaxKeys {
+ return
+ }
+ sort.Slice(candidates, func(i, j int) bool {
+ return candidates[i].lastSeen.Before(candidates[j].lastSeen)
+ })
+ excess := len(candidates) - tm.hotMaxKeys
+ for i := 0; i < excess; i++ {
+ delete(tm.hotMap, candidates[i].key)
+ }
+}
+
+func trimHotWindow(entries []time.Time, now time.Time, window time.Duration)
[]time.Time {
+ if window <= 0 || len(entries) == 0 {
+ return entries
+ }
+ cutoff := now.Add(-window)
+ idx := 0
+ for idx < len(entries) && entries[idx].Before(cutoff) {
+ idx++
+ }
+ if idx == 0 {
+ return entries
+ }
+ return append([]time.Time(nil), entries[idx:]...)
+}
+
+func (tm *TokenManager) execute(ctx context.Context, operation func() error)
error {
+ if tm.circuitBreaker == nil {
+ return operation()
+ }
+ return tm.circuitBreaker.Execute(operation)
+}
+
+func (tm *TokenManager) cacheKey(model string, prompt string) string {
+ sum := sha256.Sum256([]byte(model + "\x00" + prompt))
+ return hex.EncodeToString(sum[:])
+}
+
+func (tm *TokenManager) deleteCache(key string) {
+ if _, loaded := tm.cache.LoadAndDelete(key); loaded {
+ atomic.AddInt64(&tm.cacheSize, -1)
+ }
+}
+
+func (tm *TokenManager) loadCache(key string) ([]int, bool) {
+ entryAny, ok := tm.cache.Load(key)
+ if !ok {
+ return nil, false
+ }
+ entry, ok := entryAny.(*tokenCacheEntry)
+ if !ok {
+ tm.deleteCache(key)
+ return nil, false
+ }
+ if tm.config.TTL > 0 && time.Now().After(entry.expiresAt) {
+ tm.deleteCache(key)
+ return nil, false
+ }
+ return entry.tokens, true
+}
+
+func (tm *TokenManager) storeCache(key string, tokens []int) {
+ if tm.config.MaxSize > 0 {
+ for atomic.LoadInt64(&tm.cacheSize) >= int64(tm.config.MaxSize)
{
+ if !tm.evictOne() {
+ break
+ }
+ }
+ }
+ entry := &tokenCacheEntry{
+ tokens: tokens,
+ expiresAt: time.Now().Add(tm.config.TTL),
+ }
+ if _, loaded := tm.cache.LoadOrStore(key, entry); !loaded {
+ atomic.AddInt64(&tm.cacheSize, 1)
+ }
+}
+
+func (tm *TokenManager) evictOne() bool {
+ evicted := false
+ tm.cache.Range(func(key, value any) bool {
+ tm.deleteCache(key.(string))
+ evicted = true
+ return false
+ })
+ return evicted
+}
diff --git a/pkg/filter/ai/kvcache/token_manager_test.go
b/pkg/filter/ai/kvcache/token_manager_test.go
new file mode 100644
index 00000000..1eabfc16
--- /dev/null
+++ b/pkg/filter/ai/kvcache/token_manager_test.go
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTokenManagerGetTokensUsesCache(t *testing.T) {
+ var callCount int64
+ client := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ assert.Equal(t, "/tokenize", r.URL.Path)
+ atomic.AddInt64(&callCount, 1)
+ return newHTTPResponse(http.StatusOK,
`{"count":2,"tokens":[11,22],"max_model_len":4096}`), nil
+ })
+
+ tm := NewTokenManager("http://tokenizer.local", client,
TokenCacheConfig{
+ Enabled: true,
+ MaxSize: 10,
+ TTL: time.Minute,
+ }, nil, time.Minute, 10)
+
+ tokens1, err := tm.GetTokens(context.Background(), "m1", "hello", nil)
+ require.NoError(t, err)
+ assert.Equal(t, []int{11, 22}, tokens1)
+
+ tokens2, err := tm.GetTokens(context.Background(), "m1", "hello", nil)
+ require.NoError(t, err)
+ assert.Equal(t, []int{11, 22}, tokens2)
+ assert.Equal(t, int64(1), atomic.LoadInt64(&callCount))
+
+ stats := tm.GetCacheStats()
+ assert.Equal(t, 1, stats.Size)
+ assert.Equal(t, int64(1), stats.HitCount)
+ assert.Equal(t, int64(1), stats.MissCount)
+ assert.Equal(t, 0.5, stats.HitRate)
+}
+
+func TestTokenManagerRawBodyAndErrorHandling(t *testing.T) {
+ t.Run("raw body is passed through", func(t *testing.T) {
+ var gotBody []byte
+ client := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ var err error
+ gotBody, err = io.ReadAll(r.Body)
+ require.NoError(t, err)
+ return newHTTPResponse(http.StatusOK,
`{"count":1,"tokens":[7],"max_model_len":4096}`), nil
+ })
+
+ tm := NewTokenManager("http://tokenizer.local", client,
TokenCacheConfig{Enabled: false}, nil, time.Minute, 10)
+ rawBody := []byte(`{"model":"raw-model","prompt":"raw-prompt"}`)
+ tokens, err := tm.GetTokens(context.Background(), "ignored",
"ignored", rawBody)
+ require.NoError(t, err)
+ assert.Equal(t, []int{7}, tokens)
+ assert.JSONEq(t, string(rawBody), string(gotBody))
+ })
+
+ t.Run("status code error does not populate cache", func(t *testing.T) {
+ client := newRestyClientWithRoundTripper(func(r *http.Request)
(*http.Response, error) {
+ return newHTTPResponse(http.StatusInternalServerError,
"boom"), nil
+ })
+
+ tm := NewTokenManager("http://tokenizer.local", client,
TokenCacheConfig{
+ Enabled: true,
+ MaxSize: 10,
+ TTL: time.Minute,
+ }, nil, time.Minute, 10)
+
+ tokens, err := tm.GetTokens(context.Background(), "m1", "p1",
nil)
+ assert.Nil(t, tokens)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "tokenize status 500")
+ assert.Equal(t, 0, tm.GetCacheStats().Size)
+ })
+}
diff --git a/pkg/filter/ai/kvcache/types.go b/pkg/filter/ai/kvcache/types.go
new file mode 100644
index 00000000..fcade842
--- /dev/null
+++ b/pkg/filter/ai/kvcache/types.go
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kvcache
+
+type CacheStats struct {
+ Size int `json:"size"`
+ HitRate float64 `json:"hit_rate"`
+ HitCount int64 `json:"hit_count"`
+ MissCount int64 `json:"miss_count"`
+}
+
+type LoadMetrics struct {
+ CPUUsage float64 `json:"cpu_usage"`
+ MemoryUsage float64 `json:"memory_usage"`
+ RequestRate float64 `json:"request_rate"`
+}
+
+type LookupResponse struct {
+ EventID string `json:"event_id"`
+ LayoutInfo map[string]CacheLayout `json:"layout_info"`
+}
+
+type CacheLayout struct {
+ Location string `json:"0"`
+ TokenCount int `json:"1"`
+}
diff --git a/pkg/filter/llm/proxy/filter.go b/pkg/filter/llm/proxy/filter.go
index e9194e4d..a06cbcba 100644
--- a/pkg/filter/llm/proxy/filter.go
+++ b/pkg/filter/llm/proxy/filter.go
@@ -45,7 +45,8 @@ const (
LLMUnhealthyKey = "LLMUnhealthy"
HealthyCheckTimeKey = "HealthyCheckTime"
// Context key to pass attempt data from proxy to downstream filters
- LLMUpstreamAttemptsKey = "llm_upstream_attempts"
+ LLMUpstreamAttemptsKey = "llm_upstream_attempts"
+ llmPreferredEndpointIDKey = "llm_preferred_endpoint_id"
)
// UpstreamAttempt holds details for a single request attempt to an endpoint.
@@ -98,6 +99,21 @@ type (
}
)
+func getPreferredEndpointID(hc *contexthttp.HttpContext) string {
+ if hc == nil || hc.Params == nil {
+ return ""
+ }
+ val, ok := hc.Params[llmPreferredEndpointIDKey]
+ if !ok {
+ return ""
+ }
+ endpointID, ok := val.(string)
+ if !ok {
+ return ""
+ }
+ return endpointID
+}
+
// Kind returns the unique name of this filter.
func (p *Plugin) Kind() string {
return Kind
@@ -265,6 +281,11 @@ func (s *Strategy) Execute(executor *RequestExecutor)
(*http.Response, error) {
// 1. Pick initial endpoint from the cluster based on load balancing.
endpoint := executor.clusterManager.PickEndpoint(executor.clusterName,
executor.hc)
+ if preferred := getPreferredEndpointID(executor.hc); preferred != "" {
+ if target :=
executor.clusterManager.GetEndpointByID(executor.clusterName, preferred);
target != nil {
+ endpoint = target
+ }
+ }
// 2. The main fallback loop. It continues as long as we have a valid
endpoint to try.
for endpoint != nil {
diff --git a/pkg/pluginregistry/registry.go b/pkg/pluginregistry/registry.go
index 2b3de88e..deaa932d 100644
--- a/pkg/pluginregistry/registry.go
+++ b/pkg/pluginregistry/registry.go
@@ -31,6 +31,7 @@ import (
_
"github.com/apache/dubbo-go-pixiu/pkg/cluster/retry/exponentialbackoff"
_ "github.com/apache/dubbo-go-pixiu/pkg/cluster/retry/noretry"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/accesslog"
+ _ "github.com/apache/dubbo-go-pixiu/pkg/filter/ai/kvcache"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/auth/jwt"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/auth/mcp"
_ "github.com/apache/dubbo-go-pixiu/pkg/filter/authority"
diff --git a/pkg/server/cluster_manager.go b/pkg/server/cluster_manager.go
index 69ea4ecb..db142eac 100644
--- a/pkg/server/cluster_manager.go
+++ b/pkg/server/cluster_manager.go
@@ -191,6 +191,23 @@ func (cm *ClusterManager) PickNextEndpoint(clusterName
string, curEndpointID str
return nil
}
+// GetEndpointByID returns the endpoint by ID in the given cluster.
+func (cm *ClusterManager) GetEndpointByID(clusterName string, endpointID
string) *model.Endpoint {
+ cm.rw.RLock()
+ defer cm.rw.RUnlock()
+
+ c := cm.getCluster(clusterName)
+ if c == nil {
+ return nil
+ }
+ for _, endpoint := range c.Endpoints {
+ if endpoint.ID == endpointID && !endpoint.UnHealthy {
+ return endpoint
+ }
+ }
+ return nil
+}
+
// getCluster returns the cluster configuration by its name.
func (cm *ClusterManager) getCluster(clusterName string) *model.ClusterConfig {
for _, c := range cm.store.Config {