This is an automated email from the ASF dual-hosted git repository.

liuhongyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shenyu.git


The following commit(s) were added to refs/heads/master by this push:
     new 6898a9282e fix AiTokenLimiterPlugin appendResponse (#6027)
6898a9282e is described below

commit 6898a9282e41dc69a21d0b889b7e0f5c872d6cc5
Author: HY-love-sleep <73268470+hy-love-sl...@users.noreply.github.com>
AuthorDate: Tue May 20 09:43:38 2025 +0800

    fix AiTokenLimiterPlugin appendResponse (#6027)
    
    * fix: one-time decompression after the flow is finished
    
    * chore: code style
    
    * feat: save memory by streaming cross-block decompression
    
    * chore: code style
    
    * chore: del useless imports
---
 .../ai/token/limiter/AiTokenLimiterPlugin.java     | 156 ++++++++++++---------
 1 file changed, 92 insertions(+), 64 deletions(-)

diff --git 
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
 
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
index 8781885c7d..c7f66d2ae9 100644
--- 
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
+++ 
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java
@@ -38,6 +38,7 @@ import org.slf4j.LoggerFactory;
 import org.springframework.core.io.buffer.DataBuffer;
 import org.springframework.data.redis.core.ReactiveRedisTemplate;
 import org.springframework.http.HttpCookie;
+import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.HttpStatusCode;
 import org.springframework.http.server.reactive.ServerHttpRequest;
@@ -49,7 +50,6 @@ import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
 import reactor.util.annotation.NonNull;
 
-import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -61,7 +61,8 @@ import java.util.Objects;
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Consumer;
-import java.util.zip.GZIPInputStream;
+import java.util.zip.DataFormatException;
+import java.util.zip.Inflater;
 
 /**
  * Shenyu ai token limiter plugin.
@@ -152,32 +153,22 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
         String key;
         // Determine the key based on the configured key resolver type
         AiTokenLimiterEnum tokenLimiterEnum = 
AiTokenLimiterEnum.getByName(tokenLimitType);
-        
-        switch (tokenLimiterEnum) {
-            case IP:
-                key = 
Objects.requireNonNull(request.getRemoteAddress()).getHostString();
-                break;
-            case URI:
-                key = request.getURI().getPath();
-                break;
-            case HEADER:
-                key = request.getHeaders().getFirst(keyName);
-                break;
-            case PARAMETER:
-                key = request.getQueryParams().getFirst(keyName);
-                break;
-            case COOKIE:
+
+        key = switch (tokenLimiterEnum) {
+            case IP -> 
Objects.requireNonNull(request.getRemoteAddress()).getHostString();
+            case URI -> request.getURI().getPath();
+            case HEADER -> request.getHeaders().getFirst(keyName);
+            case PARAMETER -> request.getQueryParams().getFirst(keyName);
+            case COOKIE -> {
                 HttpCookie cookie = request.getCookies().getFirst(keyName);
-                key = Objects.nonNull(cookie) ? cookie.getValue() : "";
-                break;
-            case CONTEXT_PATH:
-            default:
-                key = exchange.getAttribute(Constants.CONTEXT_PATH);
-        }
+                yield Objects.nonNull(cookie) ? cookie.getValue() : "";
+            }
+            default -> exchange.getAttribute(Constants.CONTEXT_PATH);
+        };
         
         return StringUtils.isBlank(key) ? "" : key;
     }
-    
+
     private void recordTokensUsage(final ReactiveRedisTemplate 
reactiveRedisTemplate, final String cacheKey, final Long tokens, final Long 
windowSeconds) {
         // Record token usage with expiration
         reactiveRedisTemplate.opsForValue()
@@ -216,55 +207,92 @@ public class AiTokenLimiterPlugin extends 
AbstractShenyuPlugin {
         public Mono<Void> writeWith(@NonNull final Publisher<? extends 
DataBuffer> body) {
             return super.writeWith(appendResponse(body));
         }
-        
+
         @NonNull
         private Flux<? extends DataBuffer> appendResponse(final Publisher<? 
extends DataBuffer> body) {
-            
             BodyWriter writer = new BodyWriter();
-            return Flux.from(body).doOnNext(buffer -> {
-                try (DataBuffer.ByteBufferIterator bufferIterator = 
buffer.readableByteBuffers()) {
-                    bufferIterator.forEachRemaining(byteBuffer -> {
-                        // Handle gzip encoded response
-                        if 
(serverHttpResponse.getHeaders().containsKey(Constants.CONTENT_ENCODING)
-                                && 
serverHttpResponse.getHeaders().getFirst(Constants.CONTENT_ENCODING).contains(Constants.HTTP_ACCEPT_ENCODING_GZIP))
 {
-                            try {
-                                ByteBuffer readOnlyBuffer = 
byteBuffer.asReadOnlyBuffer();
-                                byte[] compressed = new 
byte[readOnlyBuffer.remaining()];
-                                readOnlyBuffer.get(compressed);
-                                
-                                // Decompress gzipped content
-                                byte[] decompressed = 
decompressGzip(compressed);
-                                writer.write(ByteBuffer.wrap(decompressed));
-                                
-                            } catch (IOException e) {
-                                LOG.error("Failed to decompress gzipped 
response", e);
-                                writer.write(byteBuffer.asReadOnlyBuffer());
-                            }
-                        } else {
-                            writer.write(byteBuffer.asReadOnlyBuffer());
+            HttpHeaders headers = serverHttpResponse.getHeaders();
+            boolean isGzip = headers.containsKey(Constants.CONTENT_ENCODING)
+                    && headers.getFirst(Constants.CONTENT_ENCODING)
+                    .contains(Constants.HTTP_ACCEPT_ENCODING_GZIP);
+
+            final Inflater inflater = isGzip ? new Inflater(true) : null;
+            final byte[] outBuf = new byte[4096];
+            final AtomicBoolean headerSkipped = new AtomicBoolean(!isGzip);
+
+            return Flux.<DataBuffer>from(body)
+                    .doOnNext(buffer -> {
+                        try (DataBuffer.ByteBufferIterator it = 
buffer.readableByteBuffers()) {
+                            it.forEachRemaining(bb -> {
+                                ByteBuffer ro = bb.asReadOnlyBuffer();
+                                byte[] inBytes = new byte[ro.remaining()];
+                                ro.get(inBytes);
+
+                                if (isGzip) {
+                                    int offset = 0;
+                                    int len = inBytes.length;
+                                    if (!headerSkipped.get()) {
+                                        offset = skipGzipHeader(inBytes);
+                                        headerSkipped.set(true);
+                                    }
+                                    inflater.setInput(inBytes, offset, len - 
offset);
+                                    try {
+                                        int cnt;
+                                        while ((cnt = 
inflater.inflate(outBuf)) > 0) {
+                                            
writer.write(ByteBuffer.wrap(outBuf, 0, cnt));
+                                        }
+                                    } catch (DataFormatException ex) {
+                                        LOG.error("inflater decompression 
failed", ex);
+                                    }
+                                } else {
+                                    writer.write(ro);
+                                }
+                            });
+                        } catch (Exception e) {
+                            LOG.error("read dataBuffer error", e);
+                        }
+                    })
+                    .doFinally(signal -> {
+                        // release inflater
+                        if (Objects.nonNull(inflater)) {
+                            inflater.end();
                         }
+                        String responseBody = writer.output();
+                        AiModel aiModel = 
exchange.getAttribute(Constants.AI_MODEL);
+                        long tokens = 
Objects.requireNonNull(aiModel).getCompletionTokens(responseBody);
+                        tokensRecorder.accept(tokens);
                     });
-                }
-            }).doFinally(signal -> {
-                String responseBody = writer.output();
-                AiModel aiModel = exchange.getAttribute(Constants.AI_MODEL);
-                long tokens = 
Objects.requireNonNull(aiModel).getCompletionTokens(responseBody);
-                tokensRecorder.accept(tokens);
-            });
         }
-        
-        private byte[] decompressGzip(final byte[] compressed) throws 
IOException {
-            try (GZIPInputStream gzipInputStream = new GZIPInputStream(new 
ByteArrayInputStream(compressed));
-                 ByteArrayOutputStream outputStream = new 
ByteArrayOutputStream()) {
-                byte[] buffer = new byte[1024];
-                int len;
-                while ((len = gzipInputStream.read(buffer)) > 0) {
-                    outputStream.write(buffer, 0, len);
+
+        private int skipGzipHeader(final byte[] b) {
+            int pos = 10;
+            int flg = b[3] & 0xFF;
+
+            if ((flg & 0x04) != 0) {
+                int xlen = (b[pos] & 0xFF) | ((b[pos + 1] & 0xFF) << 8);
+                pos += 2 + xlen;
+            }
+
+            if ((flg & 0x08) != 0) {
+                while (b[pos] != 0) {
+                    pos++;
+                }
+                pos++;
+            }
+
+            if ((flg & 0x10) != 0) {
+                while (b[pos] != 0) {
+                    pos++;
                 }
-                return outputStream.toByteArray();
+                pos++;
+            }
+
+            if ((flg & 0x02) != 0) {
+                pos += 2;
             }
+            return pos;
         }
-        
+
     }
     
     static class BodyWriter {

Reply via email to