This is an automated email from the ASF dual-hosted git repository. liuhongyu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/shenyu.git
The following commit(s) were added to refs/heads/master by this push: new 6898a9282e fix AiTokenLimiterPlugin appendResponse (#6027) 6898a9282e is described below commit 6898a9282e41dc69a21d0b889b7e0f5c872d6cc5 Author: HY-love-sleep <73268470+hy-love-sl...@users.noreply.github.com> AuthorDate: Tue May 20 09:43:38 2025 +0800 fix AiTokenLimiterPlugin appendResponse (#6027) * fix: one-time decompression after the flow is finished * chore: code style * feat: save memory by streaming cross-block decompression * chore: code style * chore: del useless imports --- .../ai/token/limiter/AiTokenLimiterPlugin.java | 156 ++++++++++++--------- 1 file changed, 92 insertions(+), 64 deletions(-) diff --git a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java index 8781885c7d..c7f66d2ae9 100644 --- a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java +++ b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-token-limiter/src/main/java/org/apache/shenyu/plugin/ai/token/limiter/AiTokenLimiterPlugin.java @@ -38,6 +38,7 @@ import org.slf4j.LoggerFactory; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.data.redis.core.ReactiveRedisTemplate; import org.springframework.http.HttpCookie; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatusCode; import org.springframework.http.server.reactive.ServerHttpRequest; @@ -49,7 +50,6 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.util.annotation.NonNull; -import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; @@ -61,7 +61,8 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; -import java.util.zip.GZIPInputStream; +import java.util.zip.DataFormatException; +import java.util.zip.Inflater; /** * Shenyu ai token limiter plugin. @@ -152,32 +153,22 @@ public class AiTokenLimiterPlugin extends AbstractShenyuPlugin { String key; // Determine the key based on the configured key resolver type AiTokenLimiterEnum tokenLimiterEnum = AiTokenLimiterEnum.getByName(tokenLimitType); - - switch (tokenLimiterEnum) { - case IP: - key = Objects.requireNonNull(request.getRemoteAddress()).getHostString(); - break; - case URI: - key = request.getURI().getPath(); - break; - case HEADER: - key = request.getHeaders().getFirst(keyName); - break; - case PARAMETER: - key = request.getQueryParams().getFirst(keyName); - break; - case COOKIE: + + key = switch (tokenLimiterEnum) { + case IP -> Objects.requireNonNull(request.getRemoteAddress()).getHostString(); + case URI -> request.getURI().getPath(); + case HEADER -> request.getHeaders().getFirst(keyName); + case PARAMETER -> request.getQueryParams().getFirst(keyName); + case COOKIE -> { HttpCookie cookie = request.getCookies().getFirst(keyName); - key = Objects.nonNull(cookie) ? cookie.getValue() : ""; - break; - case CONTEXT_PATH: - default: - key = exchange.getAttribute(Constants.CONTEXT_PATH); - } + yield Objects.nonNull(cookie) ? cookie.getValue() : ""; + } + default -> exchange.getAttribute(Constants.CONTEXT_PATH); + }; return StringUtils.isBlank(key) ? "" : key; } - + private void recordTokensUsage(final ReactiveRedisTemplate reactiveRedisTemplate, final String cacheKey, final Long tokens, final Long windowSeconds) { // Record token usage with expiration reactiveRedisTemplate.opsForValue() @@ -216,55 +207,92 @@ public class AiTokenLimiterPlugin extends AbstractShenyuPlugin { public Mono<Void> writeWith(@NonNull final Publisher<? extends DataBuffer> body) { return super.writeWith(appendResponse(body)); } - + @NonNull private Flux<? extends DataBuffer> appendResponse(final Publisher<? extends DataBuffer> body) { - BodyWriter writer = new BodyWriter(); - return Flux.from(body).doOnNext(buffer -> { - try (DataBuffer.ByteBufferIterator bufferIterator = buffer.readableByteBuffers()) { - bufferIterator.forEachRemaining(byteBuffer -> { - // Handle gzip encoded response - if (serverHttpResponse.getHeaders().containsKey(Constants.CONTENT_ENCODING) - && serverHttpResponse.getHeaders().getFirst(Constants.CONTENT_ENCODING).contains(Constants.HTTP_ACCEPT_ENCODING_GZIP)) { - try { - ByteBuffer readOnlyBuffer = byteBuffer.asReadOnlyBuffer(); - byte[] compressed = new byte[readOnlyBuffer.remaining()]; - readOnlyBuffer.get(compressed); - - // Decompress gzipped content - byte[] decompressed = decompressGzip(compressed); - writer.write(ByteBuffer.wrap(decompressed)); - - } catch (IOException e) { - LOG.error("Failed to decompress gzipped response", e); - writer.write(byteBuffer.asReadOnlyBuffer()); - } - } else { - writer.write(byteBuffer.asReadOnlyBuffer()); + HttpHeaders headers = serverHttpResponse.getHeaders(); + boolean isGzip = headers.containsKey(Constants.CONTENT_ENCODING) + && headers.getFirst(Constants.CONTENT_ENCODING) + .contains(Constants.HTTP_ACCEPT_ENCODING_GZIP); + + final Inflater inflater = isGzip ? new Inflater(true) : null; + final byte[] outBuf = new byte[4096]; + final AtomicBoolean headerSkipped = new AtomicBoolean(!isGzip); + + return Flux.<DataBuffer>from(body) + .doOnNext(buffer -> { + try (DataBuffer.ByteBufferIterator it = buffer.readableByteBuffers()) { + it.forEachRemaining(bb -> { + ByteBuffer ro = bb.asReadOnlyBuffer(); + byte[] inBytes = new byte[ro.remaining()]; + ro.get(inBytes); + + if (isGzip) { + int offset = 0; + int len = inBytes.length; + if (!headerSkipped.get()) { + offset = skipGzipHeader(inBytes); + headerSkipped.set(true); + } + inflater.setInput(inBytes, offset, len - offset); + try { + int cnt; + while ((cnt = inflater.inflate(outBuf)) > 0) { + writer.write(ByteBuffer.wrap(outBuf, 0, cnt)); + } + } catch (DataFormatException ex) { + LOG.error("inflater decompression failed", ex); + } + } else { + writer.write(ro); + } + }); + } catch (Exception e) { + LOG.error("read dataBuffer error", e); + } + }) + .doFinally(signal -> { + // release inflater + if (Objects.nonNull(inflater)) { + inflater.end(); } + String responseBody = writer.output(); + AiModel aiModel = exchange.getAttribute(Constants.AI_MODEL); + long tokens = Objects.requireNonNull(aiModel).getCompletionTokens(responseBody); + tokensRecorder.accept(tokens); }); - } - }).doFinally(signal -> { - String responseBody = writer.output(); - AiModel aiModel = exchange.getAttribute(Constants.AI_MODEL); - long tokens = Objects.requireNonNull(aiModel).getCompletionTokens(responseBody); - tokensRecorder.accept(tokens); - }); } - - private byte[] decompressGzip(final byte[] compressed) throws IOException { - try (GZIPInputStream gzipInputStream = new GZIPInputStream(new ByteArrayInputStream(compressed)); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { - byte[] buffer = new byte[1024]; - int len; - while ((len = gzipInputStream.read(buffer)) > 0) { - outputStream.write(buffer, 0, len); + + private int skipGzipHeader(final byte[] b) { + int pos = 10; + int flg = b[3] & 0xFF; + + if ((flg & 0x04) != 0) { + int xlen = (b[pos] & 0xFF) | ((b[pos + 1] & 0xFF) << 8); + pos += 2 + xlen; + } + + if ((flg & 0x08) != 0) { + while (b[pos] != 0) { + pos++; + } + pos++; + } + + if ((flg & 0x10) != 0) { + while (b[pos] != 0) { + pos++; } - return outputStream.toByteArray(); + pos++; + } + + if ((flg & 0x02) != 0) { + pos += 2; } + return pos; } - + } static class BodyWriter {