This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit ca315e758b5d3e16dbf03b1b2bf8237c5e22c605 Author: WenjinXie <[email protected]> AuthorDate: Fri Jan 9 17:58:43 2026 +0800 [runtime][java] Support use Long-Term Memory in action. --- .../flink/agents/api/context/RunnerContext.java | 11 + .../pom.xml | 15 + .../test/VectorStoreLongTermMemoryAgent.java | 182 +++++++++++ .../test/VectorStoreLongTermMemoryTest.java | 333 +++++++++++++++++++++ .../src/test/resources/input_data.txt | 10 + .../elasticsearch/ElasticsearchVectorStore.java | 4 +- .../agents/runtime/context/RunnerContextImpl.java | 35 ++- .../runtime/operator/ActionExecutionOperator.java | 25 +- .../python/context/PythonRunnerContextImpl.java | 5 +- .../flink/agents/runtime/memory/MemoryRefTest.java | 9 + 10 files changed, 618 insertions(+), 11 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java index 6c1bd02..c748169 100644 --- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java +++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.api.context; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.configuration.ReadableConfiguration; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceType; @@ -56,6 +57,13 @@ public interface RunnerContext { */ MemoryObject getShortTermMemory() throws Exception; + /** + * Gets the long-term memory. + * + * @return The long-term memory instance + */ + BaseLongTermMemory getLongTermMemory() throws Exception; + /** * Gets the metric group for Flink Agents. * @@ -100,4 +108,7 @@ public interface RunnerContext { * @return the option value of the action config. */ Object getActionConfigValue(String key); + + /** Clean up the resource. */ + void close() throws Exception; } diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml index 2adf275..27e7d7a 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml +++ b/e2e-test/flink-agents-end-to-end-tests-integration/pom.xml @@ -61,6 +61,16 @@ under the License. <artifactId>flink-clients</artifactId> <version>${flink.version}</version> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-connector-files</artifactId> + <version>${flink.version}</version> + </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-json</artifactId> + <version>${flink.version}</version> + </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-agents-integrations-chat-models-anthropic</artifactId> @@ -91,6 +101,11 @@ under the License. <artifactId>flink-agents-integrations-vector-stores-elasticsearch</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>com.fasterxml.jackson.datatype</groupId> + <artifactId>jackson-datatype-jsr310</artifactId> + <version>${jackson.version}</version> + </dependency> </dependencies> <profiles> diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/VectorStoreLongTermMemoryAgent.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/VectorStoreLongTermMemoryAgent.java new file mode 100644 index 0000000..b1dbf53 --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/VectorStoreLongTermMemoryAgent.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integration.test; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.annotation.ChatModelConnection; +import org.apache.flink.agents.api.annotation.ChatModelSetup; +import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; +import org.apache.flink.agents.api.annotation.EmbeddingModelSetup; +import org.apache.flink.agents.api.annotation.VectorStore; +import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection; +import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelConnection; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelSetup; +import org.apache.flink.agents.integrations.vectorstores.elasticsearch.ElasticsearchVectorStore; +import org.junit.jupiter.api.Assertions; + +import java.util.Collections; +import java.util.List; + +public class VectorStoreLongTermMemoryAgent extends Agent { + public static final ObjectMapper mapper = new ObjectMapper(); + + static { + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + mapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); + mapper.registerModule(new JavaTimeModule()); + } + + /** Data model representing a product review. */ + @JsonSerialize + @JsonDeserialize + public static class ProductReview { + private final int id; + private final String review; + + @JsonCreator + public ProductReview(@JsonProperty("id") int id, @JsonProperty("review") String review) { + this.id = id; + this.review = review; + } + + public int getId() { + return id; + } + + public String getReview() { + return review; + } + + @Override + public String toString() { + return String.format("ProductReview{id='%s', review='%s'}", id, review); + } + } + + /** Custom event type for internal agent communication. */ + public static class MyEvent extends Event { + private final int recordId; + + public MyEvent(int recordId) { + this.recordId = recordId; + } + + public int getRecordId() { + return recordId; + } + } + + private static final String CHAT_MODEL = "qwen3:8b"; + public static final String EMBED_MODEL = "nomic-embed-text"; + + @ChatModelConnection + public static ResourceDescriptor chatModelConnection() { + return ResourceDescriptor.Builder.newBuilder(OllamaChatModelConnection.class.getName()) + .addInitialArgument("endpoint", "http://localhost:11434") + .addInitialArgument("requestTimeout", 240) + .build(); + } + + @ChatModelSetup + public static ResourceDescriptor ollamaQwen3() { + return ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) + .addInitialArgument("connection", "chatModelConnection") + .addInitialArgument("model", CHAT_MODEL) + .build(); + } + + @EmbeddingModelConnection + public static ResourceDescriptor embeddingConnection() { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelConnection.class.getName()) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("timeout", 120) + .build(); + } + + @EmbeddingModelSetup + public static ResourceDescriptor embeddingModel() { + return ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "embeddingConnection") + .addInitialArgument("model", EMBED_MODEL) + .build(); + } + + @VectorStore + public static ResourceDescriptor vectorStore() { + return ResourceDescriptor.Builder.newBuilder(ElasticsearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "embeddingModel") + .addInitialArgument("host", "localhost:9200") + .addInitialArgument("dims", 768) + .build(); + } + + @Action(listenEvents = {InputEvent.class}) + public static void addItems(InputEvent event, RunnerContext ctx) throws Exception { + BaseLongTermMemory ltm = ctx.getLongTermMemory(); + MemorySet memorySet = + ltm.getOrCreateMemorySet( + "test-ltm", String.class, 5, new SummarizationStrategy("ollamaQwen3", 1)); + ProductReview review = (ProductReview) event.getInput(); + memorySet.add(Collections.singletonList(review.getReview()), null, null); + + MemoryObject stm = ctx.getShortTermMemory(); + + if (stm.isExist("count")) { + int count = (int) stm.get("count").getValue(); + stm.set("count", count + 1); + } else { + stm.set("count", 1); + } + + ctx.sendEvent(new MyEvent(review.id)); + } + + @Action(listenEvents = {MyEvent.class}) + public static void retrieveItems(MyEvent event, RunnerContext ctx) throws Exception { + int count = (int) ctx.getShortTermMemory().get("count").getValue(); + + BaseLongTermMemory ltm = ctx.getLongTermMemory(); + MemorySet memorySet = ltm.getMemorySet("test-ltm"); + + Assertions.assertEquals(String.class, memorySet.getItemType()); + Assertions.assertEquals(count, memorySet.size()); + + List<MemorySetItem> items = memorySet.get(null); + ctx.sendEvent(new OutputEvent(mapper.writeValueAsString(items))); + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/VectorStoreLongTermMemoryTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/VectorStoreLongTermMemoryTest.java new file mode 100644 index 0000000..fe86164 --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/VectorStoreLongTermMemoryTest.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.integration.test; + +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.configuration.AgentConfigOptions; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelConnection; +import org.apache.flink.agents.integrations.chatmodels.ollama.OllamaChatModelSetup; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelConnection; +import org.apache.flink.agents.integrations.embeddingmodels.ollama.OllamaEmbeddingModelSetup; +import org.apache.flink.agents.integrations.vectorstores.elasticsearch.ElasticsearchVectorStore; +import org.apache.flink.agents.plan.AgentConfiguration; +import org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory; +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.connector.file.src.reader.TextLineInputFormat; +import org.apache.flink.core.fs.Path; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.mockito.Mockito; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * Test for {@link VectorStoreLongTermMemory} + * + * <p>We use {@link ElasticsearchVectorStore} as the backend of Long-Term Memory, so need setup + * Elasticsearch server to run this test. Look <a + * href="https://www.elastic.co/docs/deploy-manage/deploy/self-managed/install-elasticsearch-docker-basic">Start + * a single-node cluster in Docker</a> for details. + * + * <p>For {@link ElasticsearchVectorStore} doesn't support security check yet, when start the + * container, should add "-e xpack.security.enabled=false" option. + */ +@Disabled("Should setup Elasticsearch server.") +public class VectorStoreLongTermMemoryTest { + private static final Logger LOG = LoggerFactory.getLogger(VectorStoreLongTermMemoryTest.class); + + private static final String NAME = "chat-history"; + private final VectorStoreLongTermMemory ltm; + private MemorySet memorySet; + private List<ChatMessage> messages; + + public static Resource getResource(String name, ResourceType type) { + if (type == ResourceType.CHAT_MODEL_CONNECTION) { + return new OllamaChatModelConnection( + ResourceDescriptor.Builder.newBuilder(OllamaChatModelConnection.class.getName()) + .addInitialArgument("endpoint", "http://localhost:11434") + .addInitialArgument("requestTimeout", 240) + .build(), + VectorStoreLongTermMemoryTest::getResource); + } else if (type == ResourceType.CHAT_MODEL) { + return new OllamaChatModelSetup( + ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName()) + .addInitialArgument("connection", "ollama-connection") + .addInitialArgument("model", "qwen3:8b") + .build(), + VectorStoreLongTermMemoryTest::getResource); + } else if (type == ResourceType.EMBEDDING_MODEL_CONNECTION) { + return new OllamaEmbeddingModelConnection( + ResourceDescriptor.Builder.newBuilder( + OllamaEmbeddingModelConnection.class.getName()) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("timeout", 120) + .build(), + VectorStoreLongTermMemoryTest::getResource); + } else if (type == ResourceType.EMBEDDING_MODEL) { + return new OllamaEmbeddingModelSetup( + ResourceDescriptor.Builder.newBuilder(OllamaEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "embed-connection") + .addInitialArgument("model", "nomic-embed-text") + .build(), + VectorStoreLongTermMemoryTest::getResource); + } else { + return new ElasticsearchVectorStore( + ResourceDescriptor.Builder.newBuilder(ElasticsearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "embed-setup") + .addInitialArgument("host", "localhost:9200") + .addInitialArgument("dims", 768) + .build(), + VectorStoreLongTermMemoryTest::getResource); + } + } + + public VectorStoreLongTermMemoryTest() throws Exception { + RunnerContext ctx = Mockito.mock(RunnerContext.class); + + AgentConfiguration config = new AgentConfiguration(); + config.set(LongTermMemoryOptions.ASYNC_COMPACTION, false); + Mockito.when(ctx.getConfig()).thenReturn(config); + + Mockito.when(ctx.getResource("ollama-connection", ResourceType.CHAT_MODEL_CONNECTION)) + .thenReturn(getResource("ollama-connection", ResourceType.CHAT_MODEL_CONNECTION)); + + Mockito.when(ctx.getResource("ollama-setup", ResourceType.CHAT_MODEL)) + .thenReturn(getResource("ollama-setup", ResourceType.CHAT_MODEL)); + + Mockito.when(ctx.getResource("embed-connection", ResourceType.EMBEDDING_MODEL_CONNECTION)) + .thenReturn( + getResource("embed-connection", ResourceType.EMBEDDING_MODEL_CONNECTION)); + + Mockito.when(ctx.getResource("embed-setup", ResourceType.EMBEDDING_MODEL)) + .thenReturn(getResource("embed-setup", ResourceType.EMBEDDING_MODEL)); + + Mockito.when(ctx.getResource("vector-store", ResourceType.VECTOR_STORE)) + .thenReturn(getResource("vector-store", ResourceType.VECTOR_STORE)); + + ltm = new VectorStoreLongTermMemory(ctx, "vector-store", "job-0001", "0001"); + } + + @BeforeEach + public void prepare(TestInfo info) throws Exception { + messages = new ArrayList<>(); + if (info.getTags().contains("skipBeforeEach")) { + return; + } + memorySet = + ltm.getOrCreateMemorySet( + NAME, + ChatMessage.class, + 100, + new SummarizationStrategy("ollama-setup", null, 1)); + for (int i = 0; i < 10; i++) { + messages.add( + new ChatMessage( + MessageRole.USER, String.format("This is the no.%s message", i))); + } + memorySet.add(messages, null, null); + } + + @AfterEach + public void cleanUp(TestInfo info) throws Exception { + if (info.getTags().contains("skipAfterEach")) { + return; + } + ltm.deleteMemorySet(NAME); + } + + @Test + public void testGetMemorySet() throws Exception { + MemorySet retrieved = ltm.getMemorySet(memorySet.getName()); + + Assertions.assertEquals(memorySet, retrieved); + } + + @Test + public void testAddAndGet() throws Exception { + List<MemorySetItem> items = memorySet.get(null); + List<ChatMessage> retrieved = + items.stream().map(x -> (ChatMessage) x.getValue()).collect(Collectors.toList()); + Assertions.assertEquals(messages, retrieved); + } + + @Test + public void testSearch() throws Exception { + List<MemorySetItem> items = memorySet.search("The no.5 message", 1, Collections.emptyMap()); + + List<ChatMessage> retrieved = + items.stream().map(x -> (ChatMessage) x.getValue()).collect(Collectors.toList()); + + Assertions.assertEquals(1, retrieved.size()); + Assertions.assertEquals(messages.get(5), retrieved.get(0)); + } + + @Test + @Tag("skipBeforeEach") + public void testCompact() throws Exception { + memorySet = + ltm.getOrCreateMemorySet( + NAME, + ChatMessage.class, + 8, + new SummarizationStrategy("ollama-setup", null, 2)); + messages.add(ChatMessage.user("What is flink?")); + messages.add( + ChatMessage.assistant( + "Apache Flink is a framework and distributed processing engine for stateful computations over unbounded and bounded data streams. Flink has been designed to run in all common cluster environments, perform computations at in-memory speed and at any scale.")); + messages.add(ChatMessage.user("What is flink agents?")); + messages.add( + ChatMessage.assistant( + "Apache Flink Agents is a brand-new sub-project from the Apache Flink community, providing an open-source framework for building event-driven streaming agents.")); + messages.add(ChatMessage.user("What's the whether tomorrow in london?")); + messages.add( + ChatMessage.assistant( + "", + Collections.singletonList( + Map.of( + "id", + "186780f8-c79d-4159-83e3-f65859835b14", + "type", + "function", + "function", + Map.of( + "name", + "get_weather", + "arguments", + Map.of( + "position", + "london", + "time", + "tomorrow")))))); + messages.add(ChatMessage.tool("snow")); + messages.add(ChatMessage.assistant("Tomorrow weather for london is snow.")); + memorySet.add(messages, null, null); + + List<MemorySetItem> items = memorySet.get(null); + List<ChatMessage> retrieved = + items.stream().map(x -> (ChatMessage) x.getValue()).collect(Collectors.toList()); + + Assertions.assertEquals(2, items.size()); + Assertions.assertTrue(items.get(0).isCompacted()); + Assertions.assertTrue(items.get(1).isCompacted()); + Assertions.assertInstanceOf( + MemorySetItem.DateTimeRange.class, items.get(0).getCreatedTime()); + Assertions.assertInstanceOf( + MemorySetItem.DateTimeRange.class, items.get(1).getCreatedTime()); + Assertions.assertEquals(2, memorySet.size()); + } + + @Test + @Tag("skipBeforeEach") + @Tag("skipAfterEach") + public void testUsingLtmInAction() throws Exception { + ElasticsearchVectorStore es = + new ElasticsearchVectorStore( + ResourceDescriptor.Builder.newBuilder( + ElasticsearchVectorStore.class.getName()) + .addInitialArgument("embedding_model", "embed-setup") + .addInitialArgument("host", "localhost:9200") + .addInitialArgument("dims", 768) + .build(), + VectorStoreLongTermMemoryTest::getResource); + try { + // Set up the Flink streaming environment and the Agents execution environment. + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + agentsEnv.getConfig().set(AgentConfigOptions.JOB_IDENTIFIER, "ltm_test_job"); + agentsEnv + .getConfig() + .set( + LongTermMemoryOptions.BACKEND, + LongTermMemoryOptions.LongTermMemoryBackend.EXTERNAL_VECTOR_STORE); + agentsEnv + .getConfig() + .set(LongTermMemoryOptions.EXTERNAL_VECTOR_STORE_NAME, "vectorStore"); + agentsEnv.getConfig().set(LongTermMemoryOptions.ASYNC_COMPACTION, true); + + DataStream<String> inputStream = + env.fromSource( + FileSource.forRecordStreamFormat( + new TextLineInputFormat(), + new Path( + Objects.requireNonNull( + this.getClass() + .getClassLoader() + .getResource( + "input_data.txt")) + .getPath())) + .build(), + WatermarkStrategy.noWatermarks(), + "ltm-test-agent"); + DataStream<VectorStoreLongTermMemoryAgent.ProductReview> reviewDataStream = + inputStream.map( + x -> + VectorStoreLongTermMemoryAgent.mapper.readValue( + x, VectorStoreLongTermMemoryAgent.ProductReview.class)); + + // Use the ReviewAnalysisAgent to analyze each product review. + DataStream<Object> outputStream = + agentsEnv + .fromDataStream( + reviewDataStream, + VectorStoreLongTermMemoryAgent.ProductReview::getId) + .apply(new VectorStoreLongTermMemoryAgent()) + .toDataStream(); + + // Print the analysis results to stdout. + outputStream.print(); + + // Execute the Flink pipeline. + agentsEnv.execute(); + + // check async compaction + LOG.debug(es.get(null, "ltm_test_job-2-test-ltm", Collections.emptyMap()).toString()); + Assertions.assertEquals(1, es.size("ltm_test_job-2-test-ltm")); + } finally { + es.deleteCollection("ltm_test_job-1-test-ltm"); + es.deleteCollection("ltm_test_job-2-test-ltm"); + es.deleteCollection("ltm_test_job-3-test-ltm"); + } + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/input_data.txt b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/input_data.txt new file mode 100644 index 0000000..13efe8e --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/resources/input_data.txt @@ -0,0 +1,10 @@ +{"id":1,"review":"Great product! Works perfectly and lasts a long time.","review_score":3.0} +{"id":2,"review":"The item arrived damaged, and the packaging was poor.","review_score":5.0} +{"id":3,"review":"Highly satisfied with the performance and value for money.","review_score":7.0} +{"id":3,"review":"Not as good as expected. It stopped working after a week.","review_score":8.0} +{"id":1,"review":"Fast shipping and excellent customer service. Would buy again!","review_score":8.0} +{"id":2,"review":"Too complicated to set up. Instructions were unclear.","review_score":8.0} +{"id":2,"review":"Good quality, but overview_scored for what it does.","review_score":8.0} +{"id":1,"review":"Exactly what I needed. Easy to use and very reliable.","review_score":8.0} +{"id":2,"review":"Worst purchase ever. Waste of money and time.","review_score":8.0} +{"id":2,"review":"Looks nice and functions well, but could be more durable.","review_score":8.0} \ No newline at end of file diff --git a/integrations/vector-stores/elasticsearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStore.java b/integrations/vector-stores/elasticsearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStore.java index 022f8df..cf09d55 100644 --- a/integrations/vector-stores/elasticsearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStore.java +++ b/integrations/vector-stores/elasticsearch/src/main/java/org/apache/flink/agents/integrations/vectorstores/elasticsearch/ElasticsearchVectorStore.java @@ -19,6 +19,7 @@ package org.apache.flink.agents.integrations.vectorstores.elasticsearch; import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch._types.Refresh; import co.elastic.clients.elasticsearch._types.mapping.DynamicMapping; import co.elastic.clients.elasticsearch._types.mapping.Property; import co.elastic.clients.elasticsearch.core.BulkRequest; @@ -788,7 +789,8 @@ public class ElasticsearchVectorStore extends BaseVectorStore } // Execute bulk request - BulkRequest bulkRequest = BulkRequest.of(br -> br.operations(bulkOperations)); + BulkRequest bulkRequest = + BulkRequest.of(br -> br.operations(bulkOperations).refresh(Refresh.WaitFor)); BulkResponse bulkResponse = this.client.bulk(bulkRequest); // Check for errors diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index 3367398..549359f 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -23,12 +23,16 @@ import org.apache.flink.agents.api.configuration.ReadableConfiguration; import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.MemoryUpdate; import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.plan.utils.JsonUtils; import org.apache.flink.agents.runtime.memory.CachedMemoryStore; +import org.apache.flink.agents.runtime.memory.InteranlBaseLongTermMemory; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; +import org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; import org.apache.flink.util.Preconditions; @@ -80,19 +84,32 @@ public class RunnerContextImpl implements RunnerContext { protected MemoryContext memoryContext; protected String actionName; + protected InteranlBaseLongTermMemory ltm; public RunnerContextImpl( FlinkAgentsMetricGroupImpl agentMetricGroup, Runnable mailboxThreadChecker, - AgentPlan agentPlan) { + AgentPlan agentPlan, + String jobIdentifier) { this.agentMetricGroup = agentMetricGroup; this.mailboxThreadChecker = mailboxThreadChecker; this.agentPlan = agentPlan; + + LongTermMemoryOptions.LongTermMemoryBackend backend = + this.getConfig().get(LongTermMemoryOptions.BACKEND); + if (backend == LongTermMemoryOptions.LongTermMemoryBackend.EXTERNAL_VECTOR_STORE) { + String vectorStoreName = + this.getConfig().get(LongTermMemoryOptions.EXTERNAL_VECTOR_STORE_NAME); + ltm = new VectorStoreLongTermMemory(this, vectorStoreName, jobIdentifier); + } } - public void switchActionContext(String actionName, MemoryContext memoryContext) { + public void switchActionContext(String actionName, MemoryContext memoryContext, String key) { this.actionName = actionName; this.memoryContext = memoryContext; + if (ltm != null) { + ltm.switchContext(key); + } } public MemoryContext getMemoryContext() { @@ -176,6 +193,12 @@ public class RunnerContextImpl implements RunnerContext { memoryContext.getShortTermMemoryUpdates()); } + @Override + public BaseLongTermMemory getLongTermMemory() throws Exception { + Preconditions.checkNotNull(this.ltm); + return this.ltm; + } + @Override public Resource getResource(String name, ResourceType type) throws Exception { if (agentPlan == null) { @@ -202,6 +225,14 @@ public class RunnerContextImpl implements RunnerContext { return agentPlan.getActionConfigValue(actionName, key); } + @Override + public void close() throws Exception { + if (this.ltm != null) { + this.ltm.close(); + this.ltm = null; + } + } + public String getActionName() { return actionName; } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 084ad1f..70a49c0 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -438,7 +438,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT } // 2. Invoke the action task. - createAndSetRunnerContext(actionTask); + createAndSetRunnerContext(actionTask, key); long sequenceNumber = sequenceNumberKState.value(); boolean isFinished; @@ -587,7 +587,10 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT pythonInterpreter = env.getInterpreter(); pythonRunnerContext = new PythonRunnerContextImpl( - this.metricGroup, this::checkMailboxThread, this.agentPlan); + this.metricGroup, + this::checkMailboxThread, + this.agentPlan, + this.jobIdentifier); javaResourceAdapter = new JavaResourceAdapter(this::getResource, pythonInterpreter); if (containPythonResource) { @@ -655,6 +658,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT if (actionStateStore != null) { actionStateStore.close(); } + if (runnerContext != null) { + runnerContext.close(); + } super.close(); } @@ -785,7 +791,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT } } - private void createAndSetRunnerContext(ActionTask actionTask) { + private void createAndSetRunnerContext(ActionTask actionTask, Object key) { if (actionTask.getRunnerContext() != null) { return; } @@ -812,7 +818,8 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT new CachedMemoryStore(shortTermMemState)); } - runnerContext.switchActionContext(actionTask.action.getName(), memoryContext); + runnerContext.switchActionContext( + actionTask.action.getName(), memoryContext, String.valueOf(key.hashCode())); actionTask.setRunnerContext(runnerContext); } @@ -924,14 +931,20 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT if (runnerContext == null) { runnerContext = new RunnerContextImpl( - this.metricGroup, this::checkMailboxThread, this.agentPlan); + this.metricGroup, + this::checkMailboxThread, + this.agentPlan, + this.jobIdentifier); } return runnerContext; } else { if (pythonRunnerContext == null) { pythonRunnerContext = new PythonRunnerContextImpl( - this.metricGroup, this::checkMailboxThread, this.agentPlan); + this.metricGroup, + this::checkMailboxThread, + this.agentPlan, + jobIdentifier); } return pythonRunnerContext; } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java index 690d412..7df56e5 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java @@ -33,8 +33,9 @@ public class PythonRunnerContextImpl extends RunnerContextImpl { public PythonRunnerContextImpl( FlinkAgentsMetricGroupImpl agentMetricGroup, Runnable mailboxThreadChecker, - AgentPlan agentPlan) { - super(agentMetricGroup, mailboxThreadChecker, agentPlan); + AgentPlan agentPlan, + String jobIdentifier) { + super(agentMetricGroup, mailboxThreadChecker, agentPlan, jobIdentifier); } @Override diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java index 1f27c82..52a5ff3 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java @@ -21,6 +21,7 @@ import org.apache.flink.agents.api.configuration.ReadableConfiguration; import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.MemoryRef; import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceType; @@ -73,6 +74,11 @@ public class MemoryRefTest { return memoryObject; } + @Override + public BaseLongTermMemory getLongTermMemory() throws Exception { + return null; + } + @Override public MemoryObject getSensoryMemory() { return null; @@ -110,6 +116,9 @@ public class MemoryRefTest { public Object getActionConfigValue(String key) { return null; } + + @Override + public void close() throws Exception {} } @BeforeEach
