This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 4012835456310317273ce4c4c61fdc49af213ca9 Author: WenjinXie <[email protected]> AuthorDate: Fri Dec 26 11:12:45 2025 +0800 [plan] Support error handling strategy for chat action. fix [plan] Support error handling strategy for chat action in java. modify default max retries not manually clear sensory memory. fix fix --- .../org/apache/flink/agents/api/agents/Agent.java | 18 +++ ...nfigOptions.java => AgentExecutionOptions.java} | 13 +- .../flink/agents/api/agents/OutputSchema.java | 134 ++++++++++++++++ .../apache/flink/agents/api/agents/ReActAgent.java | 169 ++------------------- .../flink/agents/api/event/ChatRequestEvent.java | 17 ++- .../flink/agents/api/agents/ReActAgentTest.java | 5 +- .../agents/integration/test/ReActAgentTest.java | 10 +- plan/pom.xml | 16 ++ .../flink/agents/plan/actions/ChatModelAction.java | 148 ++++++++++++++---- python/flink_agents/api/agents/agent.py | 2 + python/flink_agents/api/agents/react_agent.py | 121 ++------------- python/flink_agents/api/agents/types.py | 67 ++++++++ python/flink_agents/api/core_options.py | 24 +++ python/flink_agents/api/events/chat_event.py | 4 + .../e2e_tests_integration/react_agent_test.py | 10 +- .../flink_agents/plan/actions/chat_model_action.py | 94 +++++++++--- .../flink_agents/plan/tests/resources/action.json | 2 +- 17 files changed, 519 insertions(+), 335 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java b/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java index ae046b0..b05dc09 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/Agent.java @@ -109,4 +109,22 @@ public class Agent { } return this; } + + public enum ErrorHandlingStrategy { + FAIL("fail"), + RETRY("retry"), + IGNORE("ignore"); + + private final String value; + + ErrorHandlingStrategy(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + public static String STRUCTURED_OUTPUT = "structured_output"; } diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java similarity index 72% rename from api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java rename to api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java index d3edee7..64880a5 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgentConfigOptions.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java @@ -20,12 +20,13 @@ package org.apache.flink.agents.api.agents; import org.apache.flink.agents.api.configuration.ConfigOption; -/** Config Options for {@link ReActAgent}. */ -public class ReActAgentConfigOptions { - /** The option specifies the error handling strategy for react agent. */ - public static final ConfigOption<ReActAgent.ErrorHandlingStrategy> ERROR_HANDLING_STRATEGY = +public class AgentExecutionOptions { + public static final ConfigOption<Agent.ErrorHandlingStrategy> ERROR_HANDLING_STRATEGY = new ConfigOption<>( "error-handling-strategy", - ReActAgent.ErrorHandlingStrategy.class, - ReActAgent.ErrorHandlingStrategy.FAIL); + Agent.ErrorHandlingStrategy.class, + Agent.ErrorHandlingStrategy.FAIL); + + public static final ConfigOption<Integer> MAX_RETRIES = + new ConfigOption<>("max-retries", Integer.class, 3); } diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/OutputSchema.java b/api/src/main/java/org/apache/flink/agents/api/agents/OutputSchema.java new file mode 100644 index 0000000..54fbcc3 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/agents/OutputSchema.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.agents; + +import com.fasterxml.jackson.core.JacksonException; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import com.fasterxml.jackson.databind.ser.std.StdSerializer; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Helper class for {@link RowTypeInfo} serialization. + * + * <p>Currently, only support row contains basic type. + */ +@VisibleForTesting +@JsonSerialize(using = OutputSchema.OutputSchemaJsonSerializer.class) +@JsonDeserialize(using = OutputSchema.OutputSchemaJsonDeserializer.class) +public class OutputSchema { + private final RowTypeInfo schema; + + public OutputSchema(RowTypeInfo schema) { + this.schema = schema; + for (TypeInformation<?> info : schema.getFieldTypes()) { + if (!info.isBasicType()) { + throw new IllegalArgumentException( + "Currently, output schema only support row contains basic type."); + } + } + } + + public RowTypeInfo getSchema() { + return schema; + } + + public static class OutputSchemaJsonSerializer extends StdSerializer<OutputSchema> { + + protected OutputSchemaJsonSerializer() { + super(OutputSchema.class); + } + + @Override + public void serialize( + OutputSchema schema, + JsonGenerator jsonGenerator, + SerializerProvider serializerProvider) + throws IOException { + RowTypeInfo typeInfo = schema.getSchema(); + jsonGenerator.writeStartObject(); + + jsonGenerator.writeFieldName("fieldNames"); + jsonGenerator.writeStartArray(); + for (String name : typeInfo.getFieldNames()) { + jsonGenerator.writeString(name); + } + jsonGenerator.writeEndArray(); + + // TODO: support type information which is not basic. + jsonGenerator.writeFieldName("types"); + jsonGenerator.writeStartArray(); + for (TypeInformation<?> info : typeInfo.getFieldTypes()) { + jsonGenerator.writeObject(info.getTypeClass()); + } + jsonGenerator.writeEndArray(); + + jsonGenerator.writeEndObject(); + } + } + + public static class OutputSchemaJsonDeserializer extends StdDeserializer<OutputSchema> { + private static final ObjectMapper mapper = new ObjectMapper(); + + protected OutputSchemaJsonDeserializer() { + super(OutputSchema.class); + } + + @Override + public OutputSchema deserialize( + JsonParser jsonParser, DeserializationContext deserializationContext) + throws IOException, JacksonException { + JsonNode node = jsonParser.getCodec().readTree(jsonParser); + List<String> fieldNames = new ArrayList<>(); + node.get("fieldNames").forEach(fieldNameNode -> fieldNames.add(fieldNameNode.asText())); + List<TypeInformation<?>> types = new ArrayList<>(); + node.get("types") + .forEach( + typeNode -> { + try { + types.add( + BasicTypeInfo.getInfoFor( + mapper.treeToValue(typeNode, Class.class))); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + }); + + return new OutputSchema( + new RowTypeInfo( + types.toArray(new TypeInformation[0]), + fieldNames.toArray(new String[0]))); + } + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java index c073f54..278e356 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java @@ -18,19 +18,9 @@ package org.apache.flink.agents.api.agents; -import com.fasterxml.jackson.core.JacksonException; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializerProvider; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.fasterxml.jackson.databind.deser.std.StdDeserializer; -import com.fasterxml.jackson.databind.ser.std.StdSerializer; import org.apache.commons.lang3.ClassUtils; import org.apache.flink.agents.api.InputEvent; import org.apache.flink.agents.api.OutputEvent; @@ -43,9 +33,6 @@ import org.apache.flink.agents.api.event.ChatResponseEvent; import org.apache.flink.agents.api.prompt.Prompt; import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; -import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.types.Row; import org.slf4j.Logger; @@ -53,7 +40,6 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nullable; -import java.io.IOException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.HashMap; @@ -106,16 +92,14 @@ public class ReActAgent extends Agent { try { Method method = - this.getClass() - .getMethod("stopAction", ChatResponseEvent.class, RunnerContext.class); - this.addAction(new Class[] {ChatResponseEvent.class}, method, actionConfig); + this.getClass().getMethod("startAction", InputEvent.class, RunnerContext.class); + this.addAction(new Class[] {InputEvent.class}, method, actionConfig); } catch (NoSuchMethodException e) { throw new IllegalStateException( "Can't find the method stopAction, this must be a bug."); } } - @Action(listenEvents = {InputEvent.class}) public static void startAction(InputEvent event, RunnerContext ctx) { Object input = event.getInput(); @@ -177,151 +161,22 @@ public class ReActAgent extends Agent { inputMessages.addAll(0, instruct); } - ctx.sendEvent(new ChatRequestEvent(DEFAULT_CHAT_MODEL, inputMessages)); - } - - public static void stopAction(ChatResponseEvent event, RunnerContext ctx) - throws JsonProcessingException { - Object output = String.valueOf(event.getResponse().getContent()); - Object outputSchema = ctx.getActionConfigValue("output_schema"); - if (outputSchema != null) { - ErrorHandlingStrategy strategy = - ctx.getConfig().get(ReActAgentConfigOptions.ERROR_HANDLING_STRATEGY); - try { - if (outputSchema instanceof Class) { - output = mapper.readValue(String.valueOf(output), (Class<?>) outputSchema); - } else if (outputSchema instanceof OutputSchema) { - RowTypeInfo info = ((OutputSchema) outputSchema).getSchema(); - Map<String, Object> fields = - mapper.readValue(String.valueOf(output), Map.class); - output = Row.withNames(); - for (String name : info.getFieldNames()) { - ((Row) output).setField(name, fields.get(name)); - } - } - } catch (Exception e) { - if (strategy == ErrorHandlingStrategy.FAIL) { - throw e; - } else if (strategy == ErrorHandlingStrategy.IGNORE) { - LOG.warn( - "The response of llm {} doesn't match schema constraint, ignoring.", - output); - return; - } - } - } - - ctx.sendEvent(new OutputEvent(output)); - } - - /** - * Helper class for {@link RowTypeInfo} serialization. - * - * <p>Currently, only support row contains basic type. - */ - @VisibleForTesting - @JsonSerialize(using = OutputSchemaJsonSerializer.class) - @JsonDeserialize(using = OutputSchemaJsonDeserializer.class) - public static class OutputSchema { - private final RowTypeInfo schema; - - public OutputSchema(RowTypeInfo schema) { - this.schema = schema; - for (TypeInformation<?> info : schema.getFieldTypes()) { - if (!info.isBasicType()) { - throw new IllegalArgumentException( - "Currently, output schema only support row contains basic type."); - } - } - } - - public RowTypeInfo getSchema() { - return schema; - } - } - - public static class OutputSchemaJsonSerializer extends StdSerializer<OutputSchema> { - - protected OutputSchemaJsonSerializer() { - super(OutputSchema.class); - } - - @Override - public void serialize( - OutputSchema schema, - JsonGenerator jsonGenerator, - SerializerProvider serializerProvider) - throws IOException { - RowTypeInfo typeInfo = schema.getSchema(); - jsonGenerator.writeStartObject(); - - jsonGenerator.writeFieldName("fieldNames"); - jsonGenerator.writeStartArray(); - for (String name : typeInfo.getFieldNames()) { - jsonGenerator.writeString(name); - } - jsonGenerator.writeEndArray(); - - // TODO: support type information which is not basic. - jsonGenerator.writeFieldName("types"); - jsonGenerator.writeStartArray(); - for (TypeInformation<?> info : typeInfo.getFieldTypes()) { - jsonGenerator.writeObject(info.getTypeClass()); - } - jsonGenerator.writeEndArray(); - - jsonGenerator.writeEndObject(); - } - } - - public static class OutputSchemaJsonDeserializer extends StdDeserializer<OutputSchema> { - private static final ObjectMapper mapper = new ObjectMapper(); - - protected OutputSchemaJsonDeserializer() { - super(OutputSchema.class); - } - - @Override - public OutputSchema deserialize( - JsonParser jsonParser, DeserializationContext deserializationContext) - throws IOException, JacksonException { - JsonNode node = jsonParser.getCodec().readTree(jsonParser); - List<String> fieldNames = new ArrayList<>(); - node.get("fieldNames").forEach(fieldNameNode -> fieldNames.add(fieldNameNode.asText())); - List<TypeInformation<?>> types = new ArrayList<>(); - node.get("types") - .forEach( - typeNode -> { - try { - types.add( - BasicTypeInfo.getInfoFor( - mapper.treeToValue(typeNode, Class.class))); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - }); - - return new OutputSchema( - new RowTypeInfo( - types.toArray(new TypeInformation[0]), - fieldNames.toArray(new String[0]))); - } + ctx.sendEvent(new ChatRequestEvent(DEFAULT_CHAT_MODEL, inputMessages, outputSchema)); } - public enum ErrorHandlingStrategy { - FAIL("fail"), - IGNORE("ignore"); - - private final String value; + @Action(listenEvents = {ChatResponseEvent.class}) + public static void stopAction(ChatResponseEvent event, RunnerContext ctx) { + ChatMessage response = event.getResponse(); - ErrorHandlingStrategy(String value) { - this.value = value; + Object output; + if (response.getExtraArgs().containsKey(STRUCTURED_OUTPUT)) { + output = response.getExtraArgs().get(STRUCTURED_OUTPUT); + } else { + output = String.valueOf(response.getContent()); } - public String getValue() { - return value; - } + ctx.sendEvent(new OutputEvent(output)); } } diff --git a/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java index e5277d3..cba7b52 100644 --- a/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java +++ b/api/src/main/java/org/apache/flink/agents/api/event/ChatRequestEvent.java @@ -21,15 +21,25 @@ package org.apache.flink.agents.api.event; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.chat.messages.ChatMessage; +import javax.annotation.Nullable; + import java.util.List; +/** Event representing a request for chat. */ public class ChatRequestEvent extends Event { private final String model; private final List<ChatMessage> messages; + private final @Nullable Object outputSchema; - public ChatRequestEvent(String model, List<ChatMessage> messages) { + public ChatRequestEvent( + String model, List<ChatMessage> messages, @Nullable Object outputSchema) { this.model = model; this.messages = messages; + this.outputSchema = outputSchema; + } + + public ChatRequestEvent(String model, List<ChatMessage> messages) { + this(model, messages, null); } public String getModel() { @@ -39,4 +49,9 @@ public class ChatRequestEvent extends Event { public List<ChatMessage> getMessages() { return messages; } + + @Nullable + public Object getOutputSchema() { + return outputSchema; + } } diff --git a/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java b/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java index f5851e1..db237e7 100644 --- a/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/agents/ReActAgentTest.java @@ -36,10 +36,9 @@ public class ReActAgentTest { BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO }, new String[] {"a", "b"}); - ReActAgent.OutputSchema schema = new ReActAgent.OutputSchema(typeInfo); + OutputSchema schema = new OutputSchema(typeInfo); String json = mapper.writeValueAsString(schema); - ReActAgent.OutputSchema deserialized = - mapper.readValue(json, ReActAgent.OutputSchema.class); + OutputSchema deserialized = mapper.readValue(json, OutputSchema.class); Assertions.assertEquals(typeInfo, deserialized.getSchema()); } } diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java index 648b1b1..b4adee8 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java @@ -21,7 +21,6 @@ package org.apache.flink.agents.integration.test; import org.apache.flink.agents.api.AgentsExecutionEnvironment; import org.apache.flink.agents.api.agents.Agent; import org.apache.flink.agents.api.agents.ReActAgent; -import org.apache.flink.agents.api.agents.ReActAgentConfigOptions; import org.apache.flink.agents.api.annotation.ToolParam; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; @@ -50,6 +49,8 @@ import org.junit.jupiter.api.Test; import java.io.IOException; import java.util.List; +import static org.apache.flink.agents.api.agents.AgentExecutionOptions.ERROR_HANDLING_STRATEGY; +import static org.apache.flink.agents.api.agents.AgentExecutionOptions.MAX_RETRIES; import static org.apache.flink.agents.integration.test.OllamaPreparationUtils.pullModel; public class ReActAgentTest { @@ -110,11 +111,8 @@ public class ReActAgentTest { ReActAgentTest.class.getMethod( "multiply", Double.class, Double.class))); - agentsEnv - .getConfig() - .set( - ReActAgentConfigOptions.ERROR_HANDLING_STRATEGY, - ReActAgent.ErrorHandlingStrategy.IGNORE); + agentsEnv.getConfig().set(ERROR_HANDLING_STRATEGY, ReActAgent.ErrorHandlingStrategy.RETRY); + agentsEnv.getConfig().set(MAX_RETRIES, 3); // Declare the ReAct agent. Agent agent = getAgent(); diff --git a/plan/pom.xml b/plan/pom.xml index 2b77646..f6b5df5 100644 --- a/plan/pom.xml +++ b/plan/pom.xml @@ -68,6 +68,22 @@ under the License. <version>2.0.2</version> <scope>test</scope> </dependency> + <!-- LOG --> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + <version>${slf4j.version}</version> + </dependency> + <dependency> + <groupId>org.apache.logging.log4j</groupId> + <artifactId>log4j-core</artifactId> + <version>${log4j2.version}</version> + </dependency> + <dependency> + <groupId>org.apache.logging.log4j</groupId> + <artifactId>log4j-slf4j-impl</artifactId> + <version>${log4j2.version}</version> + </dependency> </dependencies> <profiles> diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index 0fb96e3..aa51fa4 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -17,7 +17,12 @@ */ package org.apache.flink.agents.plan.actions; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.agents.AgentExecutionOptions; +import org.apache.flink.agents.api.agents.OutputSchema; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; @@ -30,15 +35,28 @@ import org.apache.flink.agents.api.event.ToolResponseEvent; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.tools.ToolResponse; import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.types.Row; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; import java.util.*; +import static org.apache.flink.agents.api.agents.Agent.STRUCTURED_OUTPUT; + /** Built-in action for processing chat request and tool call result. */ public class ChatModelAction { + private static final Logger LOG = LoggerFactory.getLogger(ChatModelAction.class); + private static final String TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT"; private static final String TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT"; private static final String INITIAL_REQUEST_ID = "initialRequestId"; private static final String MODEL = "model"; + private static final String OUTPUT_SCHEMA = "outputSchema"; + + private static final ObjectMapper mapper = new ObjectMapper(); public static Action getChatModelAction() throws Exception { return new Action( @@ -76,22 +94,13 @@ public class ChatModelAction { return messageContext; } - @SuppressWarnings("unchecked") - private static void clearToolCallContext(MemoryObject sensoryMem, UUID initialRequestId) - throws Exception { - if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) { - Map<UUID, Object> toolCallContext = - (Map<UUID, Object>) sensoryMem.get(TOOL_CALL_CONTEXT).getValue(); - if (toolCallContext.containsKey(initialRequestId)) { - toolCallContext.remove(initialRequestId); - sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext); - } - } - } - @SuppressWarnings("unchecked") private static void saveToolRequestEventContext( - MemoryObject sensoryMem, UUID toolRequestEventId, UUID initialRequestId, String model) + MemoryObject sensoryMem, + UUID toolRequestEventId, + UUID initialRequestId, + String model, + Object outputSchema) throws Exception { Map<UUID, Object> toolRequestEventContext; if (sensoryMem.isExist(TOOL_REQUEST_EVENT_CONTEXT)) { @@ -100,20 +109,22 @@ public class ChatModelAction { } else { toolRequestEventContext = new HashMap<>(); } - toolRequestEventContext.put( - toolRequestEventId, Map.of(INITIAL_REQUEST_ID, initialRequestId, MODEL, model)); + Map<String, Object> context = new HashMap<>(); + context.put(INITIAL_REQUEST_ID, initialRequestId); + context.put(MODEL, model); + if (outputSchema != null) { + context.put(OUTPUT_SCHEMA, outputSchema); + } + toolRequestEventContext.put(toolRequestEventId, context); sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext); } @SuppressWarnings("unchecked") - private static Map<String, Object> removeToolRequestEventContext( + private static Map<String, Object> getToolRequestEventContext( MemoryObject sensoryMem, UUID requestId) throws Exception { Map<UUID, Object> toolRequestEventContext = (Map<UUID, Object>) sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue(); - Map<String, Object> context = - (Map<String, Object>) toolRequestEventContext.remove(requestId); - sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext); - return context; + return (Map<String, Object>) toolRequestEventContext.remove(requestId); } private static void handleToolCalls( @@ -121,6 +132,7 @@ public class ChatModelAction { UUID initialRequestId, String model, List<ChatMessage> messages, + Object outputSchema, RunnerContext ctx) throws Exception { updateToolCallContext( @@ -132,11 +144,38 @@ public class ChatModelAction { ToolRequestEvent toolRequestEvent = new ToolRequestEvent(model, response.getToolCalls()); saveToolRequestEventContext( - ctx.getSensoryMemory(), toolRequestEvent.getId(), initialRequestId, model); + ctx.getSensoryMemory(), + toolRequestEvent.getId(), + initialRequestId, + model, + outputSchema); ctx.sendEvent(toolRequestEvent); } + @SuppressWarnings("unchecked") + private static ChatMessage generateStructuredOutput(ChatMessage response, Object outputSchema) + throws JsonProcessingException { + String output = response.getContent(); + Object structuredOutput; + if (outputSchema instanceof Class) { + structuredOutput = mapper.readValue(String.valueOf(output), (Class<?>) outputSchema); + } else if (outputSchema instanceof OutputSchema) { + RowTypeInfo info = ((OutputSchema) outputSchema).getSchema(); + Map<String, Object> fields = mapper.readValue(String.valueOf(output), Map.class); + structuredOutput = Row.withNames(); + for (String name : info.getFieldNames()) { + ((Row) structuredOutput).setField(name, fields.get(name)); + } + } else { + throw new RuntimeException( + String.format("Unsupported output schema %s.", outputSchema)); + } + Map<String, Object> extraArgs = new HashMap<>(); + extraArgs.put(STRUCTURED_OUTPUT, structuredOutput); + return new ChatMessage(response.getRole(), output, extraArgs); + } + /** * Chat with chat model. * @@ -148,26 +187,69 @@ public class ChatModelAction { * @param ctx The runner context this function executed in. */ public static void chat( - UUID initialRequestId, String model, List<ChatMessage> messages, RunnerContext ctx) + UUID initialRequestId, + String model, + List<ChatMessage> messages, + @Nullable Object outputSchema, + RunnerContext ctx) throws Exception { BaseChatModelSetup chatModel = (BaseChatModelSetup) ctx.getResource(model, ResourceType.CHAT_MODEL); - ChatMessage response = chatModel.chat(messages, Map.of()); + Agent.ErrorHandlingStrategy strategy = + ctx.getConfig().get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY); + int numRetries = 0; + if (strategy == Agent.ErrorHandlingStrategy.RETRY) { + numRetries = + ctx.getConfig().get(AgentExecutionOptions.MAX_RETRIES) > 0 + ? ctx.getConfig().get(AgentExecutionOptions.MAX_RETRIES) + : 0; + } - if (!response.getToolCalls().isEmpty()) { - handleToolCalls(response, initialRequestId, model, messages, ctx); - } else { - // clean tool call context - clearToolCallContext(ctx.getSensoryMemory(), initialRequestId); + ChatMessage response = null; + for (int attempt = 0; attempt < numRetries + 1; attempt++) { + try { + response = chatModel.chat(messages, Map.of()); + // only generate structured output for final response. + if (outputSchema != null && response.getToolCalls().isEmpty()) { + response = generateStructuredOutput(response, outputSchema); + } + } catch (Exception e) { + if (strategy == Agent.ErrorHandlingStrategy.IGNORE) { + LOG.warn( + "Chat request {} failed with error: {}, ignored.", initialRequestId, e); + return; + } else if (strategy == Agent.ErrorHandlingStrategy.RETRY) { + if (attempt == numRetries) { + throw e; + } + LOG.warn( + "Chat request {} failed with error: {}, retrying {} / {}.", + initialRequestId, + e, + attempt, + numRetries); + } else { + LOG.debug( + "Chat request {} failed, the input chat messages are {}.", + initialRequestId, + messages); + throw e; + } + } + } + + if (!Objects.requireNonNull(response).getToolCalls().isEmpty()) { + handleToolCalls(response, initialRequestId, model, messages, outputSchema, ctx); + } else { ctx.sendEvent(new ChatResponseEvent(initialRequestId, response)); } } private static void processChatRequest(ChatRequestEvent event, RunnerContext ctx) throws Exception { - chat(event.getId(), event.getModel(), event.getMessages(), ctx); + chat(event.getId(), event.getModel(), event.getMessages(), event.getOutputSchema(), ctx); } private static void processToolResponse(ToolResponseEvent event, RunnerContext ctx) @@ -175,11 +257,11 @@ public class ChatModelAction { MemoryObject sensoryMem = ctx.getSensoryMemory(); // get tool request context from memory - Map<String, Object> context = - removeToolRequestEventContext(sensoryMem, event.getRequestId()); + Map<String, Object> context = getToolRequestEventContext(sensoryMem, event.getRequestId()); UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID); String model = (String) context.get(MODEL); + Object outputSchema = context.get(OUTPUT_SCHEMA); Map<String, ToolResponse> responses = event.getResponses(); Map<String, Boolean> success = event.getSuccess(); @@ -212,7 +294,7 @@ public class ChatModelAction { Collections.emptyList(), toolResponseMessages); - chat(initialRequestId, model, messages, ctx); + chat(initialRequestId, model, messages, outputSchema, ctx); } /** diff --git a/python/flink_agents/api/agents/agent.py b/python/flink_agents/api/agents/agent.py index 0324dac..3619338 100644 --- a/python/flink_agents/api/agents/agent.py +++ b/python/flink_agents/api/agents/agent.py @@ -26,6 +26,8 @@ from flink_agents.api.resource import ( ) from flink_agents.api.tools.mcp import MCPServer +STRUCTURED_OUTPUT = "structured_output" + class Agent(ABC): """Base class for defining agent logic. diff --git a/python/flink_agents/api/agents/react_agent.py b/python/flink_agents/api/agents/react_agent.py index 63d319e..f540953 100644 --- a/python/flink_agents/api/agents/react_agent.py +++ b/python/flink_agents/api/agents/react_agent.py @@ -15,24 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -import importlib -import json -import logging -from enum import Enum -from typing import Any, cast +from typing import cast from pydantic import ( BaseModel, - ConfigDict, - model_serializer, - model_validator, ) from pyflink.common import Row -from pyflink.common.typeinfo import BasicType, BasicTypeInfo, RowTypeInfo +from pyflink.common.typeinfo import RowTypeInfo -from flink_agents.api.agents.agent import Agent +from flink_agents.api.agents.agent import STRUCTURED_OUTPUT, Agent +from flink_agents.api.agents.types import OutputSchema from flink_agents.api.chat_message import ChatMessage, MessageRole -from flink_agents.api.configuration import ConfigOption from flink_agents.api.decorators import action from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent from flink_agents.api.events.event import InputEvent, OutputEvent @@ -46,68 +39,6 @@ _DEFAULT_USER_PROMPT = "_default_user_prompt" _OUTPUT_SCHEMA = "_output_schema" -class ErrorHandlingStrategy(Enum): - """Error handling strategy for ReActAgent.""" - - FAIL = "fail" - IGNORE = "ignore" - - -class ReActAgentOptions: - """Config options for ReActAgent.""" - - ERROR_HANDLING_STRATEGY = ConfigOption( - key="error-handling-strategy", - config_type=ErrorHandlingStrategy, - default=ErrorHandlingStrategy.FAIL, - ) - - -class OutputSchema(BaseModel): - """Util class to help serialize and deserialize output schema json.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - output_schema: type[BaseModel] | RowTypeInfo - - @model_serializer - def __custom_serializer(self) -> dict[str, Any]: - if isinstance(self.output_schema, RowTypeInfo): - data = { - "output_schema": { - "names": self.output_schema.get_field_names(), - "types": [ - type._basic_type.value - for type in self.output_schema.get_field_types() - ], - }, - } - else: - data = { - "output_schema": { - "module": self.output_schema.__module__, - "class": self.output_schema.__name__, - } - } - return data - - @model_validator(mode="before") - def __custom_deserialize(self) -> "OutputSchema": - output_schema = self["output_schema"] - if isinstance(output_schema, dict): - if "names" in output_schema: - self["output_schema"] = RowTypeInfo( - field_types=[ - BasicTypeInfo(BasicType(type)) - for type in output_schema["types"] - ], - field_names=output_schema["names"], - ) - else: - module = importlib.import_module(output_schema["module"]) - self["output_schema"] = getattr(module, output_schema["class"]) - return self - - class ReActAgent(Agent): """Built-in implementation of ReAct agent which is based on the function call ability of llm. @@ -204,13 +135,12 @@ class ReActAgent(Agent): self._resources[ResourceType.PROMPT][_DEFAULT_USER_PROMPT] = prompt self.add_action( - name="stop_action", - events=[ChatResponseEvent], - func=self.stop_action, + name="start_action", + events=[InputEvent], + func=self.start_action, output_schema=OutputSchema(output_schema=output_schema), ) - @action(InputEvent) @staticmethod def start_action(event: InputEvent, ctx: RunnerContext) -> None: """Start action to format user input and send chat request event.""" @@ -257,44 +187,25 @@ class ReActAgent(Agent): instruct = schema_prompt.format_messages() usr_msgs = instruct + usr_msgs + output_schema = ctx.get_action_config_value(key="output_schema") + ctx.send_event( ChatRequestEvent( model=_DEFAULT_CHAT_MODEL, messages=usr_msgs, + output_schema=output_schema, ) ) + @action(ChatResponseEvent) @staticmethod def stop_action(event: ChatResponseEvent, ctx: RunnerContext) -> None: """Stop action to output result.""" - output = event.response.content - # parse llm response to target schema. - output_schema = ctx.get_action_config_value(key="output_schema") + response = event.response - error_handling_strategy = ctx.config.get( - ReActAgentOptions.ERROR_HANDLING_STRATEGY - ) - try: - if output_schema: - output_schema = output_schema.output_schema - output = json.loads(output.strip()) - if isinstance(output_schema, type) and issubclass( - output_schema, BaseModel - ): - output = output_schema.model_validate(output) - elif isinstance(output_schema, RowTypeInfo): - field_names = output_schema.get_field_names() - values = {} - for field_name in field_names: - values[field_name] = output[field_name] - output = Row(**values) - except Exception: - if error_handling_strategy == ErrorHandlingStrategy.IGNORE: - logging.warning( - f"The response of llm {output} doesn't match schema constraint, ignoring." - ) - return - elif error_handling_strategy == ErrorHandlingStrategy.FAIL: - raise + if STRUCTURED_OUTPUT in response.extra_args: + output = response.extra_args[STRUCTURED_OUTPUT] + else: + output = response.content ctx.send_event(OutputEvent(output=output)) diff --git a/python/flink_agents/api/agents/types.py b/python/flink_agents/api/agents/types.py new file mode 100644 index 0000000..2fd10b8 --- /dev/null +++ b/python/flink_agents/api/agents/types.py @@ -0,0 +1,67 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import importlib +from typing import Any + +from pydantic import BaseModel, ConfigDict, model_serializer, model_validator +from pyflink.common.typeinfo import BasicType, BasicTypeInfo, RowTypeInfo + + +class OutputSchema(BaseModel): + """Util class to help serialize and deserialize output schema json.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + output_schema: type[BaseModel] | RowTypeInfo + + @model_serializer + def __custom_serializer(self) -> dict[str, Any]: + if isinstance(self.output_schema, RowTypeInfo): + data = { + "output_schema": { + "names": self.output_schema.get_field_names(), + "types": [ + type._basic_type.value + for type in self.output_schema.get_field_types() + ], + }, + } + else: + data = { + "output_schema": { + "module": self.output_schema.__module__, + "class": self.output_schema.__name__, + } + } + return data + + @model_validator(mode="before") + def __custom_deserialize(self) -> "OutputSchema": + output_schema = self["output_schema"] + if isinstance(output_schema, dict): + if "names" in output_schema: + self["output_schema"] = RowTypeInfo( + field_types=[ + BasicTypeInfo(BasicType(type)) + for type in output_schema["types"] + ], + field_names=output_schema["names"], + ) + else: + module = importlib.import_module(output_schema["module"]) + self["output_schema"] = getattr(module, output_schema["class"]) + return self diff --git a/python/flink_agents/api/core_options.py b/python/flink_agents/api/core_options.py index d9ee456..9d8f7ea 100644 --- a/python/flink_agents/api/core_options.py +++ b/python/flink_agents/api/core_options.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +from enum import Enum from typing import Any from pyflink.java_gateway import get_gateway @@ -69,6 +70,17 @@ class AgentConfigOptionsMeta(type): return python_option +class ErrorHandlingStrategy(Enum): + """Error handling strategy for Agent. + + Currently, only works for chat action. + """ + + RETRY = "retry" + FAIL = "fail" + IGNORE = "ignore" + + class AgentConfigOptions(metaclass=AgentConfigOptionsMeta): """CoreOptions to manage core configuration parameters for Flink Agents.""" @@ -77,3 +89,15 @@ class AgentConfigOptions(metaclass=AgentConfigOptionsMeta): config_type=str, default=None, ) + + ERROR_HANDLING_STRATEGY = ConfigOption( + key="error-handling-strategy", + config_type=ErrorHandlingStrategy, + default=ErrorHandlingStrategy.FAIL, + ) + + MAX_RETRIES = ConfigOption( + key="max-retries", + config_type=int, + default=3, + ) diff --git a/python/flink_agents/api/events/chat_event.py b/python/flink_agents/api/events/chat_event.py index 2f4dfb1..2fb4266 100644 --- a/python/flink_agents/api/events/chat_event.py +++ b/python/flink_agents/api/events/chat_event.py @@ -18,6 +18,7 @@ from typing import List from uuid import UUID +from flink_agents.api.agents.react_agent import OutputSchema from flink_agents.api.chat_message import ChatMessage from flink_agents.api.events.event import Event @@ -31,10 +32,13 @@ class ChatRequestEvent(Event): The name of the chat model to be chatted with. messages : List[ChatMessage] The input to the chat model. + output_schema: OutputSchema | None + The expected output schema of the chat model final response. Optional. """ model: str messages: List[ChatMessage] + output_schema: OutputSchema | None = None class ChatResponseEvent(Event): diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py index 630a7f3..7e688b5 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py @@ -28,11 +28,10 @@ from pyflink.datastream import KeySelector, StreamExecutionEnvironment from pyflink.table import DataTypes, Schema, StreamTableEnvironment, TableDescriptor from flink_agents.api.agents.react_agent import ( - ErrorHandlingStrategy, ReActAgent, - ReActAgentOptions, ) from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.core_options import AgentConfigOptions, ErrorHandlingStrategy from flink_agents.api.execution_environment import AgentsExecutionEnvironment from flink_agents.api.prompts.prompt import Prompt from flink_agents.api.resource import ResourceDescriptor @@ -79,8 +78,9 @@ client = pull_model(OLLAMA_MODEL) def test_react_agent_on_local_runner() -> None: # noqa: D103 env = AgentsExecutionEnvironment.get_execution_environment() env.get_config().set( - ReActAgentOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.IGNORE + AgentConfigOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY ) + env.get_config().set(AgentConfigOptions.MAX_RETRIES, 3) # register resource to execution environment ( @@ -155,9 +155,11 @@ def test_react_agent_on_remote_runner(tmp_path: Path) -> None: # noqa: D103 ) env.get_config().set( - ReActAgentOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.IGNORE + AgentConfigOptions.ERROR_HANDLING_STRATEGY, ErrorHandlingStrategy.RETRY ) + env.get_config().set(AgentConfigOptions.MAX_RETRIES, 3) + # register resource to execution environment ( env.add_resource( diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index 7299f13..0db7e5a 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -16,10 +16,19 @@ # limitations under the License. ################################################################################# import copy +import json +import logging from typing import TYPE_CHECKING, Dict, List, cast from uuid import UUID +from pydantic import BaseModel +from pyflink.common import Row +from pyflink.common.typeinfo import RowTypeInfo + +from flink_agents.api.agents.agent import STRUCTURED_OUTPUT +from flink_agents.api.agents.react_agent import OutputSchema from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.core_options import AgentConfigOptions, ErrorHandlingStrategy from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent from flink_agents.api.events.event import Event from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent @@ -35,6 +44,8 @@ if TYPE_CHECKING: _TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT" _TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT" +_logger = logging.getLogger(__name__) + # ============================================================================ # Helper Functions for Tool Call Context Management @@ -69,39 +80,29 @@ def _update_tool_call_context( sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) return tool_call_context[initial_request_id] - -def _clear_tool_call_context( - sensory_memory: MemoryObject, initial_request_id: UUID -) -> None: - """Clear tool call context for a specific request ID.""" - context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {} - if initial_request_id in context: - context.pop(initial_request_id) - sensory_memory.set(_TOOL_CALL_CONTEXT, context) - - def _save_tool_request_event_context( sensory_memory: MemoryObject, tool_request_event_id: UUID, initial_request_id: UUID, model: str, + output_schema: OutputSchema | None, ) -> None: """Save the context for a specific tool request event.""" context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {} context[tool_request_event_id] = { "initial_request_id": initial_request_id, "model": model, + "output_schema": output_schema, } sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context) -def _remove_tool_request_event_context( +def _get_tool_request_event_context( sensory_memory: MemoryObject, request_id: UUID ) -> Dict: """Get and remove the context for a specific tool request event.""" context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {} removed_context = context.pop(request_id, {}) - sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, removed_context) return removed_context @@ -110,6 +111,7 @@ def _handle_tool_calls( initial_request_id: UUID, model: str, messages: List[ChatMessage], + output_schema: OutputSchema | None, ctx: RunnerContext, ) -> None: """Handle tool calls in chat response.""" @@ -124,16 +126,41 @@ def _handle_tool_calls( # save tool request event context _save_tool_request_event_context( - ctx.sensory_memory, tool_request_event.id, initial_request_id, model + ctx.sensory_memory, + tool_request_event.id, + initial_request_id, + model, + output_schema, ) ctx.send_event(tool_request_event) +def _generate_structured_output( + response: ChatMessage, output_schema: OutputSchema +) -> ChatMessage: + """Deserialize output to expected output schema.""" + output_schema = output_schema.output_schema + output = json.loads(response.content.strip()) + + if isinstance(output_schema, type) and issubclass(output_schema, BaseModel): + output = output_schema.model_validate(output) + elif isinstance(output_schema, RowTypeInfo): + field_names = output_schema.get_field_names() + values = {} + for field_name in field_names: + values[field_name] = output[field_name] + output = Row(**values) + response.extra_args[STRUCTURED_OUTPUT] = output + + return response + + def chat( initial_request_id: UUID, model: str, messages: List[ChatMessage], + output_schema: OutputSchema | None, ctx: RunnerContext, ) -> None: """Chat with llm. @@ -146,16 +173,43 @@ def chat( "BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL) ) + error_handling_strategy = ctx.config.get(AgentConfigOptions.ERROR_HANDLING_STRATEGY) + num_retries = 0 + if error_handling_strategy == ErrorHandlingStrategy.RETRY: + num_retries = max(0, ctx.config.get(AgentConfigOptions.MAX_RETRIES)) + # TODO: support async execution of chat. - response = chat_model.chat(messages) + response = None + for attempt in range(num_retries + 1): + try: + response = chat_model.chat(messages) + if output_schema is not None and len(response.tool_calls) == 0: + response = _generate_structured_output(response, output_schema) + except Exception as e: # noqa: PERF203 + if error_handling_strategy == ErrorHandlingStrategy.IGNORE: + _logger.warning( + f"Chat request {initial_request_id} failed with error: {e}, ignored." + ) + return + elif error_handling_strategy == ErrorHandlingStrategy.RETRY: + if attempt == num_retries: + raise + _logger.warning( + f"Chat request {initial_request_id} failed with error: {e}, retrying {attempt} / {num_retries}." + ) + else: + _logger.debug( + f"Chat request {initial_request_id} failed, the input chat messages are {messages}." + ) + raise if ( len(response.tool_calls) > 0 ): # generate tool request event according tool calls in response - _handle_tool_calls(response, initial_request_id, model, messages, ctx) + _handle_tool_calls( + response, initial_request_id, model, messages, output_schema, ctx + ) else: # if there is no tool call generated, return chat response directly - _clear_tool_call_context(ctx.sensory_memory, initial_request_id) - ctx.send_event( ChatResponseEvent( request_id=initial_request_id, @@ -170,6 +224,7 @@ def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None: initial_request_id=event.id, model=event.model, messages=event.messages, + output_schema=event.output_schema, ctx=ctx, ) @@ -180,7 +235,7 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None request_id = event.request_id # get correspond tool request event context - tool_request_event_context = _remove_tool_request_event_context( + tool_request_event_context = _get_tool_request_event_context( sensory_memory, request_id ) initial_request_id = tool_request_event_context["initial_request_id"] @@ -206,6 +261,7 @@ def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None initial_request_id=initial_request_id, model=tool_request_event_context["model"], messages=messages, + output_schema=tool_request_event_context["output_schema"], ctx=ctx, ) diff --git a/python/flink_agents/plan/tests/resources/action.json b/python/flink_agents/plan/tests/resources/action.json index 2c52d83..e6de190 100644 --- a/python/flink_agents/plan/tests/resources/action.json +++ b/python/flink_agents/plan/tests/resources/action.json @@ -11,7 +11,7 @@ "config": { "__config_type__": "python", "output_schema": [ - "flink_agents.api.agents.react_agent", + "flink_agents.api.agents.types", "OutputSchema", { "output_schema": {
