Copilot commented on code in PR #6060: URL: https://github.com/apache/shenyu/pull/6060#discussion_r2280397566
########## shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-common/src/main/java/org/apache/shenyu/plugin/ai/common/spring/ai/factory/OpenAiModelFactory.java: ########## @@ -42,7 +42,7 @@ public ChatModel createAiModel(final AiCommonConfig config) { .build(); OpenAiChatOptions.Builder model = OpenAiChatOptions.builder().model(config.getModel()); Optional.ofNullable(config.getTemperature()).ifPresent(model::temperature); - Optional.ofNullable(config.getMaxTokens()).ifPresent(model::maxTokens); + Optional.ofNullable(config.getMaxTokens()).ifPresent(model::maxCompletionTokens); Review Comment: The change from `maxTokens` to `maxCompletionTokens` appears to be an API breaking change. This should be documented or verified that it's compatible with the Spring AI version being used, as this could cause runtime errors if the API doesn't support this method. ```suggestion Optional.ofNullable(config.getMaxTokens()).ifPresent(model::maxTokens); ``` ########## shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyExecutorService.java: ########## @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced.service; + +import org.apache.shenyu.plugin.ai.common.strategy.SimpleModelFallbackStrategy; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.retry.NonTransientAiException; +import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.util.retry.Retry; + +import java.time.Duration; +import java.util.Optional; + +/** + * AI proxy executor service. + */ +@Service +public class AiProxyExecutorService { + + private static final Logger LOG = LoggerFactory.getLogger(AiProxyExecutorService.class); + + /** + * Execute the AI call with retry and fallback. + * + * @param mainClient the main chat client + * @param fallbackClientOpt the optional fallback chat client + * @param requestBody the request body + * @return a Mono containing the ChatResponse + */ + public Mono<ChatResponse> execute(final ChatClient mainClient, final Optional<ChatClient> fallbackClientOpt, final String requestBody) { + final Mono<ChatResponse> mainCall = doChatCall(mainClient, requestBody); + + return mainCall Review Comment: [nitpick] The retry configuration is hardcoded (3 retries, 1 second backoff). Consider making these values configurable through the AiCommonConfig or application properties to allow for different retry strategies based on deployment environment or specific AI provider requirements. ########## shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java: ########## @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shenyu.plugin.ai.proxy.enhanced; + +import org.apache.shenyu.common.constant.Constants; +import org.apache.shenyu.common.dto.RuleData; +import org.apache.shenyu.common.dto.SelectorData; +import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle; +import org.apache.shenyu.common.enums.AiModelProviderEnum; +import org.apache.shenyu.common.enums.PluginEnum; +import org.apache.shenyu.common.utils.JsonUtils; +import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig; +import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry; +import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache; +import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler; +import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService; +import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService; +import org.apache.shenyu.plugin.api.ShenyuPluginChain; +import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils; +import org.apache.shenyu.plugin.base.AbstractShenyuPlugin; +import org.apache.shenyu.plugin.base.utils.CacheKeyUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.MediaType; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.nio.charset.StandardCharsets; +import java.util.Objects; +import java.util.Optional; + +/** + * AI proxy plugin. + * This plugin is used to proxy requests to AI services. + */ +public class AiProxyPlugin extends AbstractShenyuPlugin { + + private static final Logger LOG = LoggerFactory.getLogger(AiProxyPlugin.class); + + private final AiModelFactoryRegistry aiModelFactoryRegistry; + + private final AiProxyConfigService aiProxyConfigService; + + private final AiProxyExecutorService aiProxyExecutorService; + + private final ChatClientCache chatClientCache; + + private final AiProxyPluginHandler aiProxyPluginHandler; + + public AiProxyPlugin(final AiModelFactoryRegistry aiModelFactoryRegistry, + final AiProxyConfigService aiProxyConfigService, + final AiProxyExecutorService aiProxyExecutorService, + final ChatClientCache chatClientCache, + final AiProxyPluginHandler aiProxyPluginHandler) { + this.aiModelFactoryRegistry = aiModelFactoryRegistry; + this.aiProxyConfigService = aiProxyConfigService; + this.aiProxyExecutorService = aiProxyExecutorService; + this.chatClientCache = chatClientCache; + this.aiProxyPluginHandler = aiProxyPluginHandler; + } + + @Override + protected Mono<Void> doExecute(final ServerWebExchange exchange, final ShenyuPluginChain chain, + final SelectorData selector, final RuleData rule) { + final AiProxyHandle selectorHandle = aiProxyPluginHandler.getSelectorCachedHandle() + .obtainHandle(CacheKeyUtils.INST.getKey(selector.getId(), Constants.DEFAULT_RULE)); + + return DataBufferUtils.join(exchange.getRequest().getBody()) + .flatMap(dataBuffer -> { + final String requestBody = dataBuffer.toString(StandardCharsets.UTF_8); + DataBufferUtils.release(dataBuffer); + + final AiCommonConfig primaryConfig = aiProxyConfigService.resolvePrimaryConfig(selectorHandle); + + if (Boolean.TRUE.equals(primaryConfig.getStream())) { + return handleStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle); + } else { + return handleNonStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle); + } + }); + } + + private Mono<Void> handleStreamRequest(final ServerWebExchange exchange, final SelectorData selector, + final String requestBody, final AiCommonConfig primaryConfig, + final AiProxyHandle selectorHandle) { + final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig); + final Optional<ChatClient> fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle, selector.getId(), requestBody); + final ServerHttpResponse response = exchange.getResponse(); + response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM); + + final Flux<ChatResponse> chatResponseFlux = aiProxyExecutorService.executeStream(mainClient, fallbackClient, requestBody); + + final Flux<DataBuffer> sseFlux = chatResponseFlux.map(chatResponse -> { + final String json = JsonUtils.toJson(chatResponse); + final String sseData = "data: " + json + "\n\n"; + return response.bufferFactory().wrap(sseData.getBytes(StandardCharsets.UTF_8)); + }); + + return response.writeWith(sseFlux); + } + + private Mono<Void> handleNonStreamRequest(final ServerWebExchange exchange, final SelectorData selector, + final String requestBody, final AiCommonConfig primaryConfig, + final AiProxyHandle selectorHandle) { + final ChatClient mainClient = createMainChatClient(selector.getId(), primaryConfig); + final Optional<ChatClient> fallbackClient = resolveFallbackClient(primaryConfig, selectorHandle, selector.getId(), requestBody); + + return aiProxyExecutorService.execute(mainClient, fallbackClient, requestBody) + .flatMap(response -> { + byte[] jsonBytes = JsonUtils.toJson(response).getBytes(StandardCharsets.UTF_8); + return WebFluxResultUtils.result(exchange, jsonBytes); + }); + } + + private Optional<ChatClient> resolveFallbackClient(final AiCommonConfig primaryConfig, final AiProxyHandle selectorHandle, + final String selectorId, final String requestBody) { + return aiProxyConfigService + .resolveDynamicFallbackConfig(primaryConfig, requestBody) + .map(this::createDynamicFallbackClient) + .or(() -> aiProxyConfigService.resolveAdminFallbackConfig(primaryConfig, selectorHandle) + .map(adminFallbackConfig -> createAdminFallbackClient(selectorId, adminFallbackConfig))); + } + + private ChatClient createMainChatClient(final String selectorId, final AiCommonConfig config) { + return chatClientCache.computeIfAbsent(selectorId, + () -> { + LOG.info("Creating and caching main model for selector: {}", selectorId); + return createChatModel(config); + }); + } + + private ChatClient createAdminFallbackClient(final String selectorId, final AiCommonConfig fallbackConfig) { + final String fallbackCacheKey = selectorId + ":fallback"; + return chatClientCache.computeIfAbsent(fallbackCacheKey, () -> { + LOG.info("Creating and caching admin fallback model for selector: {}", selectorId); + return createChatModel(fallbackConfig); + }); + } + + private ChatClient createDynamicFallbackClient(final AiCommonConfig fallbackConfig) { + LOG.info("Creating non-cached dynamic fallback model."); + return ChatClient.builder(createChatModel(fallbackConfig)).build(); + } + + private ChatModel createChatModel(final AiCommonConfig config) { + LOG.info("Creating chat model with config: {}", config); + final AiModelProviderEnum provider = AiModelProviderEnum.getByName(config.getProvider()); + if (Objects.isNull(provider)) { + throw new IllegalArgumentException("Invalid AI model provider in config: " + config.getProvider()); Review Comment: The error message 'Invalid AI model provider in config: ' + config.getProvider() could expose sensitive configuration details in logs. Consider using a more generic error message or ensuring this information is properly sanitized before logging. ```suggestion throw new IllegalArgumentException("Invalid AI model provider in config."); ``` ########## shenyu-common/src/main/java/org/apache/shenyu/common/dto/convert/rule/AiProxyHandle.java: ########## @@ -218,24 +242,212 @@ public boolean equals(final Object o) { && Objects.equals(model, that.model) && Objects.equals(temperature, that.temperature) && Objects.equals(maxTokens, that.maxTokens) - && Objects.equals(stream, that.stream); + && Objects.equals(stream, that.stream) + && Objects.equals(fallbackConfig, that.fallbackConfig); } - + @Override public int hashCode() { - return Objects.hash(provider, baseUrl, apiKey, model, temperature, maxTokens, stream); + return Objects.hash(provider, baseUrl, apiKey, model, temperature, maxTokens, stream, fallbackConfig); } - + @Override public String toString() { return "AiProxyHandle{" + "provider='" + provider + '\'' + ", baseUrl='" + baseUrl + '\'' - + ", apiKey='" + apiKey + '\'' + + ", apiKey='" + maskApiKey(apiKey) + '\'' + ", model='" + model + '\'' + ", temperature=" + temperature + ", maxTokens=" + maxTokens + ", stream=" + stream + + ", fallbackConfig=" + fallbackConfig + '}'; } -} + + public static String maskApiKey(final String apiKey) { + if (Objects.isNull(apiKey) || apiKey.length() <= 7) { + return apiKey; + } + if (apiKey.isEmpty()) { + return apiKey; + } Review Comment: There's a potential null pointer exception. The method checks for null and length <= 7 on line 269, but then checks `apiKey.isEmpty()` on line 272 without ensuring apiKey is not null first. The isEmpty() check should be moved before the length check or combined with the null check. ```suggestion if (Objects.isNull(apiKey) || apiKey.isEmpty()) { return apiKey; } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: notifications-unsubscr...@shenyu.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org