This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 0f5d08667eef22f22adc03c2b5bc8bcff5dc5ba4 Author: WenjinXie <[email protected]> AuthorDate: Fri Jan 9 17:28:49 2026 +0800 [api][runtime][java] Introduce Long-Term Memory. --- .../agents/api/memory/BaseLongTermMemory.java | 133 +++++++++ .../agents/api/memory/LongTermMemoryOptions.java | 52 ++++ .../apache/flink/agents/api/memory/MemorySet.java | 159 ++++++++++ .../flink/agents/api/memory/MemorySetItem.java | 94 ++++++ .../api/memory/compaction/CompactionStrategy.java | 39 +++ .../memory/compaction/SummarizationStrategy.java | 96 ++++++ .../org/apache/flink/agents/api/prompt/Prompt.java | 41 +++ .../flink/agents/api/memory/MemorySetTest.java | 43 +++ .../agents/runtime/memory/CompactionFunctions.java | 213 ++++++++++++++ .../runtime/memory/InteranlBaseLongTermMemory.java | 31 ++ .../runtime/memory/VectorStoreLongTermMemory.java | 326 +++++++++++++++++++++ 11 files changed, 1227 insertions(+) diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/BaseLongTermMemory.java b/api/src/main/java/org/apache/flink/agents/api/memory/BaseLongTermMemory.java new file mode 100644 index 0000000..cb3b9c5 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/memory/BaseLongTermMemory.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.memory; + +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Map; + +/** + * Base interface for long-term memory management. It provides operations to create, retrieve, + * delete, and search memory sets, which are collections of memory items. A memory set can store + * items of a specific type (e.g., String or ChatMessage) and has a capacity limit. When the + * capacity is exceeded, a compaction strategy is applied to manage the memory set size. + */ +public interface BaseLongTermMemory extends AutoCloseable { + + /** + * Gets an existing memory set or creates a new one if it doesn't exist. + * + * @param name the name of the memory set + * @param itemType the type of items stored in the memory set + * @param capacity the maximum number of items the memory set can hold + * @param strategy the compaction strategy to use when the capacity is exceeded + * @return the existing or newly created memory set + * @throws Exception if the memory set cannot be created or retrieved + */ + MemorySet getOrCreateMemorySet( + String name, Class<?> itemType, int capacity, CompactionStrategy strategy) + throws Exception; + + /** + * Gets an existing memory set by name. + * + * @param name the name of the memory set to retrieve + * @return the memory set with the given name + * @throws Exception if the memory set does not exist or cannot be retrieved + */ + MemorySet getMemorySet(String name) throws Exception; + + /** + * Deletes a memory set by name. + * + * @param name the name of the memory set to delete + * @return true if the memory set was successfully deleted, false if it didn't exist + * @throws Exception if the deletion operation fails + */ + boolean deleteMemorySet(String name) throws Exception; + + /** + * Gets the number of items in the memory set. + * + * @param memorySet the memory set to count items in + * @return the number of items in the memory set + * @throws Exception if the size cannot be determined + */ + long size(MemorySet memorySet) throws Exception; + + /** + * Adds items to the memory set. If IDs are not provided, they will be automatically generated. + * This method may trigger compaction if the memory set capacity is exceeded. + * + * @param memorySet the memory set to add items to + * @param memoryItems the items to be added to the memory set + * @param ids optional list of IDs for the items. If null or shorter than memoryItems, IDs will + * be auto-generated for missing items + * @param metadatas optional list of metadata maps for the items. Each metadata map corresponds + * to an item at the same index + * @return list of IDs of the added items + * @throws Exception if items cannot be added to the memory set + */ + List<String> add( + MemorySet memorySet, + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception; + + /** + * Retrieves memory items from the memory set. If no IDs are provided, all items in the memory + * set are returned. + * + * @param memorySet the memory set to retrieve items from + * @param ids optional list of item IDs to retrieve. If null, all items are returned + * @return list of memory set items. If ids is provided, returns items matching those IDs. If + * ids is null, returns all items in the memory set + * @throws Exception if items cannot be retrieved from the memory set + */ + List<MemorySetItem> get(MemorySet memorySet, @Nullable List<String> ids) throws Exception; + + /** + * Deletes memory items from the memory set. If no IDs are provided, all items in the memory set + * are deleted. + * + * @param memorySet the memory set to delete items from + * @param ids optional list of item IDs to delete. If null, all items in the memory set are + * deleted + * @throws Exception if items cannot be deleted from the memory set + */ + void delete(MemorySet memorySet, @Nullable List<String> ids) throws Exception; + + /** + * Performs semantic search on the memory set to find items related to the query string. + * + * @param memorySet the memory set to search in + * @param query the query string for semantic search + * @param limit the maximum number of items to return + * @param extraArgs additional arguments for the search operation (e.g., filters, distance + * metrics) + * @return list of memory set items that are most relevant to the query, ordered by relevance + * @throws Exception if the search operation fails + */ + List<MemorySetItem> search( + MemorySet memorySet, String query, int limit, Map<String, Object> extraArgs) + throws Exception; +} diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/LongTermMemoryOptions.java b/api/src/main/java/org/apache/flink/agents/api/memory/LongTermMemoryOptions.java new file mode 100644 index 0000000..eab7f50 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/memory/LongTermMemoryOptions.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.memory; + +import org.apache.flink.agents.api.configuration.ConfigOption; + +public class LongTermMemoryOptions { + public enum LongTermMemoryBackend { + EXTERNAL_VECTOR_STORE("external_vector_store"); + + private final String value; + + LongTermMemoryBackend(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + /** The backend for long-term memory. */ + public static final ConfigOption<LongTermMemoryBackend> BACKEND = + new ConfigOption<>("long-term-memory.backend", LongTermMemoryBackend.class, null); + + /** The name of the vector store to server as the backend for long-term memory. */ + public static final ConfigOption<String> EXTERNAL_VECTOR_STORE_NAME = + new ConfigOption<>("long-term-memory.external-vector-store-name", String.class, null); + + /** Whether execute compaction asynchronously . */ + public static final ConfigOption<Boolean> ASYNC_COMPACTION = + new ConfigOption<>("long-term-memory.async-compaction", Boolean.class, false); + + /** The thread count of executor for async compaction. */ + public static final ConfigOption<Integer> THREAD_COUNT = + new ConfigOption<>("long-term-memory.async-compaction.thread-count", Integer.class, 16); +} diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/MemorySet.java b/api/src/main/java/org/apache/flink/agents/api/memory/MemorySet.java new file mode 100644 index 0000000..3807f5e --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/memory/MemorySet.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.memory; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class MemorySet { + private final String name; + private final Class<?> itemType; + private final int capacity; + private final CompactionStrategy strategy; + private @JsonIgnore BaseLongTermMemory ltm; + + @JsonCreator + public MemorySet( + @JsonProperty("name") String name, + @JsonProperty("itemType") Class<?> itemType, + @JsonProperty("capacity") int capacity, + @JsonProperty("strategy") CompactionStrategy strategy) { + this.name = name; + this.itemType = itemType; + this.capacity = capacity; + this.strategy = strategy; + } + + /** + * Gets the number of items in this memory set. + * + * @return the number of items in the memory set + * @throws Exception if the size cannot be determined + */ + public long size() throws Exception { + return this.ltm.size(this); + } + + /** + * Adds items to this memory set. If IDs are not provided, they will be automatically generated. + * This method may trigger compaction if the memory set capacity is exceeded. + * + * @param memoryItems the items to be added to the memory set + * @param ids optional list of IDs for the items. If null or shorter than memoryItems, IDs will + * be auto-generated for missing items + * @param metadatas optional list of metadata maps for the items. Each metadata map corresponds + * to an item at the same index + * @return list of IDs of the added items + * @throws Exception if items cannot be added to the memory set + */ + public List<String> add( + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception { + return this.ltm.add(this, memoryItems, ids, metadatas); + } + + /** + * Retrieves memory items from this memory set. If no IDs are provided, all items in the memory + * set are returned. + * + * @param ids optional list of item IDs to retrieve. If null, all items are returned + * @return list of memory set items. If ids is provided, returns items matching those IDs. If + * ids is null, returns all items in the memory set + * @throws Exception if items cannot be retrieved from the memory set + */ + public List<MemorySetItem> get(@Nullable List<String> ids) throws Exception { + return this.ltm.get(this, ids); + } + + /** + * Performs semantic search on this memory set to find items related to the query string. + * + * @param query the query string for semantic search + * @param limit the maximum number of items to return + * @param extraArgs optional additional arguments for the search operation (e.g., filters, + * distance metrics). If null, an empty map is used + * @return list of memory set items that are most relevant to the query, ordered by relevance + * @throws Exception if the search operation fails + */ + public List<MemorySetItem> search( + String query, int limit, @Nullable Map<String, Object> extraArgs) throws Exception { + return this.ltm.search( + this, query, limit, extraArgs == null ? Collections.emptyMap() : extraArgs); + } + + public void setLtm(BaseLongTermMemory ltm) { + this.ltm = ltm; + } + + public String getName() { + return name; + } + + public Class<?> getItemType() { + return itemType; + } + + public int getCapacity() { + return capacity; + } + + public CompactionStrategy getStrategy() { + return strategy; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + MemorySet memorySet = (MemorySet) o; + return capacity == memorySet.capacity + && Objects.equals(name, memorySet.name) + && Objects.equals(itemType, memorySet.itemType) + && Objects.equals(strategy, memorySet.strategy); + } + + @Override + public int hashCode() { + return Objects.hash(name, itemType, capacity, strategy); + } + + @Override + public String toString() { + return "MemorySet{" + + "name='" + + name + + '\'' + + ", itemType=" + + itemType + + ", capacity=" + + capacity + + ", strategy=" + + strategy + + '}'; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/MemorySetItem.java b/api/src/main/java/org/apache/flink/agents/api/memory/MemorySetItem.java new file mode 100644 index 0000000..b822fbc --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/memory/MemorySetItem.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.memory; + +import java.time.LocalDateTime; +import java.util.Map; + +public class MemorySetItem { + private final String memorySetName; + private final String id; + private final Object value; + private final boolean compacted; + private final Object createdTime; + private final LocalDateTime lastAccessedTime; + private final Map<String, Object> metadata; + + public MemorySetItem( + String memorySetName, + String id, + Object value, + boolean compacted, + Object createdTime, + LocalDateTime lastAccessedTime, + Map<String, Object> metadata) { + this.memorySetName = memorySetName; + this.id = id; + this.value = value; + this.compacted = compacted; + this.createdTime = createdTime; + this.lastAccessedTime = lastAccessedTime; + this.metadata = metadata; + } + + public String getMemorySetName() { + return memorySetName; + } + + public String getId() { + return id; + } + + public Object getValue() { + return value; + } + + public boolean isCompacted() { + return compacted; + } + + public Object getCreatedTime() { + return createdTime; + } + + public LocalDateTime getLastAccessedTime() { + return lastAccessedTime; + } + + public Map<String, Object> getMetadata() { + return metadata; + } + + public static class DateTimeRange { + private final LocalDateTime start; + private final LocalDateTime end; + + public DateTimeRange(LocalDateTime start, LocalDateTime end) { + this.start = start; + this.end = end; + } + + public LocalDateTime getStart() { + return start; + } + + public LocalDateTime getEnd() { + return end; + } + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/compaction/CompactionStrategy.java b/api/src/main/java/org/apache/flink/agents/api/memory/compaction/CompactionStrategy.java new file mode 100644 index 0000000..3a87f6f --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/memory/compaction/CompactionStrategy.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.memory.compaction; + +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) +public interface CompactionStrategy { + enum Type { + SUMMARIZATION("summarization"); + + private final String value; + + Type(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } + + Type type(); +} diff --git a/api/src/main/java/org/apache/flink/agents/api/memory/compaction/SummarizationStrategy.java b/api/src/main/java/org/apache/flink/agents/api/memory/compaction/SummarizationStrategy.java new file mode 100644 index 0000000..ff41e16 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/memory/compaction/SummarizationStrategy.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.memory.compaction; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +import javax.annotation.Nullable; + +import java.util.Objects; + +public class SummarizationStrategy implements CompactionStrategy { + private final String model; + + @JsonTypeInfo( + use = JsonTypeInfo.Id.CLASS, + include = JsonTypeInfo.As.PROPERTY, + property = "@class") + private final Object prompt; + + private final int limit; + + public SummarizationStrategy(String model, int limit) { + this(model, null, limit); + } + + @JsonCreator + public SummarizationStrategy( + @JsonProperty("model") String model, + @Nullable @JsonProperty("prompt") Object prompt, + @JsonProperty("limit") int limit) { + this.model = model; + this.prompt = prompt; + this.limit = limit; + } + + @Override + public Type type() { + return Type.SUMMARIZATION; + } + + public String getModel() { + return model; + } + + public Object getPrompt() { + return prompt; + } + + public int getLimit() { + return limit; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + SummarizationStrategy that = (SummarizationStrategy) o; + return limit == that.limit + && Objects.equals(model, that.model) + && Objects.equals(prompt, that.prompt); + } + + @Override + public int hashCode() { + return Objects.hash(model, prompt, limit); + } + + @Override + public String toString() { + return "SummarizationStrategy{" + + "model='" + + model + + '\'' + + ", prompt=" + + prompt + + ", limit=" + + limit + + '}'; + } +} diff --git a/api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java b/api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java index f688a1e..549c474 100644 --- a/api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java +++ b/api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java @@ -188,6 +188,23 @@ public abstract class Prompt extends SerializableResource { .collect(Collectors.toList())); } + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + LocalPrompt that = (LocalPrompt) o; + return Objects.equals(template, that.template); + } + + @Override + public int hashCode() { + return Objects.hashCode(template); + } + + @Override + public String toString() { + return "LocalPrompt{" + "template=" + template + '}'; + } + /** Format template string with keyword arguments */ private static String format(String template, Map<String, String> kwargs) { if (template == null) { @@ -247,6 +264,18 @@ public abstract class Prompt extends SerializableResource { public String toString() { return "StringTemplate{content='" + content + "'}"; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + StringTemplate that = (StringTemplate) o; + return Objects.equals(content, that.content); + } + + @Override + public int hashCode() { + return Objects.hashCode(content); + } } /** Messages template implementation. */ @@ -279,6 +308,18 @@ public abstract class Prompt extends SerializableResource { public String toString() { return "MessagesTemplate{messages=" + messages.size() + " items}"; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + MessagesTemplate that = (MessagesTemplate) o; + return Objects.equals(messages, that.messages); + } + + @Override + public int hashCode() { + return Objects.hashCode(messages); + } } } } diff --git a/api/src/test/java/org/apache/flink/agents/api/memory/MemorySetTest.java b/api/src/test/java/org/apache/flink/agents/api/memory/MemorySetTest.java new file mode 100644 index 0000000..c0c8e87 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/memory/MemorySetTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.memory; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.prompt.Prompt; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class MemorySetTest { + @Test + public void testJsonSerialization() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + MemorySet memorySet = + new MemorySet( + "test", + ChatMessage.class, + 100, + new SummarizationStrategy( + "testModel", Prompt.fromText("Test prompt"), 100)); + String jsonValue = mapper.writeValueAsString(memorySet); + + MemorySet deserialized = mapper.readValue(jsonValue, MemorySet.class); + Assertions.assertEquals(memorySet, deserialized); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CompactionFunctions.java b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CompactionFunctions.java new file mode 100644 index 0000000..5d52f43 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CompactionFunctions.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.formatter; +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.mapper; + +public class CompactionFunctions { + private static final Logger LOG = LoggerFactory.getLogger(CompactionFunctions.class); + + private static Prompt DEFAULT_ANALYSIS_PROMPT = + Prompt.fromText( + "<role>\n" + + "Context Summarize Assistant\n" + + "</role>\n" + + "\n" + + "<primary_objective>\n" + + "Your sole objective in this task is to summarize the context above.\n" + + "</primary_objective>\n" + + "\n" + + "<objective_information>\n" + + "You're nearing the total number of input tokens you can accept, so you need compact the context. To achieve this objective, you should extract important topics. Notice,\n" + + "**The topics must no more than {limit}**. Afterwards, you should generate summarization for each topic, and record indices of the messages the summary was derived from. " + + "**There are {count} messages totally, indexed from 0 to {end}, DO NOT omit any message, even if irrelevant**. The messages involved in each topic must not overlap, and their union must equal the entire set of messages.\n" + + "</objective_information>\n" + + "\n" + + "<output_example>\n" + + "You must always respond with valid json format in this format:\n" + + "{\"topic1\": {\"summarization\": \"User ask what is 1 * 2, and the result is 3.\", \"messages\": [0,1,2,3]},\n" + + " ...\n" + + " \"topic4\": {\"summarization\": \"User ask what's the weather tomorrow, llm use the search_weather, and the answer is snow.\", \"messages\": [9,10,11,12]}\n" + + "}\n" + + "</output_example>"); + + /** + * Generate summarization of the items in the memory set. + * + * <p>This method will add the summarization to memory set, and delete original items involved + * in summarization. + * + * @param ltm The long term memory the memory set belongs to. + * @param memorySet The memory set to be summarized. + * @param ctx The runner context used to retrieve needed resources. + * @param ids The ids of items to be summarized. If not provided, all items will be involved in + * summarization. Optional. + */ + @SuppressWarnings("unchecked") + public static void summarize( + BaseLongTermMemory ltm, + MemorySet memorySet, + RunnerContext ctx, + @Nullable List<String> ids) + throws Exception { + SummarizationStrategy strategy = (SummarizationStrategy) memorySet.getStrategy(); + + List<MemorySetItem> items = ltm.get(memorySet, ids); + ChatMessage response = generateSummarization(items, memorySet.getItemType(), strategy, ctx); + + LOG.debug("Items to be summarized: {}\n, Summarization: {}", items, response.getContent()); + + Map<String, Map<String, Object>> topics = + mapper.readValue(response.getContent(), Map.class); + + for (Map<String, Object> topic : topics.values()) { + String summarization = (String) topic.get("summarization"); + List<Integer> indices = (List<Integer>) topic.get("messages"); + + if (strategy.getLimit() == 1) { + indices = IntStream.range(0, items.size()).boxed().collect(Collectors.toList()); + } + + Object item; + if (memorySet.getItemType() == ChatMessage.class) { + item = new ChatMessage(MessageRole.USER, summarization); + } else { + item = summarization; + } + + List<LocalDateTime> created_times = new ArrayList<>(); + List<LocalDateTime> lastAccessedTimes = new ArrayList<>(); + List<String> itemIds = new ArrayList<>(); + LocalDateTime start = LocalDateTime.MAX; + LocalDateTime end = LocalDateTime.MAX; + LocalDateTime lastAccessed = LocalDateTime.MIN; + for (int index : indices) { + if (items.get(index).isCompacted()) { + MemorySetItem.DateTimeRange range = + ((MemorySetItem.DateTimeRange) items.get(index).getCreatedTime()); + start = start.isBefore(range.getStart()) ? start : range.getStart(); + end = end.isAfter(range.getEnd()) ? end : range.getEnd(); + } else { + LocalDateTime point = (LocalDateTime) items.get(index).getCreatedTime(); + start = start.isBefore(point) ? start : point; + end = end.isAfter(point) ? end : point; + } + + LocalDateTime point = items.get(index).getLastAccessedTime(); + lastAccessed = lastAccessed.isAfter(point) ? lastAccessed : point; + + itemIds.add(items.get(index).getId()); + } + + ltm.delete(memorySet, itemIds); + + ltm.add( + memorySet, + Collections.singletonList(item), + null, + Collections.singletonList( + Map.of( + "compacted", + true, + "created_time_start", + start.format(formatter), + "created_time_end", + end.format(formatter), + "last_accessed_time", + lastAccessed.format(formatter)))); + } + } + + // TODO: support batched summarize. + private static ChatMessage generateSummarization( + List<MemorySetItem> items, + Class<?> itemType, + SummarizationStrategy strategy, + RunnerContext ctx) + throws Exception { + List<ChatMessage> messages = new ArrayList<>(); + if (itemType == ChatMessage.class) { + for (MemorySetItem item : items) { + messages.add((ChatMessage) item.getValue()); + } + } else { + for (MemorySetItem item : items) { + messages.add(new ChatMessage(MessageRole.USER, String.valueOf(item.getValue()))); + } + } + + BaseChatModelSetup model = + (BaseChatModelSetup) ctx.getResource(strategy.getModel(), ResourceType.CHAT_MODEL); + + Object prompt = strategy.getPrompt(); + if (prompt != null) { + if (prompt instanceof String) { + prompt = ctx.getResource((String) prompt, ResourceType.PROMPT); + } + + Map<String, String> variables = new HashMap<>(); + for (ChatMessage msg : messages) { + for (Map.Entry<String, Object> pair : msg.getExtraArgs().entrySet()) { + variables.put(pair.getKey(), String.valueOf(pair.getValue())); + } + } + + List<ChatMessage> promptMsg = + ((Prompt) prompt).formatMessages(MessageRole.USER, variables); + messages.addAll(promptMsg); + } else { + messages.addAll( + DEFAULT_ANALYSIS_PROMPT.formatMessages( + MessageRole.SYSTEM, + Map.of( + "limit", + String.valueOf(strategy.getLimit()), + "count", + String.valueOf(items.size()), + "end", + String.valueOf(items.size() - 1)))); + } + + return model.chat(messages); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/InteranlBaseLongTermMemory.java b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/InteranlBaseLongTermMemory.java new file mode 100644 index 0000000..b13e2d7 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/InteranlBaseLongTermMemory.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.memory.BaseLongTermMemory; + +/** Internal interface extends {@link BaseLongTermMemory} for hiding some interface to user. */ +public interface InteranlBaseLongTermMemory extends BaseLongTermMemory { + /** + * Switches the context for the memory operations. This allows the same memory instance to be + * used for different key by isolating data based on the provided key. + * + * @param key the context key + */ + void switchContext(String key); +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java new file mode 100644 index 0000000..3d9f6a8 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore.Collection; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.ExecutorUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.WeakHashMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.apache.flink.agents.runtime.memory.CompactionFunctions.summarize; + +public class VectorStoreLongTermMemory implements InteranlBaseLongTermMemory { + public static final ObjectMapper mapper = new ObjectMapper(); + public static final DateTimeFormatter formatter = DateTimeFormatter.ISO_DATE_TIME; + + private final RunnerContext ctx; + private final boolean asyncCompaction; + + private final String jobId; + private Map<String, AtomicBoolean> inCompaction; + private String key; + private transient ExecutorService lazyCompactExecutor; + private Object vectorStore; + + public VectorStoreLongTermMemory(RunnerContext ctx, Object vectorStore, String jobId) { + this(ctx, vectorStore, jobId, null); + } + + @VisibleForTesting + public VectorStoreLongTermMemory( + RunnerContext ctx, Object vectorStore, String jobId, String key) { + this.ctx = ctx; + this.vectorStore = vectorStore; + this.jobId = jobId; + this.key = key; + this.asyncCompaction = ctx.getConfig().get(LongTermMemoryOptions.ASYNC_COMPACTION); + if (this.asyncCompaction) { + inCompaction = new WeakHashMap<>(); + } + } + + @Override + public void switchContext(String key) { + this.key = key; + } + + private BaseVectorStore store() throws Exception { + if (vectorStore instanceof String) { + vectorStore = ctx.getResource((String) vectorStore, ResourceType.VECTOR_STORE); + } + return (BaseVectorStore) vectorStore; + } + + @Override + public MemorySet getOrCreateMemorySet( + String name, Class<?> itemType, int capacity, CompactionStrategy strategy) + throws Exception { + MemorySet memorySet = new MemorySet(name, itemType, capacity, strategy); + ((CollectionManageableVectorStore) this.store()) + .getOrCreateCollection( + this.nameMangling(name), + Map.of("memory_set", mapper.writeValueAsString(memorySet))); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public MemorySet getMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .getCollection(this.nameMangling(name)); + MemorySet memorySet = + mapper.readValue( + (String) collection.getMetadata().get("memory_set"), MemorySet.class); + memorySet.setLtm(this); + return memorySet; + } + + @Override + public boolean deleteMemorySet(String name) throws Exception { + Collection collection = + ((CollectionManageableVectorStore) this.store()) + .deleteCollection(this.nameMangling(name)); + return collection != null; + } + + @Override + public long size(MemorySet memorySet) throws Exception { + return this.store().size(this.nameMangling(memorySet.getName())); + } + + @Override + public List<String> add( + MemorySet memorySet, + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception { + if (ids == null || ids.isEmpty()) { + ids = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + ids.add(UUID.randomUUID().toString()); + } + } + + String timestamp = LocalDateTime.now().format(formatter); + Map<String, Object> metadata = + Map.of( + "compacted", + false, + "created_time", + timestamp, + "last_accessed_time", + timestamp); + + List<Map<String, Object>> mergedMetadatas = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.add(new HashMap<>(metadata)); + } + + if (metadatas != null && !metadatas.isEmpty()) { + for (int i = 0; i < memoryItems.size(); i++) { + mergedMetadatas.get(i).putAll(metadatas.get(i)); + } + } + + List<Document> documents = new ArrayList<>(); + for (int i = 0; i < memoryItems.size(); i++) { + documents.add( + new Document( + mapper.writeValueAsString(memoryItems.get(i)), + mergedMetadatas.get(i), + ids.get(i))); + } + + List<String> itemIds = + this.store() + .add( + documents, + this.nameMangling(memorySet.getName()), + Collections.emptyMap()); + + if (memorySet.size() >= memorySet.getCapacity()) { + if (this.asyncCompaction) { + String name = this.nameMangling(memorySet.getName()); + AtomicBoolean isCompacting = + this.inCompaction.computeIfAbsent(name, k -> new AtomicBoolean(false)); + if (isCompacting.compareAndSet(false, true)) { + CompletableFuture.runAsync( + () -> { + try { + asyncCompact(memorySet, isCompacting); + } catch (Exception e) { + throw new RuntimeException(e); + } + }, + this.workerExecutor()) + .exceptionally( + e -> { + throw new RuntimeException( + String.format( + "Compaction for %s failed", + this.nameMangling(memorySet.getName())), + e); + }); + } + } else { + this.compact(memorySet); + } + } + + return itemIds; + } + + // TODO: get the entire set at once may cause OOM, should support batched get. + @Override + public List<MemorySetItem> get(MemorySet memorySet, @Nullable List<String> ids) + throws Exception { + List<Document> documents = + this.store() + .get(ids, this.nameMangling(memorySet.getName()), Collections.emptyMap()); + return this.convertToItems(memorySet, documents); + } + + @Override + public void delete(MemorySet memorySet, @Nullable List<String> ids) throws Exception { + this.store().delete(ids, this.nameMangling(memorySet.getName()), Collections.emptyMap()); + } + + @Override + public List<MemorySetItem> search( + MemorySet memorySet, String query, int limit, Map<String, Object> extraArgs) + throws Exception { + VectorStoreQuery vectorStoreQuery = + new VectorStoreQuery( + query, limit, this.nameMangling(memorySet.getName()), extraArgs); + VectorStoreQueryResult result = this.store().query(vectorStoreQuery); + return this.convertToItems(memorySet, result.getDocuments()); + } + + private String nameMangling(String name) { + return String.join("-", this.jobId, this.key, name); + } + + private List<MemorySetItem> convertToItems(MemorySet memorySet, List<Document> documents) + throws JsonProcessingException { + List<MemorySetItem> items = new ArrayList<>(); + for (Document doc : documents) { + Map<String, Object> metadata = doc.getMetadata(); + boolean compacted = (boolean) metadata.remove("compacted"); + Object createdTime; + if (compacted) { + createdTime = + new MemorySetItem.DateTimeRange( + LocalDateTime.parse( + (String) metadata.remove("created_time_start"), formatter), + LocalDateTime.parse( + (String) metadata.remove("created_time_end"), formatter)); + } else { + createdTime = + LocalDateTime.parse((String) metadata.remove("created_time"), formatter); + } + MemorySetItem item = + new MemorySetItem( + memorySet.getName(), + doc.getId(), + memorySet.getItemType() == String.class + ? doc.getContent() + : mapper.readValue(doc.getContent(), memorySet.getItemType()), + compacted, + createdTime, + LocalDateTime.parse( + (String) metadata.remove("last_accessed_time"), formatter), + metadata); + items.add(item); + } + return items; + } + + private void compact(MemorySet memorySet) throws Exception { + CompactionStrategy strategy = memorySet.getStrategy(); + if (strategy.type() == CompactionStrategy.Type.SUMMARIZATION) { + summarize(this, memorySet, ctx, null); + } else { + throw new RuntimeException( + String.format("Unknown compaction strategy: %s", strategy.type())); + } + } + + private void asyncCompact(MemorySet memorySet, AtomicBoolean inCompaction) throws Exception { + CompactionStrategy strategy = memorySet.getStrategy(); + if (strategy.type() == CompactionStrategy.Type.SUMMARIZATION) { + summarize(this, memorySet, ctx, null); + } else { + throw new RuntimeException( + String.format("Unknown compaction strategy: %s", strategy.type())); + } + inCompaction.set(false); + } + + private ExecutorService workerExecutor() { + // TODO: shutdown executor when close. + if (lazyCompactExecutor == null) { + int nThreads = ctx.getConfig().get(LongTermMemoryOptions.THREAD_COUNT); + lazyCompactExecutor = + Executors.newFixedThreadPool( + nThreads, + new ExecutorThreadFactory( + Thread.currentThread().getName() + "-ltm-compact-worker")); + } + return lazyCompactExecutor; + } + + @Override + public void close() throws Exception { + if (lazyCompactExecutor != null) { + ExecutorUtils.gracefulShutdown(180, TimeUnit.SECONDS, lazyCompactExecutor); + lazyCompactExecutor = null; + } + } +}
