Copilot commented on code in PR #6060:
URL: https://github.com/apache/shenyu/pull/6060#discussion_r2280397566


##########
shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/spring/ai/factory/OpenAiModelFactory.java:
##########
@@ -42,7 +42,7 @@ public ChatModel createAiModel(final AiCommonConfig config) {
                 .build();
         OpenAiChatOptions.Builder model = 
OpenAiChatOptions.builder().model(config.getModel());
         
Optional.ofNullable(config.getTemperature()).ifPresent(model::temperature);
-        Optional.ofNullable(config.getMaxTokens()).ifPresent(model::maxTokens);
+        
Optional.ofNullable(config.getMaxTokens()).ifPresent(model::maxCompletionTokens);

Review Comment:
   The change from `maxTokens` to `maxCompletionTokens` appears to be an API 
breaking change. This should be documented or verified that it's compatible 
with the Spring AI version being used, as this could cause runtime errors if 
the API doesn't support this method.
   ```suggestion
           
Optional.ofNullable(config.getMaxTokens()).ifPresent(model::maxTokens);
   ```



##########
shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java:
##########
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced.service;
+
+import org.apache.shenyu.plugin.ai.common.strategy.SimpleModelFallbackStrategy;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.ai.retry.NonTransientAiException;
+import org.springframework.stereotype.Service;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Schedulers;
+import reactor.util.retry.Retry;
+
+import java.time.Duration;
+import java.util.Optional;
+
+/**
+ * AI proxy executor service.
+ */
+@Service
+public class AiProxyExecutorService {
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(AiProxyExecutorService.class);
+
+    /**
+     * Execute the AI call with retry and fallback.
+     *
+     * @param mainClient      the main chat client
+     * @param fallbackClientOpt the optional fallback chat client
+     * @param requestBody     the request body
+     * @return a Mono containing the ChatResponse
+     */
+    public Mono<ChatResponse> execute(final ChatClient mainClient, final 
Optional<ChatClient> fallbackClientOpt, final String requestBody) {
+        final Mono<ChatResponse> mainCall = doChatCall(mainClient, 
requestBody);
+
+        return mainCall

Review Comment:
   [nitpick] The retry configuration is hardcoded (3 retries, 1 second 
backoff). Consider making these values configurable through the AiCommonConfig 
or application properties to allow for different retry strategies based on 
deployment environment or specific AI provider requirements.



##########
shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java:
##########
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shenyu.plugin.ai.proxy.enhanced;
+
+import org.apache.shenyu.common.constant.Constants;
+import org.apache.shenyu.common.dto.RuleData;
+import org.apache.shenyu.common.dto.SelectorData;
+import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
+import org.apache.shenyu.common.enums.AiModelProviderEnum;
+import org.apache.shenyu.common.enums.PluginEnum;
+import org.apache.shenyu.common.utils.JsonUtils;
+import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
+import 
org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
+import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
+import 
org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
+import org.apache.shenyu.plugin.api.ShenyuPluginChain;
+import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils;
+import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
+import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.chat.model.ChatResponse;
+import org.springframework.core.io.buffer.DataBuffer;
+import org.springframework.core.io.buffer.DataBufferUtils;
+import org.springframework.http.MediaType;
+import org.springframework.http.server.reactive.ServerHttpResponse;
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * AI proxy plugin.
+ * This plugin is used to proxy requests to AI services.
+ */
+public class AiProxyPlugin extends AbstractShenyuPlugin {
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(AiProxyPlugin.class);
+
+    private final AiModelFactoryRegistry aiModelFactoryRegistry;
+
+    private final AiProxyConfigService aiProxyConfigService;
+
+    private final AiProxyExecutorService aiProxyExecutorService;
+
+    private final ChatClientCache chatClientCache;
+
+    private final AiProxyPluginHandler aiProxyPluginHandler;
+
+    public AiProxyPlugin(final AiModelFactoryRegistry aiModelFactoryRegistry,
+                         final AiProxyConfigService aiProxyConfigService,
+                         final AiProxyExecutorService aiProxyExecutorService,
+                         final ChatClientCache chatClientCache,
+                         final AiProxyPluginHandler aiProxyPluginHandler) {
+        this.aiModelFactoryRegistry = aiModelFactoryRegistry;
+        this.aiProxyConfigService = aiProxyConfigService;
+        this.aiProxyExecutorService = aiProxyExecutorService;
+        this.chatClientCache = chatClientCache;
+        this.aiProxyPluginHandler = aiProxyPluginHandler;
+    }
+
+    @Override
+    protected Mono<Void> doExecute(final ServerWebExchange exchange, final 
ShenyuPluginChain chain,
+                                   final SelectorData selector, final RuleData 
rule) {
+        final AiProxyHandle selectorHandle = 
aiProxyPluginHandler.getSelectorCachedHandle()
+                .obtainHandle(CacheKeyUtils.INST.getKey(selector.getId(), 
Constants.DEFAULT_RULE));
+
+        return DataBufferUtils.join(exchange.getRequest().getBody())
+                .flatMap(dataBuffer -> {
+                    final String requestBody = 
dataBuffer.toString(StandardCharsets.UTF_8);
+                    DataBufferUtils.release(dataBuffer);
+
+                    final AiCommonConfig primaryConfig = 
aiProxyConfigService.resolvePrimaryConfig(selectorHandle);
+
+                    if (Boolean.TRUE.equals(primaryConfig.getStream())) {
+                        return handleStreamRequest(exchange, selector, 
requestBody, primaryConfig, selectorHandle);
+                    } else {
+                        return handleNonStreamRequest(exchange, selector, 
requestBody, primaryConfig, selectorHandle);
+                    }
+                });
+    }
+
+    private Mono<Void> handleStreamRequest(final ServerWebExchange exchange, 
final SelectorData selector,
+                                           final String requestBody, final 
AiCommonConfig primaryConfig,
+                                           final AiProxyHandle selectorHandle) 
{
+        final ChatClient mainClient = createMainChatClient(selector.getId(), 
primaryConfig);
+        final Optional<ChatClient> fallbackClient = 
resolveFallbackClient(primaryConfig, selectorHandle, selector.getId(), 
requestBody);
+        final ServerHttpResponse response = exchange.getResponse();
+        response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM);
+
+        final Flux<ChatResponse> chatResponseFlux = 
aiProxyExecutorService.executeStream(mainClient, fallbackClient, requestBody);
+
+        final Flux<DataBuffer> sseFlux = chatResponseFlux.map(chatResponse -> {
+            final String json = JsonUtils.toJson(chatResponse);
+            final String sseData = "data: " + json + "\n\n";
+            return 
response.bufferFactory().wrap(sseData.getBytes(StandardCharsets.UTF_8));
+        });
+
+        return response.writeWith(sseFlux);
+    }
+
+    private Mono<Void> handleNonStreamRequest(final ServerWebExchange 
exchange, final SelectorData selector,
+                                              final String requestBody, final 
AiCommonConfig primaryConfig,
+                                              final AiProxyHandle 
selectorHandle) {
+        final ChatClient mainClient = createMainChatClient(selector.getId(), 
primaryConfig);
+        final Optional<ChatClient> fallbackClient = 
resolveFallbackClient(primaryConfig, selectorHandle, selector.getId(), 
requestBody);
+
+        return aiProxyExecutorService.execute(mainClient, fallbackClient, 
requestBody)
+                .flatMap(response -> {
+                    byte[] jsonBytes = 
JsonUtils.toJson(response).getBytes(StandardCharsets.UTF_8);
+                    return WebFluxResultUtils.result(exchange, jsonBytes);
+                });
+    }
+
+    private Optional<ChatClient> resolveFallbackClient(final AiCommonConfig 
primaryConfig, final AiProxyHandle selectorHandle,
+                                                       final String 
selectorId, final String requestBody) {
+        return aiProxyConfigService
+                .resolveDynamicFallbackConfig(primaryConfig, requestBody)
+                .map(this::createDynamicFallbackClient)
+                .or(() -> 
aiProxyConfigService.resolveAdminFallbackConfig(primaryConfig, selectorHandle)
+                        .map(adminFallbackConfig -> 
createAdminFallbackClient(selectorId, adminFallbackConfig)));
+    }
+
+    private ChatClient createMainChatClient(final String selectorId, final 
AiCommonConfig config) {
+        return chatClientCache.computeIfAbsent(selectorId,
+                () -> {
+                    LOG.info("Creating and caching main model for selector: 
{}", selectorId);
+                    return createChatModel(config);
+                });
+    }
+
+    private ChatClient createAdminFallbackClient(final String selectorId, 
final AiCommonConfig fallbackConfig) {
+        final String fallbackCacheKey = selectorId + ":fallback";
+        return chatClientCache.computeIfAbsent(fallbackCacheKey, () -> {
+            LOG.info("Creating and caching admin fallback model for selector: 
{}", selectorId);
+            return createChatModel(fallbackConfig);
+        });
+    }
+
+    private ChatClient createDynamicFallbackClient(final AiCommonConfig 
fallbackConfig) {
+        LOG.info("Creating non-cached dynamic fallback model.");
+        return ChatClient.builder(createChatModel(fallbackConfig)).build();
+    }
+
+    private ChatModel createChatModel(final AiCommonConfig config) {
+        LOG.info("Creating chat model with config: {}", config);
+        final AiModelProviderEnum provider = 
AiModelProviderEnum.getByName(config.getProvider());
+        if (Objects.isNull(provider)) {
+            throw new IllegalArgumentException("Invalid AI model provider in 
config: " + config.getProvider());

Review Comment:
   The error message 'Invalid AI model provider in config: ' + 
config.getProvider() could expose sensitive configuration details in logs. 
Consider using a more generic error message or ensuring this information is 
properly sanitized before logging.
   ```suggestion
               throw new IllegalArgumentException("Invalid AI model provider in 
config.");
   ```



##########
shenyu-common/src/main/java/org/apache/shenyu/common/dto/convert/rule/AiProxyHandle.java:
##########
@@ -218,24 +242,212 @@ public boolean equals(final Object o) {
                 && Objects.equals(model, that.model)
                 && Objects.equals(temperature, that.temperature)
                 && Objects.equals(maxTokens, that.maxTokens)
-                && Objects.equals(stream, that.stream);
+                && Objects.equals(stream, that.stream)
+                && Objects.equals(fallbackConfig, that.fallbackConfig);
     }
-    
+
     @Override
     public int hashCode() {
-        return Objects.hash(provider, baseUrl, apiKey, model, temperature, 
maxTokens, stream);
+        return Objects.hash(provider, baseUrl, apiKey, model, temperature, 
maxTokens, stream, fallbackConfig);
     }
-    
+
     @Override
     public String toString() {
         return "AiProxyHandle{"
                 + "provider='" + provider + '\''
                 + ", baseUrl='" + baseUrl + '\''
-                + ", apiKey='" + apiKey + '\''
+                + ", apiKey='" + maskApiKey(apiKey) + '\''
                 + ", model='" + model + '\''
                 + ", temperature=" + temperature
                 + ", maxTokens=" + maxTokens
                 + ", stream=" + stream
+                + ", fallbackConfig=" + fallbackConfig
                 + '}';
     }
-}
+
+    public static String maskApiKey(final String apiKey) {
+        if (Objects.isNull(apiKey) || apiKey.length() <= 7) {
+            return apiKey;
+        }
+        if (apiKey.isEmpty()) {
+            return apiKey;
+        }

Review Comment:
   There's a potential null pointer exception. The method checks for null and 
length <= 7 on line 269, but then checks `apiKey.isEmpty()` on line 272 without 
ensuring apiKey is not null first. The isEmpty() check should be moved before 
the length check or combined with the null check.
   ```suggestion
           if (Objects.isNull(apiKey) || apiKey.isEmpty()) {
               return apiKey;
           }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: notifications-unsubscr...@shenyu.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to