This is an automated email from the ASF dual-hosted git repository.
xiaoyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shenyu.git
The following commit(s) were added to refs/heads/master by this push:
new bbdcbfebb1 feat: Introduce request body size limit, enhance AI proxy
plugin's ChatClient caching with config hash, and refine prompt extraction from
the request body. (#6248)
bbdcbfebb1 is described below
commit bbdcbfebb13202273462a8fbab2502551770358c
Author: aias00 <[email protected]>
AuthorDate: Mon Dec 8 11:07:20 2025 +0800
feat: Introduce request body size limit, enhance AI proxy plugin's
ChatClient caching with config hash, and refine prompt extraction from the
request body. (#6248)
* feat: Introduce request body size limit, enhance AI proxy plugin's
ChatClient caching with config hash, and refine prompt extraction from the
request body.
* Update
shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
Co-authored-by: Copilot <[email protected]>
* Update
shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigService.java
Co-authored-by: Copilot <[email protected]>
* feat: Implement cache eviction strategy in ChatClientCache and enforce
maximum request body size in AiProxyPlugin
* feat: integrate mvnd installation and usage in CI workflows
---------
Co-authored-by: Copilot <[email protected]>
---
.../plugin/ai/proxy/enhanced/AiProxyPlugin.java | 132 +++++++++++++--------
.../ai/proxy/enhanced/cache/ChatClientCache.java | 93 ++++++++++++---
.../enhanced/service/AiProxyConfigService.java | 33 +++++-
.../ai/proxy/enhanced/AiProxyPluginTest.java | 21 ++--
.../enhanced/service/AiProxyConfigServiceTest.java | 75 ++++++++++++
5 files changed, 275 insertions(+), 79 deletions(-)
diff --git
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
index 916c24dda5..687c51233f 100644
---
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
+++
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPlugin.java
@@ -60,6 +60,11 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
private static final Logger LOG =
LoggerFactory.getLogger(AiProxyPlugin.class);
+ /**
+ * Maximum request body size: 5MB.
+ */
+ private static final long MAX_REQUEST_BODY_SIZE_BYTES = 5 * 1024 * 1024L;
+
private final AiModelFactoryRegistry aiModelFactoryRegistry;
private final AiProxyConfigService aiProxyConfigService;
@@ -72,10 +77,10 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
public AiProxyPlugin(
final AiModelFactoryRegistry aiModelFactoryRegistry,
- final AiProxyConfigService aiProxyConfigService,
- final AiProxyExecutorService aiProxyExecutorService,
- final ChatClientCache chatClientCache,
- final AiProxyPluginHandler aiProxyPluginHandler) {
+ final AiProxyConfigService aiProxyConfigService,
+ final AiProxyExecutorService aiProxyExecutorService,
+ final ChatClientCache chatClientCache,
+ final AiProxyPluginHandler aiProxyPluginHandler) {
this.aiModelFactoryRegistry = aiModelFactoryRegistry;
this.aiProxyConfigService = aiProxyConfigService;
this.aiProxyExecutorService = aiProxyExecutorService;
@@ -89,25 +94,34 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
final ShenyuPluginChain chain,
final SelectorData selector,
final RuleData rule) {
- final AiProxyHandle selectorHandle =
- aiProxyPluginHandler
- .getSelectorCachedHandle()
- .obtainHandle(
- CacheKeyUtils.INST.getKey(
- selector.getId(),
Constants.DEFAULT_RULE));
+ final AiProxyHandle selectorHandle = aiProxyPluginHandler
+ .getSelectorCachedHandle()
+ .obtainHandle(
+ CacheKeyUtils.INST.getKey(
+ selector.getId(), Constants.DEFAULT_RULE));
return DataBufferUtils.join(exchange.getRequest().getBody())
.flatMap(dataBuffer -> {
+ // Validate actual body size after reading, not just
Content-Length header
+ final int actualSize = dataBuffer.readableByteCount();
+ if (actualSize > MAX_REQUEST_BODY_SIZE_BYTES) {
+ DataBufferUtils.release(dataBuffer);
+ LOG.warn("[AiProxy] Request body size {} exceeds
maximum allowed size {}",
+ actualSize, MAX_REQUEST_BODY_SIZE_BYTES);
+
exchange.getResponse().setStatusCode(HttpStatus.PAYLOAD_TOO_LARGE);
+ return exchange.getResponse().setComplete();
+ }
+
final String requestBody =
dataBuffer.toString(StandardCharsets.UTF_8);
DataBufferUtils.release(dataBuffer);
- final AiCommonConfig primaryConfig =
-
aiProxyConfigService.resolvePrimaryConfig(selectorHandle);
+ final AiCommonConfig primaryConfig =
aiProxyConfigService.resolvePrimaryConfig(selectorHandle);
// override apiKey by proxy key if provided in header
final HttpHeaders headers =
exchange.getRequest().getHeaders();
final String proxyApiKey =
headers.getFirst(Constants.X_API_KEY);
- final boolean proxyEnabled =
Objects.nonNull(selectorHandle) &&
"true".equalsIgnoreCase(String.valueOf(selectorHandle.getProxyEnabled()));
+ final boolean proxyEnabled =
Objects.nonNull(selectorHandle)
+ &&
"true".equalsIgnoreCase(String.valueOf(selectorHandle.getProxyEnabled()));
if (proxyEnabled) {
// if proxy mode enabled but header missing -> 401
@@ -116,12 +130,13 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
return exchange.getResponse().setComplete();
}
- final String realKey =
-
AiProxyApiKeyCache.getInstance().getRealApiKey(selector.getId(), proxyApiKey);
+ final String realKey =
AiProxyApiKeyCache.getInstance().getRealApiKey(selector.getId(),
+ proxyApiKey);
if (Objects.nonNull(realKey)) {
primaryConfig.setApiKey(realKey);
if (LOG.isDebugEnabled()) {
- LOG.debug("[AiProxy] proxy key hit,
selectorId={}, key={}... (masked)", selector.getId(), proxyApiKey.substring(0,
Math.min(6, proxyApiKey.length())));
+ LOG.debug("[AiProxy] proxy key hit,
selectorId={}, key={}... (masked)",
+ selector.getId(),
proxyApiKey.substring(0, Math.min(6, proxyApiKey.length())));
}
LOG.info("[AiProxy] proxy key hit, cacheSize={}",
AiProxyApiKeyCache.getInstance().size());
} else {
@@ -147,22 +162,22 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle) {
final ChatClient mainClient = createMainChatClient(selector.getId(),
primaryConfig);
- final Optional<ChatClient> fallbackClient =
- resolveFallbackClient(primaryConfig, selectorHandle,
selector.getId(), requestBody);
+ final String prompt = aiProxyConfigService.extractPrompt(requestBody);
+ final Optional<ChatClient> fallbackClient =
resolveFallbackClient(primaryConfig, selectorHandle,
+ selector.getId(), requestBody);
final ServerHttpResponse response = exchange.getResponse();
response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM);
- final Flux<ChatResponse> chatResponseFlux =
- aiProxyExecutorService.executeStream(mainClient,
fallbackClient, requestBody);
+ final Flux<ChatResponse> chatResponseFlux =
aiProxyExecutorService.executeStream(mainClient, fallbackClient,
+ prompt);
- final Flux<DataBuffer> sseFlux =
- chatResponseFlux.map(
- chatResponse -> {
- final String json = JsonUtils.toJson(chatResponse);
- final String sseData = "data: " + json + "\n\n";
- return response.bufferFactory()
-
.wrap(sseData.getBytes(StandardCharsets.UTF_8));
- });
+ final Flux<DataBuffer> sseFlux = chatResponseFlux.map(
+ chatResponse -> {
+ final String json = JsonUtils.toJson(chatResponse);
+ final String sseData = "data: " + json + "\n\n";
+ return response.bufferFactory()
+ .wrap(sseData.getBytes(StandardCharsets.UTF_8));
+ });
return response.writeWith(sseFlux);
}
@@ -174,15 +189,15 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
final AiCommonConfig primaryConfig,
final AiProxyHandle selectorHandle) {
final ChatClient mainClient = createMainChatClient(selector.getId(),
primaryConfig);
- final Optional<ChatClient> fallbackClient =
- resolveFallbackClient(primaryConfig, selectorHandle,
selector.getId(), requestBody);
+ final String prompt = aiProxyConfigService.extractPrompt(requestBody);
+ final Optional<ChatClient> fallbackClient =
resolveFallbackClient(primaryConfig, selectorHandle,
+ selector.getId(), requestBody);
return aiProxyExecutorService
- .execute(mainClient, fallbackClient, requestBody)
+ .execute(mainClient, fallbackClient, prompt)
.flatMap(
response -> {
- byte[] jsonBytes =
-
JsonUtils.toJson(response).getBytes(StandardCharsets.UTF_8);
+ byte[] jsonBytes =
JsonUtils.toJson(response).getBytes(StandardCharsets.UTF_8);
return WebFluxResultUtils.result(exchange,
jsonBytes);
});
}
@@ -202,36 +217,57 @@ public class AiProxyPlugin extends AbstractShenyuPlugin {
return createDynamicFallbackClient(cfg);
})
.or(
- () ->
- aiProxyConfigService
-
.resolveAdminFallbackConfig(primaryConfig, selectorHandle)
- .map(adminFallbackConfig -> {
- LOG.info("[AiProxy] use admin
fallback");
- if (LOG.isDebugEnabled()) {
- LOG.debug("[AiProxy] admin
fallback config: {}", adminFallbackConfig);
- }
- return
createAdminFallbackClient(selectorId, adminFallbackConfig);
- }));
+ () -> aiProxyConfigService
+ .resolveAdminFallbackConfig(primaryConfig,
selectorHandle)
+ .map(adminFallbackConfig -> {
+ LOG.info("[AiProxy] use admin fallback");
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("[AiProxy] admin fallback
config: {}", adminFallbackConfig);
+ }
+ return
createAdminFallbackClient(selectorId, adminFallbackConfig);
+ }));
+ }
+
+ /**
+ * Generate cache key based on config fields excluding apiKey.
+ * This ensures cache consistency even when apiKey is updated at runtime.
+ *
+ * @param config the config
+ * @return cache key hash
+ */
+ private int generateConfigCacheKey(final AiCommonConfig config) {
+ return Objects.hash(
+ config.getProvider(),
+ config.getBaseUrl(),
+ config.getModel(),
+ config.getTemperature(),
+ config.getMaxTokens(),
+ config.getStream()
+ // Explicitly exclude apiKey to avoid cache misses when apiKey
changes
+ );
}
private ChatClient createMainChatClient(final String selectorId, final
AiCommonConfig config) {
+ final int configHash = generateConfigCacheKey(config);
+ final String cacheKey = selectorId + "|main_" + configHash;
return chatClientCache.computeIfAbsent(
- selectorId,
+ cacheKey,
() -> {
- LOG.info("Creating and caching main model for selector:
{}", selectorId);
+ LOG.info("Creating and caching main model for selector:
{}, key: {}", selectorId, cacheKey);
return createChatModel(config);
});
}
private ChatClient createAdminFallbackClient(
final String selectorId, final AiCommonConfig fallbackConfig) {
- final String fallbackCacheKey = selectorId + "|adminFallback";
+ final int configHash = generateConfigCacheKey(fallbackConfig);
+ final String fallbackCacheKey = selectorId + "|adminFallback_" +
configHash;
return chatClientCache.computeIfAbsent(
fallbackCacheKey,
() -> {
LOG.info(
- "Creating and caching admin fallback model for
selector: {}",
- selectorId);
+ "Creating and caching admin fallback model for
selector: {}, key: {}",
+ selectorId, fallbackCacheKey);
return createChatModel(fallbackConfig);
});
}
diff --git
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
index 1d7f3b6338..473ace98fa 100644
---
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
+++
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/cache/ChatClientCache.java
@@ -23,37 +23,104 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Map;
+import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
/**
* This is ChatClient cache.
*/
public final class ChatClientCache {
-
+
private static final Logger LOG =
LoggerFactory.getLogger(ChatClientCache.class);
-
+
+ private static final int MAX_CACHE_SIZE = getCacheSize();
+
private final Map<String, ChatClient> chatClientMap = new
ConcurrentHashMap<>();
-
+
+ private final AtomicBoolean evictionInProgress = new AtomicBoolean(false);
+
/**
* Instantiates a new Chat client cache.
*/
public ChatClientCache() {
}
+ private static int getCacheSize() {
+ String value =
System.getProperty("shenyu.plugin.ai.proxy.enhanced.cache.maxSize",
+
System.getenv("SHENYU_PLUGIN_AI_PROXY_ENHANCED_CACHE_MAXSIZE"));
+ if (Objects.nonNull(value)) {
+ try {
+ return Integer.parseInt(value);
+ } catch (NumberFormatException e) {
+ LoggerFactory.getLogger(ChatClientCache.class)
+ .warn("[ChatClientCache] Invalid cache size '{}',
using default 500.", value);
+ }
+ }
+ return 500;
+ }
+
/**
* Gets client or compute if absent.
*
- * @param key the key
+ * @param key the key
* @param chatModelSupplier the chat model supplier
* @return the chat client
*/
public ChatClient computeIfAbsent(final String key, final
Supplier<ChatModel> chatModelSupplier) {
+ // Check size before computing, but use synchronized block to prevent
race conditions
+ final int currentSize = chatClientMap.size();
+ if (currentSize > MAX_CACHE_SIZE) {
+ // Use atomic flag to ensure only one thread performs eviction
+ if (evictionInProgress.compareAndSet(false, true)) {
+ try {
+ synchronized (chatClientMap) {
+ // Double-check after acquiring lock
+ if (chatClientMap.size() > MAX_CACHE_SIZE) {
+ evictOldestEntries();
+ }
+ }
+ } finally {
+ evictionInProgress.set(false);
+ }
+ }
+ }
return chatClientMap.computeIfAbsent(key, k ->
ChatClient.builder(chatModelSupplier.get()).build());
}
-
+
+ /**
+ * Evict oldest entries when cache size exceeds limit.
+ * Removes approximately 25% of entries to avoid thundering herd problem.
+ */
+ private void evictOldestEntries() {
+ final int currentSize = chatClientMap.size();
+ if (currentSize <= MAX_CACHE_SIZE) {
+ return;
+ }
+
+ // Evict 25% of entries, but at least 10 entries
+ final int evictCount = Math.max(10, currentSize / 4);
+ LOG.warn("[ChatClientCache] Cache size {} exceeded limit {}, evicting
{} oldest entries",
+ currentSize, MAX_CACHE_SIZE, evictCount);
+
+ // Since ConcurrentHashMap doesn't maintain insertion order,
+ // we evict entries based on iteration order (which is somewhat
arbitrary but better than clearing all)
+ int removed = 0;
+ for (final String key : chatClientMap.keySet()) {
+ if (removed >= evictCount) {
+ break;
+ }
+ chatClientMap.remove(key);
+ removed++;
+ }
+
+ LOG.info("[ChatClientCache] Evicted {} entries, cache size now: {}",
removed, chatClientMap.size());
+ }
+
/**
- * Removes all cached clients associated with a selector ID (by prefix
matching "selectorId|").
+ * Removes all cached clients associated with a selector ID (by prefix
matching
+ * "selectorId|").
*
* @param selectorId the selector id
*/
@@ -65,7 +132,7 @@ public final class ChatClientCache {
chatClientMap.keySet().removeIf(k -> k.equals(selectorId) ||
k.startsWith(prefix));
LOG.info("[ChatClientCache] invalidate selectorId={} (by prefix)",
selectorId);
}
-
+
/**
* Clear all cached clients.
*/
@@ -73,16 +140,4 @@ public final class ChatClientCache {
chatClientMap.clear();
LOG.info("[ChatClientCache] cleared all cached clients");
}
-
- /**
- * Puts a client directly into the cache.
- *
- * <p>NOTE: This method is intended for testing purposes to inject mock
clients.
- *
- * @param key the key
- * @param client the client to cache
- */
- public void put(final String key, final ChatClient client) {
- chatClientMap.put(key, client);
- }
}
\ No newline at end of file
diff --git
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigService.java
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigService.java
index 90485939d0..ce89eb498d 100644
---
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigService.java
+++
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/main/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigService.java
@@ -38,7 +38,8 @@ public class AiProxyConfigService {
private static final String FALLBACK_CONFIG = "fallbackConfig";
/**
- * Resolves the primary configuration for the AI call by merging global
and selector-level settings.
+ * Resolves the primary configuration for the AI call by merging global and
+ * selector-level settings.
*
* @param handle the selector handle
* @return the primary AiCommonConfig
@@ -71,7 +72,8 @@ public class AiProxyConfigService {
* @param requestBody the request body
* @return an Optional containing the final fallback AiCommonConfig
*/
- public Optional<AiCommonConfig> resolveDynamicFallbackConfig(final
AiCommonConfig primaryConfig, final String requestBody) {
+ public Optional<AiCommonConfig> resolveDynamicFallbackConfig(final
AiCommonConfig primaryConfig,
+ final String requestBody) {
return extractDynamicFallbackConfig(requestBody)
.map(dynamicConfig -> {
LOG.info("Resolved dynamic fallback config: {}",
dynamicConfig);
@@ -86,7 +88,8 @@ public class AiProxyConfigService {
* @param handle the selector handle
* @return an Optional containing the final fallback AiCommonConfig
*/
- public Optional<AiCommonConfig> resolveAdminFallbackConfig(final
AiCommonConfig primaryConfig, final AiProxyHandle handle) {
+ public Optional<AiCommonConfig> resolveAdminFallbackConfig(final
AiCommonConfig primaryConfig,
+ final AiProxyHandle handle) {
return Optional.ofNullable(handle)
.map(AiProxyHandle::getFallbackConfig)
.map(fallback -> {
@@ -121,4 +124,28 @@ public class AiProxyConfigService {
}
return Optional.empty();
}
+
+ /**
+ * Extract prompt from request body if it contains fallback config.
+ *
+ * @param requestBody the request body
+ * @return the extracted prompt or original body
+ */
+ public String extractPrompt(final String requestBody) {
+ if (Objects.isNull(requestBody) || requestBody.isEmpty()) {
+ return requestBody;
+ }
+ try {
+ JsonNode jsonNode = JsonUtils.toJsonNode(requestBody);
+ if (jsonNode.has("prompt")) {
+ return jsonNode.get("prompt").asText();
+ }
+ if (jsonNode.has("content")) {
+ return jsonNode.get("content").asText();
+ }
+ } catch (Exception e) {
+ // ignore
+ }
+ return requestBody;
+ }
}
diff --git
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
index edd38cec29..30e84abc08 100644
---
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
+++
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/AiProxyPluginTest.java
@@ -155,7 +155,8 @@ public class AiProxyPluginTest {
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig,
REQUEST_BODY)).thenReturn(fallbackConfig);
when(configService.resolveAdminFallbackConfig(primaryConfig,
handle)).thenReturn(fallbackConfig);
- when(executorService.execute(any(), any(),
anyString())).thenReturn(Mono.just(chatResponse));
+ when(configService.extractPrompt(anyString())).thenAnswer(invocation
-> invocation.getArgument(0));
+ when(executorService.execute(any(), any(),
any())).thenReturn(Mono.just(chatResponse));
}
@Test
@@ -172,7 +173,7 @@ public class AiProxyPluginTest {
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveDynamicFallbackConfig(primaryConfig,
REQUEST_BODY);
verify(configService).resolveAdminFallbackConfig(primaryConfig,
handle);
- verify(executorService).execute(any(), any(), anyString());
+ verify(executorService).execute(any(), any(), any());
}
@Test
@@ -189,7 +190,7 @@ public class AiProxyPluginTest {
StepVerifier.create(plugin.doExecute(exchange,
mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
- verify(executorService).execute(any(ChatClient.class),
any(Optional.class), anyString());
+ verify(executorService).execute(any(ChatClient.class),
any(Optional.class), any());
}
@Test
@@ -261,7 +262,7 @@ public class AiProxyPluginTest {
StepVerifier.create(plugin.doExecute(exchange,
mock(ShenyuPluginChain.class), selector, rule))
.verifyComplete();
- verify(executorService).execute(any(ChatClient.class),
any(Optional.class), anyString());
+ verify(executorService).execute(any(ChatClient.class),
any(Optional.class), any());
}
@Test
@@ -280,7 +281,8 @@ public class AiProxyPluginTest {
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig,
REQUEST_BODY)).thenReturn(Optional.empty());
when(configService.resolveAdminFallbackConfig(primaryConfig,
handle)).thenReturn(Optional.of(fallbackConfig));
- when(executorService.execute(any(), any(),
anyString())).thenReturn(Mono.just(chatResponse));
+ when(configService.extractPrompt(anyString())).thenAnswer(invocation
-> invocation.getArgument(0));
+ when(executorService.execute(any(), any(),
any())).thenReturn(Mono.just(chatResponse));
// Execute the test - focus on successful execution rather than cache
verification
StepVerifier.create(plugin.doExecute(exchange,
mock(ShenyuPluginChain.class), selector, rule))
@@ -289,7 +291,7 @@ public class AiProxyPluginTest {
// Verify that the configuration methods were called correctly
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveAdminFallbackConfig(primaryConfig,
handle);
- verify(executorService).execute(any(), any(), anyString());
+ verify(executorService).execute(any(), any(), any());
}
@Test
@@ -305,12 +307,13 @@ public class AiProxyPluginTest {
when(configService.resolvePrimaryConfig(handle)).thenReturn(primaryConfig);
when(configService.resolveDynamicFallbackConfig(primaryConfig,
REQUEST_BODY)).thenReturn(Optional.empty());
when(configService.resolveAdminFallbackConfig(primaryConfig,
handle)).thenReturn(Optional.empty());
+ when(configService.extractPrompt(anyString())).thenAnswer(invocation
-> invocation.getArgument(0));
// Mock registry to return null factory for invalid provider - this
should cause IllegalArgumentException
when(registry.getFactory(any())).thenReturn(null);
// Mock executorService to return a proper Mono to avoid
NullPointerException
- when(executorService.execute(any(), any(),
anyString())).thenReturn(Mono.error(new IllegalArgumentException("AI model
factory not found")));
+ when(executorService.execute(any(), any(),
any())).thenReturn(Mono.error(new IllegalArgumentException("AI model factory
not found")));
StepVerifier.create(plugin.doExecute(exchange,
mock(ShenyuPluginChain.class), selector, rule))
.expectError(IllegalArgumentException.class)
@@ -325,7 +328,7 @@ public class AiProxyPluginTest {
final RuntimeException exception = new RuntimeException("AI execution
failed");
setupSuccessMocks(handle, primaryConfig, Optional.empty());
- when(executorService.execute(any(), any(),
anyString())).thenReturn(Mono.error(exception));
+ when(executorService.execute(any(), any(),
any())).thenReturn(Mono.error(exception));
StepVerifier.create(plugin.doExecute(exchange,
mock(ShenyuPluginChain.class), selector, rule))
.expectErrorMatches(exception::equals)
@@ -334,6 +337,6 @@ public class AiProxyPluginTest {
verify(configService).resolvePrimaryConfig(handle);
verify(configService).resolveDynamicFallbackConfig(primaryConfig,
REQUEST_BODY);
verify(configService).resolveAdminFallbackConfig(primaryConfig,
handle);
- verify(executorService).execute(any(), any(), anyString());
+ verify(executorService).execute(any(), any(), any());
}
}
\ No newline at end of file
diff --git
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigServiceTest.java
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigServiceTest.java
index d4234bf119..79b1a47447 100644
---
a/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigServiceTest.java
+++
b/shenyu-plugin/shenyu-plugin-ai/shenyu-plugin-ai-proxy-enhanced/src/test/java/org/apache/shenyu/plugin/ai/proxy/enhanced/service/AiProxyConfigServiceTest.java
@@ -143,4 +143,79 @@ public class AiProxyConfigServiceTest {
assertFalse(result.isPresent());
}
+
+ @Test
+ void testExtractPromptNullInput() {
+ String result = configService.extractPrompt(null);
+ assertEquals(null, result);
+ }
+
+ @Test
+ void testExtractPromptEmptyInput() {
+ String result = configService.extractPrompt("");
+ assertEquals("", result);
+ }
+
+ @Test
+ void testExtractPromptWithPromptField() {
+ String requestBody = "{\"prompt\": \"test prompt content\"}";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals("test prompt content", result);
+ }
+
+ @Test
+ void testExtractPromptWithContentField() {
+ String requestBody = "{\"content\": \"test content\"}";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals("test content", result);
+ }
+
+ @Test
+ void testExtractPromptPromptTakesPrecedenceOverContent() {
+ String requestBody = "{\"prompt\": \"prompt value\", \"content\":
\"content value\"}";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals("prompt value", result);
+ }
+
+ @Test
+ void testExtractPromptWithFallbackConfigAndPrompt() {
+ String requestBody = "{\"fallbackConfig\": {\"model\":
\"test-model\"}, \"prompt\": \"test prompt\"}";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals("test prompt", result);
+ }
+
+ @Test
+ void testExtractPromptWithFallbackConfigAndContent() {
+ String requestBody = "{\"fallbackConfig\": {\"model\":
\"test-model\"}, \"content\": \"test content\"}";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals("test content", result);
+ }
+
+ @Test
+ void testExtractPromptWithFallbackConfigButNoPromptOrContent() {
+ String requestBody = "{\"fallbackConfig\": {\"model\":
\"test-model\"}, \"messages\": [{\"role\": \"user\"}]}";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals(requestBody, result);
+ }
+
+ @Test
+ void testExtractPromptWithoutPromptOrContent() {
+ String requestBody = "{\"messages\": [{\"role\": \"user\",
\"content\": \"hello\"}]}";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals(requestBody, result);
+ }
+
+ @Test
+ void testExtractPromptMalformedJson() {
+ String requestBody = "{\"prompt\": \"test\"";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals(requestBody, result);
+ }
+
+ @Test
+ void testExtractPromptEmptyJson() {
+ String requestBody = "{}";
+ String result = configService.extractPrompt(requestBody);
+ assertEquals(requestBody, result);
+ }
}