This is an automated email from the ASF dual-hosted git repository. wenjin272 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 3a828071ba61687f46a0e482df1a1fc7de217316 Author: WenjinXie <[email protected]> AuthorDate: Tue May 5 01:37:10 2026 +0800 [api][runtime] Implement Mem0LongTermMemory on the Java side Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> --- .../agents/api/memory/BaseLongTermMemory.java | 101 +++----- .../agents/api/memory/LongTermMemoryOptions.java | 35 +-- .../apache/flink/agents/api/memory/MemorySet.java | 120 ++++------ .../flink/agents/api/memory/MemorySetItem.java | 64 ++--- .../api/memory/compaction/CompactionConfig.java | 92 -------- .../flink/agents/api/memory/MemorySetTest.java | 10 +- .../pom.xml | 11 + .../resource/test/Mem0LongTermMemoryAgent.java | 260 +++++++++++++++++++++ .../resource/test/Mem0LongTermMemoryTest.java | 182 +++++++++++++++ .../apache/flink/agents/plan/actions/Utils.java | 3 + python/flink_agents/runtime/python_java_utils.py | 37 +++ runtime/pom.xml | 8 + .../agents/runtime/context/RunnerContextImpl.java | 16 +- .../agents/runtime/memory/Mem0LongTermMemory.java | 182 +++++++++++++++ .../runtime/operator/ActionExecutionOperator.java | 67 +++++- .../runtime/python/utils/PythonActionExecutor.java | 4 + .../runtime/memory/Mem0LongTermMemoryTest.java | 240 +++++++++++++++++++ 17 files changed, 1107 insertions(+), 325 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/BaseLongTermMemory.java b/api/src/main/java/org/apache/flink/agents/api/memory/BaseLongTermMemory.java index 9a303ce7..4e35e490 100644 --- a/api/src/main/java/org/apache/flink/agents/api/memory/BaseLongTermMemory.java +++ b/api/src/main/java/org/apache/flink/agents/api/memory/BaseLongTermMemory.java @@ -17,117 +17,88 @@ */ package org.apache.flink.agents.api.memory; -import org.apache.flink.agents.api.memory.compaction.CompactionConfig; - import javax.annotation.Nullable; import java.util.List; import java.util.Map; /** - * Base interface for long-term memory management. It provides operations to create, retrieve, - * delete, and search memory sets, which are collections of memory items. A memory set can store - * items of a specific type (e.g., String or ChatMessage) and has a capacity limit. When the - * capacity is exceeded, compaction will be applied to manage the memory set size. + * Base interface for long-term memory management. Provides operations to create, retrieve, delete, + * and search memory sets, which are named collections of memory items. */ public interface BaseLongTermMemory extends AutoCloseable { /** - * Gets an existing memory set or creates a new one if it doesn't exist. + * Gets the memory set by name. If it does not exist, the backend creates it. * * @param name the name of the memory set - * @param itemType the type of items stored in the memory set - * @param capacity the maximum number of items the memory set can hold - * @param compactionConfig the compaction config to use when the capacity is exceeded - * @return the existing or newly created memory set - * @throws Exception if the memory set cannot be created or retrieved - */ - MemorySet getOrCreateMemorySet( - String name, Class<?> itemType, int capacity, CompactionConfig compactionConfig) - throws Exception; - - /** - * Gets an existing memory set by name. - * - * @param name the name of the memory set to retrieve - * @return the memory set with the given name - * @throws Exception if the memory set does not exist or cannot be retrieved + * @return the memory set */ MemorySet getMemorySet(String name) throws Exception; /** - * Deletes a memory set by name. + * Deletes the memory set. * * @param name the name of the memory set to delete - * @return true if the memory set was successfully deleted, false if it didn't exist - * @throws Exception if the deletion operation fails + * @return true if the memory set was successfully deleted */ boolean deleteMemorySet(String name) throws Exception; /** - * Gets the number of items in the memory set. - * - * @param memorySet the memory set to count items in - * @return the number of items in the memory set - * @throws Exception if the size cannot be determined - */ - long size(MemorySet memorySet) throws Exception; - - /** - * Adds items to the memory set. If IDs are not provided, they will be automatically generated. - * This method may trigger compaction if the memory set capacity is exceeded. + * Adds items to the memory set. The backend may auto-generate IDs. * * @param memorySet the memory set to add items to - * @param memoryItems the items to be added to the memory set - * @param ids optional list of IDs for the items. If null or shorter than memoryItems, IDs will - * be auto-generated for missing items - * @param metadatas optional list of metadata maps for the items. Each metadata map corresponds - * to an item at the same index + * @param memoryItems the items to add + * @param metadatas optional list of metadata maps, one per item * @return list of IDs of the added items - * @throws Exception if items cannot be added to the memory set */ List<String> add( MemorySet memorySet, - List<?> memoryItems, - @Nullable List<String> ids, + List<String> memoryItems, @Nullable List<Map<String, Object>> metadatas) throws Exception; /** - * Retrieves memory items from the memory set. If no IDs are provided, all items in the memory - * set are returned. + * Retrieves memory items. When {@code ids} is provided, {@code filters} and {@code limit} are + * ignored. * - * @param memorySet the memory set to retrieve items from - * @param ids optional list of item IDs to retrieve. If null, all items are returned - * @return list of memory set items. If ids is provided, returns items matching those IDs. If - * ids is null, returns all items in the memory set - * @throws Exception if items cannot be retrieved from the memory set + * @param memorySet the memory set to retrieve from + * @param ids optional list of item IDs to retrieve + * @param filters optional metadata filters + * @param limit maximum number of items to return; defaults to 100 when {@code null} + * @return list of matching memory items */ - List<MemorySetItem> get(MemorySet memorySet, @Nullable List<String> ids) throws Exception; + List<MemorySetItem> get( + MemorySet memorySet, + @Nullable List<String> ids, + @Nullable Map<String, Object> filters, + @Nullable Integer limit) + throws Exception; /** - * Deletes memory items from the memory set. If no IDs are provided, all items in the memory set - * are deleted. + * Deletes memory items. If {@code ids} is null, all items in the set are deleted. * * @param memorySet the memory set to delete items from - * @param ids optional list of item IDs to delete. If null, all items in the memory set are - * deleted - * @throws Exception if items cannot be deleted from the memory set + * @param ids optional list of item IDs to delete */ void delete(MemorySet memorySet, @Nullable List<String> ids) throws Exception; /** - * Performs semantic search on the memory set to find items related to the query string. + * Performs semantic search on the memory set. * * @param memorySet the memory set to search in * @param query the query string for semantic search - * @param limit the maximum number of items to return - * @param extraArgs additional arguments for the search operation (e.g., filters, distance - * metrics) - * @return list of memory set items that are most relevant to the query, ordered by relevance - * @throws Exception if the search operation fails + * @param limit maximum number of items to return + * @param filters optional metadata filters + * @param extraArgs backend-specific extra arguments forwarded as keyword arguments to the + * underlying search call (mirrors Python's {@code **kwargs}) + * @return list of memory items most relevant to the query, ordered by relevance */ List<MemorySetItem> search( - MemorySet memorySet, String query, int limit, Map<String, Object> extraArgs) + MemorySet memorySet, + String query, + int limit, + @Nullable Map<String, Object> filters, + Map<String, Object> extraArgs) throws Exception; } diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/LongTermMemoryOptions.java b/api/src/main/java/org/apache/flink/agents/api/memory/LongTermMemoryOptions.java index caafeb82..85717cb9 100644 --- a/api/src/main/java/org/apache/flink/agents/api/memory/LongTermMemoryOptions.java +++ b/api/src/main/java/org/apache/flink/agents/api/memory/LongTermMemoryOptions.java @@ -19,34 +19,19 @@ package org.apache.flink.agents.api.memory; import org.apache.flink.agents.api.configuration.ConfigOption; +/** Config options for long-term memory. */ public class LongTermMemoryOptions { - public enum LongTermMemoryBackend { - EXTERNAL_VECTOR_STORE("external_vector_store"); - private final String value; + /** Config options for the Mem0-based long-term memory backend. */ + public static class Mem0 { + public static final ConfigOption<String> CHAT_MODEL_SETUP = + new ConfigOption<>("long-term-memory.mem0.chat-model-setup", String.class, null); - LongTermMemoryBackend(String value) { - this.value = value; - } + public static final ConfigOption<String> EMBEDDING_MODEL_SETUP = + new ConfigOption<>( + "long-term-memory.mem0.embedding-model-setup", String.class, null); - public String getValue() { - return value; - } + public static final ConfigOption<String> VECTOR_STORE = + new ConfigOption<>("long-term-memory.mem0.vector-store", String.class, null); } - - /** The backend for long-term memory. */ - public static final ConfigOption<LongTermMemoryBackend> BACKEND = - new ConfigOption<>("long-term-memory.backend", LongTermMemoryBackend.class, null); - - /** The name of the vector store to server as the backend for long-term memory. */ - public static final ConfigOption<String> EXTERNAL_VECTOR_STORE_NAME = - new ConfigOption<>("long-term-memory.external-vector-store-name", String.class, null); - - /** Whether execute compaction asynchronously . */ - public static final ConfigOption<Boolean> ASYNC_COMPACTION = - new ConfigOption<>("long-term-memory.async-compaction", Boolean.class, true); - - /** The thread count of executor for async compaction. */ - public static final ConfigOption<Integer> THREAD_COUNT = - new ConfigOption<>("long-term-memory.async-compaction.thread-count", Integer.class, 16); } diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/MemorySet.java b/api/src/main/java/org/apache/flink/agents/api/memory/MemorySet.java index 9e827dbd..2cfdb87d 100644 --- a/api/src/main/java/org/apache/flink/agents/api/memory/MemorySet.java +++ b/api/src/main/java/org/apache/flink/agents/api/memory/MemorySet.java @@ -20,91 +20,80 @@ package org.apache.flink.agents.api.memory; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; -import org.apache.flink.agents.api.memory.compaction.CompactionConfig; import javax.annotation.Nullable; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; +/** + * Represents a long term memory set, a named collection of memory items. Acts as a thin proxy that + * delegates all operations to the bound {@link BaseLongTermMemory}. + */ public class MemorySet { private final String name; - private final Class<?> itemType; - private final int capacity; - private final CompactionConfig compactionConfig; private @JsonIgnore BaseLongTermMemory ltm; @JsonCreator - public MemorySet( - @JsonProperty("name") String name, - @JsonProperty("itemType") Class<?> itemType, - @JsonProperty("capacity") int capacity, - @JsonProperty("compactionConfig") CompactionConfig compactionConfig) { + public MemorySet(@JsonProperty("name") String name) { this.name = name; - this.itemType = itemType; - this.capacity = capacity; - this.compactionConfig = compactionConfig; } /** - * Gets the number of items in this memory set. + * Adds items to this memory set. The backend may auto-generate IDs. * - * @return the number of items in the memory set - * @throws Exception if the size cannot be determined + * @param memoryItems the items to add + * @param metadatas optional list of metadata maps, one per item + * @return list of IDs of the added items */ - public long size() throws Exception { - return this.ltm.size(this); + public List<String> add(List<String> memoryItems, @Nullable List<Map<String, Object>> metadatas) + throws Exception { + return this.ltm.add(this, memoryItems, metadatas); } /** - * Adds items to this memory set. If IDs are not provided, they will be automatically generated. - * This method may trigger compaction if the memory set capacity is exceeded. + * Retrieves memory items. When {@code ids} is provided, {@code filters} and {@code limit} are + * ignored. * - * @param memoryItems the items to be added to the memory set - * @param ids optional list of IDs for the items. If null or shorter than memoryItems, IDs will - * be auto-generated for missing items - * @param metadatas optional list of metadata maps for the items. Each metadata map corresponds - * to an item at the same index - * @return list of IDs of the added items - * @throws Exception if items cannot be added to the memory set + * @param ids optional list of item IDs to retrieve + * @param filters optional metadata filters + * @param limit maximum number of items to return + * @return list of matching memory items */ - public List<String> add( - List<?> memoryItems, + public List<MemorySetItem> get( @Nullable List<String> ids, - @Nullable List<Map<String, Object>> metadatas) + @Nullable Map<String, Object> filters, + @Nullable Integer limit) throws Exception { - return this.ltm.add(this, memoryItems, ids, metadatas); + return this.ltm.get(this, ids, filters, limit); } /** - * Retrieves memory items from this memory set. If no IDs are provided, all items in the memory - * set are returned. + * Performs semantic search on this memory set. * - * @param ids optional list of item IDs to retrieve. If null, all items are returned - * @return list of memory set items. If ids is provided, returns items matching those IDs. If - * ids is null, returns all items in the memory set - * @throws Exception if items cannot be retrieved from the memory set + * @param query the query string for semantic search + * @param limit maximum number of items to return + * @param filters optional metadata filters + * @param extraArgs backend-specific extra arguments; pass an empty map when none are needed + * @return list of memory items most relevant to the query, ordered by relevance */ - public List<MemorySetItem> get(@Nullable List<String> ids) throws Exception { - return this.ltm.get(this, ids); + public List<MemorySetItem> search( + String query, + int limit, + @Nullable Map<String, Object> filters, + Map<String, Object> extraArgs) + throws Exception { + return this.ltm.search(this, query, limit, filters, extraArgs); } /** - * Performs semantic search on this memory set to find items related to the query string. + * Deletes memory items. If {@code ids} is null, all items in the set are deleted. * - * @param query the query string for semantic search - * @param limit the maximum number of items to return - * @param extraArgs optional additional arguments for the search operation (e.g., filters, - * distance metrics). If null, an empty map is used - * @return list of memory set items that are most relevant to the query, ordered by relevance - * @throws Exception if the search operation fails + * @param ids optional list of item IDs to delete */ - public List<MemorySetItem> search( - String query, int limit, @Nullable Map<String, Object> extraArgs) throws Exception { - return this.ltm.search( - this, query, limit, extraArgs == null ? Collections.emptyMap() : extraArgs); + public void delete(@Nullable List<String> ids) throws Exception { + this.ltm.delete(this, ids); } public void setLtm(BaseLongTermMemory ltm) { @@ -115,45 +104,20 @@ public class MemorySet { return name; } - public Class<?> getItemType() { - return itemType; - } - - public int getCapacity() { - return capacity; - } - - public CompactionConfig getCompactionConfig() { - return compactionConfig; - } - @Override public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; MemorySet memorySet = (MemorySet) o; - return capacity == memorySet.capacity - && Objects.equals(name, memorySet.name) - && Objects.equals(itemType, memorySet.itemType) - && Objects.equals(compactionConfig, memorySet.compactionConfig); + return Objects.equals(name, memorySet.name); } @Override public int hashCode() { - return Objects.hash(name, itemType, capacity, compactionConfig); + return Objects.hash(name); } @Override public String toString() { - return "MemorySet{" - + "name='" - + name - + '\'' - + ", itemType=" - + itemType - + ", capacity=" - + capacity - + ", compactionConfig=" - + compactionConfig - + '}'; + return "MemorySet{name='" + name + "'}"; } } diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/MemorySetItem.java b/api/src/main/java/org/apache/flink/agents/api/memory/MemorySetItem.java index b822fbc0..64183f94 100644 --- a/api/src/main/java/org/apache/flink/agents/api/memory/MemorySetItem.java +++ b/api/src/main/java/org/apache/flink/agents/api/memory/MemorySetItem.java @@ -17,33 +17,33 @@ */ package org.apache.flink.agents.api.memory; +import javax.annotation.Nullable; + import java.time.LocalDateTime; import java.util.Map; +/** Represents a long term memory item. */ public class MemorySetItem { private final String memorySetName; private final String id; - private final Object value; - private final boolean compacted; - private final Object createdTime; - private final LocalDateTime lastAccessedTime; - private final Map<String, Object> metadata; + private final String value; + private final @Nullable LocalDateTime createdAt; + private final @Nullable LocalDateTime updatedAt; + private final @Nullable Map<String, Object> additionalMetadata; public MemorySetItem( String memorySetName, String id, - Object value, - boolean compacted, - Object createdTime, - LocalDateTime lastAccessedTime, - Map<String, Object> metadata) { + String value, + @Nullable LocalDateTime createdAt, + @Nullable LocalDateTime updatedAt, + @Nullable Map<String, Object> additionalMetadata) { this.memorySetName = memorySetName; this.id = id; this.value = value; - this.compacted = compacted; - this.createdTime = createdTime; - this.lastAccessedTime = lastAccessedTime; - this.metadata = metadata; + this.createdAt = createdAt; + this.updatedAt = updatedAt; + this.additionalMetadata = additionalMetadata; } public String getMemorySetName() { @@ -54,41 +54,19 @@ public class MemorySetItem { return id; } - public Object getValue() { + public String getValue() { return value; } - public boolean isCompacted() { - return compacted; - } - - public Object getCreatedTime() { - return createdTime; - } - - public LocalDateTime getLastAccessedTime() { - return lastAccessedTime; + public @Nullable LocalDateTime getCreatedAt() { + return createdAt; } - public Map<String, Object> getMetadata() { - return metadata; + public @Nullable LocalDateTime getUpdatedAt() { + return updatedAt; } - public static class DateTimeRange { - private final LocalDateTime start; - private final LocalDateTime end; - - public DateTimeRange(LocalDateTime start, LocalDateTime end) { - this.start = start; - this.end = end; - } - - public LocalDateTime getStart() { - return start; - } - - public LocalDateTime getEnd() { - return end; - } + public @Nullable Map<String, Object> getAdditionalMetadata() { + return additionalMetadata; } } diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/compaction/CompactionConfig.java b/api/src/main/java/org/apache/flink/agents/api/memory/compaction/CompactionConfig.java deleted file mode 100644 index fa399184..00000000 --- a/api/src/main/java/org/apache/flink/agents/api/memory/compaction/CompactionConfig.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.agents.api.memory.compaction; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import javax.annotation.Nullable; - -import java.util.Objects; - -/** Configuration for long-term memory compaction. */ -public class CompactionConfig { - private final String model; - - @JsonTypeInfo( - use = JsonTypeInfo.Id.CLASS, - include = JsonTypeInfo.As.PROPERTY, - property = "@class") - private final Object prompt; - - private final int limit; - - public CompactionConfig(String model, int limit) { - this(model, null, limit); - } - - @JsonCreator - public CompactionConfig( - @JsonProperty("model") String model, - @Nullable @JsonProperty("prompt") Object prompt, - @JsonProperty("limit") int limit) { - this.model = model; - this.prompt = prompt; - this.limit = limit; - } - - public String getModel() { - return model; - } - - public Object getPrompt() { - return prompt; - } - - public int getLimit() { - return limit; - } - - @Override - public boolean equals(Object o) { - if (o == null || getClass() != o.getClass()) return false; - CompactionConfig that = (CompactionConfig) o; - return limit == that.limit - && Objects.equals(model, that.model) - && Objects.equals(prompt, that.prompt); - } - - @Override - public int hashCode() { - return Objects.hash(model, prompt, limit); - } - - @Override - public String toString() { - return "CompactionConfig{" - + "model='" - + model - + '\'' - + ", prompt=" - + prompt - + ", limit=" - + limit - + '}'; - } -} diff --git a/api/src/test/java/org/apache/flink/agents/api/memory/MemorySetTest.java b/api/src/test/java/org/apache/flink/agents/api/memory/MemorySetTest.java index dc5f7cb9..2714adc7 100644 --- a/api/src/test/java/org/apache/flink/agents/api/memory/MemorySetTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/memory/MemorySetTest.java @@ -18,9 +18,6 @@ package org.apache.flink.agents.api.memory; import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.flink.agents.api.chat.messages.ChatMessage; -import org.apache.flink.agents.api.memory.compaction.CompactionConfig; -import org.apache.flink.agents.api.prompt.Prompt; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -28,12 +25,7 @@ public class MemorySetTest { @Test public void testJsonSerialization() throws Exception { ObjectMapper mapper = new ObjectMapper(); - MemorySet memorySet = - new MemorySet( - "test", - ChatMessage.class, - 100, - new CompactionConfig("testModel", Prompt.fromText("Test prompt"), 100)); + MemorySet memorySet = new MemorySet("test"); String jsonValue = mapper.writeValueAsString(memorySet); MemorySet deserialized = mapper.readValue(jsonValue, MemorySet.class); diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml index f61d91ad..2d19a8e1 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/pom.xml @@ -35,6 +35,17 @@ <artifactId>flink-agents-integrations-embedding-models-ollama</artifactId> <version>${project.version}</version> </dependency> + <!-- Required by Mem0LongTermMemoryTest: OpenAI-compatible chat model + ES vector store. --> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-agents-integrations-chat-models-openai</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-agents-integrations-vector-stores-elasticsearch</artifactId> + <version>${project.version}</version> + </dependency> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-streaming-java</artifactId> diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryAgent.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryAgent.java new file mode 100644 index 00000000..449a08de --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryAgent.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.resource.test; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.annotation.ChatModelConnection; +import org.apache.flink.agents.api.annotation.ChatModelSetup; +import org.apache.flink.agents.api.annotation.EmbeddingModelConnection; +import org.apache.flink.agents.api.annotation.EmbeddingModelSetup; +import org.apache.flink.agents.api.annotation.VectorStore; +import org.apache.flink.agents.api.context.DurableCallable; +import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceName; +import org.apache.flink.api.java.functions.KeySelector; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * E2E test agent that mirrors the Python {@code long_term_memory_test.py}: streams {@link + * ItemData}, asynchronously appends each fact to a Mem0-backed long-term memory under a per-name + * key, and emits a per-record {@link OutputEvent} with the timestamps and (when applicable) the + * full retrieved item set. + * + * <p>All resources are declared as native Java implementations (Ollama chat / embedding, + * Elasticsearch vector store). Python's mem0 adapter consumes them through the cross-language + * bridge: {@code ctx.get_resource(name, type)} on the Python side returns a Java*Impl wrapper that + * delegates back into Java via Pemja. + * + * <p>The test driving this agent must (1) pull the Ollama models and (2) provide ES connection env + * vars ({@code ES_HOST}, {@code ES_INDEX}, {@code ES_DIMS}, {@code ES_VECTOR_FIELD}, optional + * {@code ES_USERNAME}/{@code ES_PASSWORD}); see {@link Mem0LongTermMemoryTest}. + */ +public class Mem0LongTermMemoryAgent extends Agent { + + public static final String CHAT_MODEL = "qwen3.6-plus"; + public static final String OLLAMA_EMBEDDING_MODEL = "nomic-embed-text"; + public static final String MEMORY_SET_NAME = "test_ltm"; + + /** Mirrors the Python e2e: dashscope-hosted OpenAI-compatible endpoint, env-overridable. */ + private static final String DEFAULT_BASE_URL = "https://coding.dashscope.aliyuncs.com/v1"; + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + /** Per-name fact emitted by the test source. */ + public static class ItemData { + public String name; + public String fact; + + public ItemData() {} + + public ItemData(String name, String fact) { + this.name = name; + this.fact = fact; + } + } + + /** KeySelector — partitions the stream by username. */ + public static class ItemDataKeySelector implements KeySelector<ItemData, String> { + @Override + public String getKey(ItemData value) { + return value.name; + } + } + + /** Carrier event — passes the in-progress {@code Record} between the two actions. */ + public static class MyEvent extends Event { + private Map<String, Object> value; + + public MyEvent() {} + + public MyEvent(Map<String, Object> value) { + this.value = value; + } + + public Map<String, Object> getValue() { + return value; + } + + public void setValue(Map<String, Object> value) { + this.value = value; + } + } + + @ChatModelConnection + public static ResourceDescriptor openaiConnection() { + String baseUrl = System.getenv().getOrDefault("ACTION_BASE_URL", DEFAULT_BASE_URL); + String apiKey = System.getenv("ACTION_API_KEY"); + return ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.OPENAI_COMPLETIONS_CONNECTION) + .addInitialArgument("api_key", apiKey) + .addInitialArgument("api_base_url", baseUrl) + .addInitialArgument("request_timeout", 300) + .build(); + } + + @ChatModelSetup + public static ResourceDescriptor openaiQwen3() { + return ResourceDescriptor.Builder.newBuilder( + ResourceName.ChatModel.OPENAI_COMPLETIONS_SETUP) + .addInitialArgument("connection", "openaiConnection") + .addInitialArgument("model", CHAT_MODEL) + .addInitialArgument("extract_reasoning", true) + .addInitialArgument("think", false) + .build(); + } + + @EmbeddingModelConnection + public static ResourceDescriptor ollamaEmbeddingConnection() { + return ResourceDescriptor.Builder.newBuilder(ResourceName.EmbeddingModel.OLLAMA_CONNECTION) + .addInitialArgument("host", "http://localhost:11434") + .addInitialArgument("timeout", 240) + .build(); + } + + @EmbeddingModelSetup + public static ResourceDescriptor ollamaNomicEmbedText() { + return ResourceDescriptor.Builder.newBuilder(ResourceName.EmbeddingModel.OLLAMA_SETUP) + .addInitialArgument("connection", "ollamaEmbeddingConnection") + .addInitialArgument("model", OLLAMA_EMBEDDING_MODEL) + .build(); + } + + @VectorStore + public static ResourceDescriptor esLtmStore() { + ResourceDescriptor.Builder builder = + ResourceDescriptor.Builder.newBuilder( + ResourceName.VectorStore.ELASTICSEARCH_VECTOR_STORE) + .addInitialArgument("embedding_model", "ollamaNomicEmbedText") + .addInitialArgument("host", System.getenv("ES_HOST")) + .addInitialArgument( + "collection", + UUID.randomUUID().toString().substring(0, 8) + "-context"); + String username = System.getenv("ES_USERNAME"); + String password = System.getenv("ES_PASSWORD"); + if (username != null && password != null) { + builder.addInitialArgument("username", username) + .addInitialArgument("password", password); + } + return builder.build(); + } + + @Action(listenEvents = {InputEvent.class}) + public static void addItems(InputEvent event, RunnerContext ctx) throws Exception { + ItemData input = (ItemData) event.getInput(); + BaseLongTermMemory ltm = ctx.getLongTermMemory(); + + String timestampBeforeAdd = Instant.now().toString(); + MemorySet memorySet = ltm.getMemorySet(MEMORY_SET_NAME); + ctx.durableExecuteAsync( + new DurableCallable<Void>() { + @Override + public String getId() { + return "mem0-add-" + input.name; + } + + @Override + public Class<Void> getResultClass() { + return Void.class; + } + + @Override + public Void call() throws Exception { + memorySet.add(List.of(input.fact), null); + return null; + } + }); + String timestampAfterAdd = Instant.now().toString(); + + MemoryObject countObj = ctx.getShortTermMemory().get("count"); + int count = countObj == null ? 1 : ((Number) countObj.getValue()).intValue(); + ctx.getShortTermMemory().set("count", count + 1); + + Map<String, Object> record = new HashMap<>(); + record.put("name", input.name); + record.put("count", count); + record.put("timestamp_before_add", timestampBeforeAdd); + record.put("timestamp_after_add", timestampAfterAdd); + ctx.sendEvent(new MyEvent(record)); + } + + @Action(listenEvents = {MyEvent.class}) + public static void retrieveItems(MyEvent event, RunnerContext ctx) throws Exception { + Map<String, Object> record = event.getValue(); + record.put("timestamp_second_action", Instant.now().toString()); + + MemorySet memorySet = ctx.getLongTermMemory().getMemorySet(MEMORY_SET_NAME); + String name = (String) record.get("name"); + int count = ((Number) record.get("count")).intValue(); + @SuppressWarnings({"unchecked", "rawtypes"}) + List<MemorySetItem> items = + ctx.durableExecuteAsync( + new DurableCallable<List>() { + @Override + public String getId() { + return "mem0-get-" + name + "-" + count; + } + + @Override + public Class<List> getResultClass() { + return List.class; + } + + @Override + public List<MemorySetItem> call() throws Exception { + return memorySet.get(null, null, null); + } + }); + if (("alice".equals(name) || "bob".equals(name)) && count == 2) { + // Serialise the items as a JSON string. Embedding the raw List<Map<String,Object>> + // here trips up Flink's Kryo deep-copy when chained operators forward the record: + // CollectionSerializer.copy -> ArrayList.add(NPE) + // Stringifying sidesteps the Kryo path entirely (the test parses it back). + List<Map<String, Object>> serialised = new ArrayList<>(); + for (MemorySetItem item : items) { + Map<String, Object> m = new HashMap<>(); + m.put("id", item.getId()); + m.put("value", item.getValue()); + m.put( + "created_at", + item.getCreatedAt() == null ? null : item.getCreatedAt().toString()); + m.put( + "updated_at", + item.getUpdatedAt() == null ? null : item.getUpdatedAt().toString()); + serialised.add(m); + } + record.put("items_json", MAPPER.writeValueAsString(serialised)); + } + ctx.sendEvent(new OutputEvent(record)); + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryTest.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryTest.java new file mode 100644 index 00000000..d11b7dfe --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/Mem0LongTermMemoryTest.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.resource.test; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.AgentsExecutionEnvironment; +import org.apache.flink.agents.api.configuration.AgentConfigOptions; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.time.Instant; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.resource.test.CrossLanguageTestPreparationUtils.pullModel; +import static org.apache.flink.agents.resource.test.Mem0LongTermMemoryAgent.OLLAMA_EMBEDDING_MODEL; + +/** + * End-to-end test for {@link org.apache.flink.agents.runtime.memory.Mem0LongTermMemory}, mirroring + * the Python {@code long_term_memory_test.py}. Streams four facts (alice/bob interleaved) through + * {@link Mem0LongTermMemoryAgent}, exercising the full cross-language path: Java agent → Java + * Mem0LongTermMemory wrapper → Python mem0 instance → mem0 adapter calling back into Java for chat + * / embedding / vector store work. + * + * <p>Skipped when any prerequisite is missing: + * + * <ul> + * <li>Ollama daemon serving {@code nomic-embed-text} (the embedding model) + * <li>{@code ACTION_API_KEY} env var (and optionally {@code ACTION_BASE_URL}) for the + * OpenAI-compatible chat model — mirrors the Python e2e test's setup + * <li>{@code python} on PATH with {@code mem0ai} and {@code flink_agents} installed + * <li>Elasticsearch reachable via the {@code ES_HOST} env var + * </ul> + */ +public class Mem0LongTermMemoryTest { + + private final boolean embeddingReady; + private final boolean pythonReady; + private final boolean esConfigured; + private final boolean apiKeySet; + + public Mem0LongTermMemoryTest() throws IOException { + embeddingReady = pullModel(OLLAMA_EMBEDDING_MODEL); + pythonReady = isPythonAvailable(); + esConfigured = System.getenv("ES_HOST") != null; + apiKeySet = System.getenv("ACTION_API_KEY") != null; + } + + @Test + @Disabled("Using mem0 in java depends on the pemja fix.") + public void testMem0LongTermMemory() throws Exception { + Assumptions.assumeTrue( + embeddingReady, + "Ollama is not reachable or the embedding model could not be pulled"); + Assumptions.assumeTrue( + pythonReady, + "`python` executable not found on PATH; this test requires Python with mem0ai installed"); + Assumptions.assumeTrue(esConfigured, "Elasticsearch env var (ES_HOST) is not set"); + Assumptions.assumeTrue( + apiKeySet, + "ACTION_API_KEY env var is not set; required for the OpenAI-compatible chat model"); + + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + + DataStream<Mem0LongTermMemoryAgent.ItemData> inputStream = + env.fromElements( + new Mem0LongTermMemoryAgent.ItemData( + "alice", "My favorite fruit is watermelon."), + new Mem0LongTermMemoryAgent.ItemData("bob", "I like swimming."), + new Mem0LongTermMemoryAgent.ItemData( + "bob", "I'm a vegetarian and allergic to nuts."), + new Mem0LongTermMemoryAgent.ItemData( + "alice", "My favorite fruit is bananas.")); + + AgentsExecutionEnvironment agentsEnv = + AgentsExecutionEnvironment.getExecutionEnvironment(env); + agentsEnv.getConfig().set(AgentConfigOptions.JOB_IDENTIFIER, "LTM_TEST_JOB"); + agentsEnv.getConfig().set(LongTermMemoryOptions.Mem0.CHAT_MODEL_SETUP, "openaiQwen3"); + agentsEnv + .getConfig() + .set(LongTermMemoryOptions.Mem0.EMBEDDING_MODEL_SETUP, "ollamaNomicEmbedText"); + agentsEnv.getConfig().set(LongTermMemoryOptions.Mem0.VECTOR_STORE, "esLtmStore"); + + DataStream<Object> outputStream = + agentsEnv + .fromDataStream( + inputStream, new Mem0LongTermMemoryAgent.ItemDataKeySelector()) + .apply(new Mem0LongTermMemoryAgent()) + .toDataStream(); + + CloseableIterator<Object> results = outputStream.collectAsync(); + agentsEnv.execute(); + + checkResult(results); + } + + private static boolean isPythonAvailable() { + try { + Process p = new ProcessBuilder("python", "--version").start(); + return p.waitFor() == 0; + } catch (Exception e) { + return false; + } + } + + @SuppressWarnings("unchecked") + private void checkResult(CloseableIterator<Object> results) throws Exception { + Map<String, Map<String, Object>> records = new HashMap<>(); + for (int i = 0; i < 4; i++) { + Assertions.assertTrue( + results.hasNext(), "Expected 4 records, only got " + records.size()); + Map<String, Object> record = (Map<String, Object>) results.next(); + String name = (String) record.get("name"); + int count = ((Number) record.get("count")).intValue(); + records.put(name + "." + count, record); + } + results.close(); + + // alice's second pass must contain the items list, with mem0 having merged + // the watermelon/banana facts into a single entry that mentions bananas. + Map<String, Object> aliceTwo = records.get("alice.2"); + Assertions.assertNotNull(aliceTwo, "Missing alice.2 record"); + String itemsJson = (String) aliceTwo.get("items_json"); + Assertions.assertNotNull(itemsJson, "alice.2 record must carry the items_json payload"); + List<Map<String, Object>> items = + new ObjectMapper() + .readValue(itemsJson, new TypeReference<List<Map<String, Object>>>() {}); + Assertions.assertEquals(1, items.size(), "Expected mem0 to merge alice's facts into one"); + + Map<String, Object> item = items.get(0); + String value = (String) item.get("value"); + Assertions.assertTrue( + value.toLowerCase().contains("banana"), + "alice's surviving item should reflect the bananas update, was: " + value); + + String createdAt = (String) item.get("created_at"); + String updatedAt = (String) item.get("updated_at"); + Assertions.assertNotNull(createdAt, "created_at must be populated"); + Assertions.assertNotNull(updatedAt, "updated_at must be populated"); + // The agent serialises these from MemorySetItem's LocalDateTime fields, so they have + // no trailing 'Z'/offset and must be parsed as LocalDateTime, not Instant. + Assertions.assertTrue( + LocalDateTime.parse(createdAt).isBefore(LocalDateTime.parse(updatedAt)), + "updated_at should be strictly after created_at when mem0 merged the facts"); + + // Async add must not block other keys: alice.1 starts before bob.1 finishes. + Map<String, Object> aliceOne = records.get("alice.1"); + Map<String, Object> bobOne = records.get("bob.1"); + Assertions.assertNotNull(aliceOne, "Missing alice.1 record"); + Assertions.assertNotNull(bobOne, "Missing bob.1 record"); + Assertions.assertTrue( + Instant.parse((String) aliceOne.get("timestamp_before_add")) + .isBefore(Instant.parse((String) bobOne.get("timestamp_after_add"))), + "alice.1 should start its add before bob.1 finishes (async non-blocking)"); + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/Utils.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/Utils.java index 8789b292..1f1a3c66 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/Utils.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/Utils.java @@ -22,12 +22,15 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.InputStream; +import java.util.List; import java.util.Properties; import java.util.StringTokenizer; public final class Utils { private static final Logger LOG = LoggerFactory.getLogger(Utils.class); private static final String DEFAULT_VALUE = "<unknown>"; + public static final List<String> requiredVersions = + List.of("1.20.3", "2.0.1", "2.1.1", "2.2.0"); static final Versions INSTANCE = new Versions(); diff --git a/python/flink_agents/runtime/python_java_utils.py b/python/flink_agents/runtime/python_java_utils.py index dd3a61ea..a7ab6f09 100644 --- a/python/flink_agents/runtime/python_java_utils.py +++ b/python/flink_agents/runtime/python_java_utils.py @@ -24,6 +24,7 @@ import cloudpickle from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.events.event import Event, InputEvent +from flink_agents.api.memory.long_term_memory import MemorySet, MemorySetItem from flink_agents.api.resource import Resource, ResourceType, get_resource_class from flink_agents.api.tools.tool import Tool, ToolMetadata from flink_agents.api.tools.utils import ( @@ -292,6 +293,42 @@ def get_mode_value(query: VectorStoreQuery) -> str: return query.mode.value +def get_long_term_memory(ctx: Any) -> Any: + """Return ``ctx.long_term_memory`` (or ``None``). Used by the Java side to + avoid relying on Pemja's ``PyObject.getAttr`` semantics for attributes that + may be ``None`` or wrapped Pydantic BaseModel instances. + """ + return ctx.long_term_memory + + +def to_python_memory_set(name: str) -> MemorySet: + """Build a Python ``MemorySet`` from its name. Used by the Java + ``Mem0LongTermMemory`` wrapper to forward calls into Python ``Mem0LongTermMemory``, + which expects a ``MemorySet`` instance but only reads its ``name`` field. + """ + return MemorySet(name=name) + + +def mem0_items_to_java( + items: typing.List[MemorySetItem], +) -> typing.List[Dict[str, Any]]: + """Convert a list of ``MemorySetItem`` to plain dicts so the Java side can + consume them without PyObject reflection. Datetimes are serialised to ISO 8601 + strings; ``None`` fields are preserved. + """ + return [ + { + "memory_set_name": it.memory_set_name, + "id": it.id, + "value": it.value, + "created_at": it.created_at.isoformat() if it.created_at else None, + "updated_at": it.updated_at.isoformat() if it.updated_at else None, + "additional_metadata": it.additional_metadata, + } + for it in items + ] + + def call_method(obj: Any, method_name: str, kwargs: Dict[str, Any]) -> Any: """Calls a method on `obj` by name and passes in positional and keyword arguments. diff --git a/runtime/pom.xml b/runtime/pom.xml index 384eeea8..39679389 100644 --- a/runtime/pom.xml +++ b/runtime/pom.xml @@ -141,6 +141,14 @@ under the License. <version>${log4j2.version}</version> <scope>provided</scope> </dependency> + <!-- Required so RunnerContextImpl's ObjectMapper can serialise the + Java 8 date/time types that show up in DurableCallable results + (e.g. MemorySetItem.createdAt / updatedAt). --> + <dependency> + <groupId>com.fasterxml.jackson.datatype</groupId> + <artifactId>jackson-datatype-jsr310</artifactId> + <version>${jackson.version}</version> + </dependency> </dependencies> <build> diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index c479c3a7..b0603ecb 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -20,6 +20,7 @@ package org.apache.flink.agents.runtime.context; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.configuration.ReadableConfiguration; import org.apache.flink.agents.api.context.DurableCallable; @@ -27,7 +28,6 @@ import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.MemoryUpdate; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.memory.BaseLongTermMemory; -import org.apache.flink.agents.api.memory.LongTermMemoryOptions; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; @@ -39,7 +39,6 @@ import org.apache.flink.agents.runtime.actionstate.CallResult; import org.apache.flink.agents.runtime.memory.CachedMemoryStore; import org.apache.flink.agents.runtime.memory.InteranlBaseLongTermMemory; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; -import org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; @@ -60,7 +59,8 @@ import java.util.concurrent.Callable; */ public class RunnerContextImpl implements RunnerContext { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final ObjectMapper OBJECT_MAPPER = + new ObjectMapper().registerModule(new JavaTimeModule()); public static class MemoryContext { private final CachedMemoryStore sensoryMemStore; @@ -118,14 +118,10 @@ public class RunnerContextImpl implements RunnerContext { this.mailboxThreadChecker = mailboxThreadChecker; this.agentPlan = agentPlan; this.resourceCache = resourceCache; + } - LongTermMemoryOptions.LongTermMemoryBackend backend = - this.getConfig().get(LongTermMemoryOptions.BACKEND); - if (backend == LongTermMemoryOptions.LongTermMemoryBackend.EXTERNAL_VECTOR_STORE) { - String vectorStoreName = - this.getConfig().get(LongTermMemoryOptions.EXTERNAL_VECTOR_STORE_NAME); - ltm = new VectorStoreLongTermMemory(this, vectorStoreName, jobIdentifier); - } + public void setLongTermMemory(InteranlBaseLongTermMemory ltm) { + this.ltm = ltm; } public void switchActionContext(String actionName, MemoryContext memoryContext, String key) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/Mem0LongTermMemory.java b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/Mem0LongTermMemory.java new file mode 100644 index 00000000..2847812a --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/Mem0LongTermMemory.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import pemja.core.object.PyObject; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Java-side thin wrapper around the Python {@code Mem0LongTermMemory} instance. All public methods + * forward to Python via {@link PythonResourceAdapter#callMethod}; Python-only return types ({@code + * MemorySetItem}) are converted into Java POJOs by the {@code mem0_items_to_java} helper in {@code + * python_java_utils.py}. + */ +public class Mem0LongTermMemory implements InteranlBaseLongTermMemory { + + private static final String TO_PYTHON_MEMORY_SET = "python_java_utils.to_python_memory_set"; + private static final String MEM0_ITEMS_TO_JAVA = "python_java_utils.mem0_items_to_java"; + + private final PythonResourceAdapter adapter; + private final PyObject pyMem0; + + public Mem0LongTermMemory(PythonResourceAdapter adapter, PyObject pyMem0) { + this.adapter = adapter; + this.pyMem0 = pyMem0; + } + + @Override + public MemorySet getMemorySet(String name) { + // Mirrors Python's `Mem0LongTermMemory.get_memory_set`: a pure factory that + // returns a new MemorySet bound to this ltm; no Python call is needed. + MemorySet ms = new MemorySet(name); + ms.setLtm(this); + return ms; + } + + @Override + public boolean deleteMemorySet(String name) { + return (Boolean) adapter.callMethod(pyMem0, "delete_memory_set", Map.of("name", name)); + } + + @Override + @SuppressWarnings("unchecked") + public List<String> add( + MemorySet memorySet, + List<String> memoryItems, + @Nullable List<Map<String, Object>> metadatas) { + Map<String, Object> kwargs = new HashMap<>(); + kwargs.put("memory_set", buildPyMemorySet(memorySet)); + kwargs.put("memory_items", memoryItems); + if (metadatas != null) { + kwargs.put("metadatas", metadatas); + } + return (List<String>) adapter.callMethod(pyMem0, "add", kwargs); + } + + @Override + public List<MemorySetItem> get( + MemorySet memorySet, + @Nullable List<String> ids, + @Nullable Map<String, Object> filters, + @Nullable Integer limit) { + Map<String, Object> kwargs = new HashMap<>(); + kwargs.put("memory_set", buildPyMemorySet(memorySet)); + if (ids != null) { + kwargs.put("ids", ids); + } + if (filters != null) { + kwargs.put("filters", filters); + } + if (limit != null) { + kwargs.put("limit", limit); + } + Object pyItems = adapter.callMethod(pyMem0, "get", kwargs); + return convertItems(pyItems); + } + + @Override + public void delete(MemorySet memorySet, @Nullable List<String> ids) { + Map<String, Object> kwargs = new HashMap<>(); + kwargs.put("memory_set", buildPyMemorySet(memorySet)); + if (ids != null) { + kwargs.put("ids", ids); + } + adapter.callMethod(pyMem0, "delete", kwargs); + } + + @Override + public List<MemorySetItem> search( + MemorySet memorySet, + String query, + int limit, + @Nullable Map<String, Object> filters, + Map<String, Object> extraArgs) { + Map<String, Object> kwargs = new HashMap<>(extraArgs); + kwargs.put("memory_set", buildPyMemorySet(memorySet)); + kwargs.put("query", query); + kwargs.put("limit", limit); + if (filters != null) { + kwargs.put("filters", filters); + } + Object pyItems = adapter.callMethod(pyMem0, "search", kwargs); + return convertItems(pyItems); + } + + @Override + public void switchContext(String key) { + adapter.callMethod(pyMem0, "switch_context", Map.of("key", key)); + } + + @Override + public void close() { + adapter.callMethod(pyMem0, "close", Map.of()); + } + + private Object buildPyMemorySet(MemorySet memorySet) { + return adapter.invoke(TO_PYTHON_MEMORY_SET, memorySet.getName()); + } + + @SuppressWarnings("unchecked") + private List<MemorySetItem> convertItems(Object pyItems) { + Object converted = adapter.invoke(MEM0_ITEMS_TO_JAVA, pyItems); + List<Map<String, Object>> dicts = + converted == null ? List.of() : (List<Map<String, Object>>) converted; + List<MemorySetItem> items = new ArrayList<>(dicts.size()); + for (Map<String, Object> dict : dicts) { + items.add(dictToItem(dict)); + } + return items; + } + + @SuppressWarnings("unchecked") + private MemorySetItem dictToItem(Map<String, Object> dict) { + return new MemorySetItem( + (String) dict.get("memory_set_name"), + (String) dict.get("id"), + (String) dict.get("value"), + parseTimestamp(dict.get("created_at")), + parseTimestamp(dict.get("updated_at")), + (Map<String, Object>) dict.get("additional_metadata")); + } + + private static @Nullable LocalDateTime parseTimestamp(@Nullable Object iso) { + if (iso == null) { + return null; + } + String s = (String) iso; + try { + // mem0's ISO 8601 timestamps include a UTC offset (e.g. "...+00:00"). + return OffsetDateTime.parse(s).toLocalDateTime(); + } catch (DateTimeParseException e) { + // Fall back to parsing offset-less timestamps for backends that omit them. + return LocalDateTime.parse(s); + } + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index e5015a3a..c2228872 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -28,6 +28,7 @@ import org.apache.flink.agents.api.logger.EventLogger; import org.apache.flink.agents.api.logger.EventLoggerConfig; import org.apache.flink.agents.api.logger.EventLoggerFactory; import org.apache.flink.agents.api.logger.EventLoggerOpenParams; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; @@ -49,6 +50,7 @@ import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; import org.apache.flink.agents.runtime.eventlog.FileEventLogger; import org.apache.flink.agents.runtime.memory.CachedMemoryStore; +import org.apache.flink.agents.runtime.memory.Mem0LongTermMemory; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; @@ -94,6 +96,7 @@ import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import pemja.core.PythonInterpreter; +import pemja.core.object.PyObject; import java.lang.reflect.Field; import java.util.ArrayList; @@ -106,6 +109,8 @@ import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTIO import static org.apache.flink.agents.api.configuration.AgentConfigOptions.BASE_LOG_DIR; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.JOB_IDENTIFIER; import static org.apache.flink.agents.api.configuration.AgentConfigOptions.PRETTY_PRINT; +import static org.apache.flink.agents.plan.actions.Utils.requiredVersions; +import static org.apache.flink.agents.plan.actions.Utils.supportAsync; import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkState; @@ -170,6 +175,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // RunnerContext for Java Actions private transient RunnerContextImpl runnerContext; + // Long-term memory backed by Mem0; non-null only when LongTermMemoryOptions.Mem0 is configured. + private transient Mem0LongTermMemory ltm; + // We need to check whether the current thread is the mailbox thread using the mailbox // processor. // TODO: This is a temporary workaround. In the future, we should add an interface in @@ -634,7 +642,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT instanceof PythonResourceProvider)); - if (containPythonAction || containPythonResource) { + boolean mem0Configured = isMem0Configured(); + + if (containPythonAction || containPythonResource || mem0Configured) { LOG.debug("Begin initialize PythonEnvironmentManager."); PythonDependencyInfo dependencyInfo = PythonDependencyInfo.create( @@ -661,15 +671,63 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT this.jobIdentifier); javaResourceAdapter = new JavaResourceAdapter(this::getResource, pythonInterpreter); - if (containPythonResource) { + if (containPythonResource || mem0Configured) { initPythonResourceAdapter(); } - if (containPythonAction) { + if (containPythonAction || mem0Configured) { initPythonActionExecutor(); } + if (mem0Configured) { + wireLongTermMemory(); + } } } + private boolean isMem0Configured() { + // Mirror Python's `_init_long_term_memory`: mem0 is considered configured only when + // all three resource names are present; otherwise the Python side returns None, and + // we should not pay the Python interpreter startup cost. + var config = agentPlan.getConfig(); + boolean configured = + config.get(LongTermMemoryOptions.Mem0.CHAT_MODEL_SETUP) != null + && config.get(LongTermMemoryOptions.Mem0.EMBEDDING_MODEL_SETUP) != null + && config.get(LongTermMemoryOptions.Mem0.VECTOR_STORE) != null; + + // Mem0 will call chat model and embedding model in its own thread executor, this behavior + // is same as the async execution for cross-language resources, and also requires the fix + // in pemja. + if (configured && !supportAsync()) { + throw new RuntimeException( + String.format( + "Using Mem0 based Long-Term Memory in java requires flink version higher" + + "than %s. You can upgrade flink or use python api.", + requiredVersions)); + } + return configured; + } + + /** + * Pull the {@code long_term_memory} attribute off the Python {@code FlinkRunnerContext} (which + * {@code create_flink_runner_context} already initialised via {@code _init_long_term_memory}) + * and wrap it as a Java {@link Mem0LongTermMemory}. + */ + private void wireLongTermMemory() { + PyObject pyCtx = pythonActionExecutor.getPythonRunnerContext(); + Object pyLtm = pythonInterpreter.invoke("python_java_utils.get_long_term_memory", pyCtx); + if (pyLtm == null) { + throw new IllegalStateException( + String.format( + "Mem0 long-term memory is configured on the Java side but the Python " + + "runner context returned no long-term memory. Verify that %s, " + + "%s, and %s all reference resources that exist and that the " + + "Python-side _init_long_term_memory succeeded.", + LongTermMemoryOptions.Mem0.CHAT_MODEL_SETUP.getKey(), + LongTermMemoryOptions.Mem0.EMBEDDING_MODEL_SETUP.getKey(), + LongTermMemoryOptions.Mem0.VECTOR_STORE.getKey())); + } + ltm = new Mem0LongTermMemory(pythonResourceAdapter, (PyObject) pyLtm); + } + private void initPythonActionExecutor() throws Exception { pythonActionExecutor = new PythonActionExecutor( @@ -1108,6 +1166,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT this.resourceCache, this.jobIdentifier, continuationActionExecutor); + if (ltm != null) { + runnerContext.setLongTermMemory(ltm); + } } return runnerContext; } else { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java index 659c48a2..7a354c91 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java @@ -90,6 +90,10 @@ public class PythonActionExecutor { this.jobIdentifier = jobIdentifier; } + public PyObject getPythonRunnerContext() { + return pythonRunnerContext; + } + public void open() throws Exception { interpreter.exec(PYTHON_IMPORTS); diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/Mem0LongTermMemoryTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/Mem0LongTermMemoryTest.java new file mode 100644 index 00000000..eee167bd --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/Mem0LongTermMemoryTest.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import pemja.core.object.PyObject; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class Mem0LongTermMemoryTest { + @Mock private PythonResourceAdapter mockAdapter; + @Mock private PyObject mockPyMem0; + @Mock private PyObject mockPyMemorySet; + + private Mem0LongTermMemory ltm; + private AutoCloseable mocks; + + @BeforeEach + void setUp() { + mocks = MockitoAnnotations.openMocks(this); + ltm = new Mem0LongTermMemory(mockAdapter, mockPyMem0); + when(mockAdapter.invoke(eq("python_java_utils.to_python_memory_set"), any())) + .thenReturn(mockPyMemorySet); + } + + @AfterEach + void tearDown() throws Exception { + if (mocks != null) { + mocks.close(); + } + } + + @Test + void testGetMemorySetIsPureFactoryAndBindsLtm() throws Exception { + MemorySet ms = ltm.getMemorySet("notes"); + + assertThat(ms.getName()).isEqualTo("notes"); + // Adding through the proxy should reach our ltm instance. + when(mockAdapter.callMethod(eq(mockPyMem0), eq("add"), any())).thenReturn(List.of("id1")); + ms.add(List.of("hello"), null); + verify(mockAdapter).callMethod(eq(mockPyMem0), eq("add"), any()); + // get_memory_set itself is a pure factory; it should NOT round-trip to Python. + verify(mockAdapter, org.mockito.Mockito.never()) + .callMethod(eq(mockPyMem0), eq("get_memory_set"), any()); + } + + @Test + void testDeleteMemorySetForwardsAndReturnsBoolean() throws Exception { + when(mockAdapter.callMethod(eq(mockPyMem0), eq("delete_memory_set"), any())) + .thenReturn(Boolean.TRUE); + + boolean deleted = ltm.deleteMemorySet("notes"); + + assertThat(deleted).isTrue(); + verify(mockAdapter) + .callMethod(eq(mockPyMem0), eq("delete_memory_set"), eq(Map.of("name", "notes"))); + } + + @Test + void testAddForwardsKwargsAndReturnsIds() throws Exception { + MemorySet ms = ltm.getMemorySet("notes"); + when(mockAdapter.callMethod(eq(mockPyMem0), eq("add"), any())) + .thenReturn(List.of("a", "b")); + + List<String> ids = + ltm.add(ms, List.of("hello", "world"), List.of(Map.of("k", "v"), Map.of())); + + assertThat(ids).containsExactly("a", "b"); + verify(mockAdapter) + .callMethod( + eq(mockPyMem0), + eq("add"), + argThat( + kwargs -> { + assertThat(kwargs) + .containsKeys( + "memory_set", "memory_items", "metadatas"); + assertThat(kwargs.get("memory_set")).isEqualTo(mockPyMemorySet); + return true; + })); + } + + @Test + void testGetOmitsNullOptionalKwargs() throws Exception { + MemorySet ms = ltm.getMemorySet("notes"); + when(mockAdapter.callMethod(eq(mockPyMem0), eq("get"), any())).thenReturn(null); + when(mockAdapter.invoke(eq("python_java_utils.mem0_items_to_java"), any())) + .thenReturn(null); + + ltm.get(ms, null, null, null); + + verify(mockAdapter) + .callMethod( + eq(mockPyMem0), + eq("get"), + argThat( + kwargs -> { + assertThat(kwargs).containsOnlyKeys("memory_set"); + return true; + })); + } + + @Test + void testGetWithIdsAndFiltersConvertsItems() throws Exception { + MemorySet ms = ltm.getMemorySet("notes"); + when(mockAdapter.callMethod(eq(mockPyMem0), eq("get"), any())).thenReturn("py_items"); + when(mockAdapter.invoke(eq("python_java_utils.mem0_items_to_java"), eq("py_items"))) + .thenReturn( + List.of( + Map.of( + "memory_set_name", "notes", + "id", "id1", + "value", "hello", + "additional_metadata", Map.of("k", "v")))); + + List<MemorySetItem> items = ltm.get(ms, List.of("id1"), Map.of("user_id", "u1"), 50); + + assertThat(items).hasSize(1); + MemorySetItem item = items.get(0); + assertThat(item.getMemorySetName()).isEqualTo("notes"); + assertThat(item.getId()).isEqualTo("id1"); + assertThat(item.getValue()).isEqualTo("hello"); + assertThat(item.getAdditionalMetadata()).containsEntry("k", "v"); + assertThat(item.getCreatedAt()).isNull(); + + verify(mockAdapter) + .callMethod( + eq(mockPyMem0), + eq("get"), + argThat( + kwargs -> { + assertThat(kwargs).containsKeys("ids", "filters", "limit"); + assertThat(kwargs.get("limit")).isEqualTo(50); + return true; + })); + } + + @Test + void testDeleteForwardsIds() throws Exception { + MemorySet ms = ltm.getMemorySet("notes"); + + ltm.delete(ms, List.of("id1", "id2")); + + verify(mockAdapter) + .callMethod( + eq(mockPyMem0), + eq("delete"), + argThat( + kwargs -> { + assertThat(kwargs).containsKeys("memory_set", "ids"); + return true; + })); + } + + @Test + void testDeleteWithoutIdsOmitsKwarg() throws Exception { + MemorySet ms = ltm.getMemorySet("notes"); + + ltm.delete(ms, null); + + verify(mockAdapter) + .callMethod( + eq(mockPyMem0), + eq("delete"), + argThat( + kwargs -> { + assertThat(kwargs).containsOnlyKeys("memory_set"); + return true; + })); + } + + @Test + void testSearchForwardsKwargs() throws Exception { + MemorySet ms = ltm.getMemorySet("notes"); + when(mockAdapter.callMethod(eq(mockPyMem0), eq("search"), any())).thenReturn(null); + when(mockAdapter.invoke(eq("python_java_utils.mem0_items_to_java"), any())) + .thenReturn(null); + + ltm.search(ms, "hi", 10, Map.of("user_id", "u1"), Map.of("threshold", 0.7)); + + verify(mockAdapter) + .callMethod( + eq(mockPyMem0), + eq("search"), + argThat( + kwargs -> { + assertThat(kwargs) + .containsKeys( + "memory_set", + "query", + "limit", + "filters", + "threshold"); + assertThat(kwargs.get("query")).isEqualTo("hi"); + assertThat(kwargs.get("limit")).isEqualTo(10); + assertThat(kwargs.get("threshold")).isEqualTo(0.7); + return true; + })); + } + + @Test + void testSwitchContextAndCloseForward() { + ltm.switchContext("k1"); + ltm.close(); + + verify(mockAdapter) + .callMethod(eq(mockPyMem0), eq("switch_context"), eq(Map.of("key", "k1"))); + verify(mockAdapter).callMethod(eq(mockPyMem0), eq("close"), eq(Map.of())); + } +}
