xintongsong commented on code in PR #425: URL: https://github.com/apache/flink-agents/pull/425#discussion_r2680660061
########## api/src/main/java/org/apache/flink/agents/api/memory/BaseLongTermMemory.java: ########## @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.memory; + +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Map; + +/** + * Base interface for long-term memory management. It provides operations to create, retrieve, + * delete, and search memory sets, which are collections of memory items. A memory set can store + * items of a specific type (e.g., String or ChatMessage) and has a capacity limit. When the + * capacity is exceeded, a compaction strategy is applied to manage the memory set size. + */ +public interface BaseLongTermMemory extends AutoCloseable { + + /** + * Gets an existing memory set or creates a new one if it doesn't exist. + * + * @param name the name of the memory set + * @param itemType the type of items stored in the memory set + * @param capacity the maximum number of items the memory set can hold + * @param strategy the compaction strategy to use when the capacity is exceeded + * @return the existing or newly created memory set + * @throws Exception if the memory set cannot be created or retrieved + */ + MemorySet getOrCreateMemorySet( + String name, Class<?> itemType, int capacity, CompactionStrategy strategy) + throws Exception; + + /** + * Gets an existing memory set by name. + * + * @param name the name of the memory set to retrieve + * @return the memory set with the given name + * @throws Exception if the memory set does not exist or cannot be retrieved + */ + MemorySet getMemorySet(String name) throws Exception; + + /** + * Deletes a memory set by name. + * + * @param name the name of the memory set to delete + * @return true if the memory set was successfully deleted, false if it didn't exist + * @throws Exception if the deletion operation fails + */ + boolean deleteMemorySet(String name) throws Exception; + + /** + * Gets the number of items in the memory set. + * + * @param memorySet the memory set to count items in + * @return the number of items in the memory set + * @throws Exception if the size cannot be determined + */ + long size(MemorySet memorySet) throws Exception; + + /** + * Adds items to the memory set. If IDs are not provided, they will be automatically generated. + * This method may trigger compaction if the memory set capacity is exceeded. + * + * @param memorySet the memory set to add items to + * @param memoryItems the items to be added to the memory set + * @param ids optional list of IDs for the items. If null or shorter than memoryItems, IDs will + * be auto-generated for missing items + * @param metadatas optional list of metadata maps for the items. Each metadata map corresponds + * to an item at the same index + * @return list of IDs of the added items + * @throws Exception if items cannot be added to the memory set + */ + List<String> add( + MemorySet memorySet, + List<?> memoryItems, + @Nullable List<String> ids, + @Nullable List<Map<String, Object>> metadatas) + throws Exception; + + /** + * Retrieves memory items from the memory set. If no IDs are provided, all items in the memory + * set are returned. + * + * @param memorySet the memory set to retrieve items from + * @param ids optional list of item IDs to retrieve. If null, all items are returned + * @return list of memory set items. If ids is provided, returns items matching those IDs. If + * ids is null, returns all items in the memory set + * @throws Exception if items cannot be retrieved from the memory set + */ + List<MemorySetItem> get(MemorySet memorySet, @Nullable List<String> ids) throws Exception; + + /** + * Deletes memory items from the memory set. If no IDs are provided, all items in the memory set + * are deleted. + * + * @param memorySet the memory set to delete items from + * @param ids optional list of item IDs to delete. If null, all items in the memory set are + * deleted + * @throws Exception if items cannot be deleted from the memory set + */ + void delete(MemorySet memorySet, @Nullable List<String> ids) throws Exception; + + /** + * Performs semantic search on the memory set to find items related to the query string. + * + * @param memorySet the memory set to search in + * @param query the query string for semantic search + * @param limit the maximum number of items to return + * @param extraArgs additional arguments for the search operation (e.g., filters, distance + * metrics) + * @return list of memory set items that are most relevant to the query, ordered by relevance + * @throws Exception if the search operation fails + */ + List<MemorySetItem> search( + MemorySet memorySet, String query, int limit, Map<String, Object> extraArgs) + throws Exception; + + /** + * Switches the context for the long-term memory operations. This allows the same memory + * instance to be used for different key by isolating data based on the provided job ID and key. + * + * @param key the context key + */ + void switchContext(String key); Review Comment: This should not be exposed to users. ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java: ########## @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore.Collection; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.ExecutorUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.apache.flink.agents.runtime.memory.CompactionFunctions.summarize; + +public class VectorStoreLongTermMemory implements BaseLongTermMemory { + private static final Logger LOG = LoggerFactory.getLogger(VectorStoreLongTermMemory.class); Review Comment: Never used. ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/VectorStoreLongTermMemory.java: ########## @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.LongTermMemoryOptions; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.CompactionStrategy; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.vectorstores.BaseVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore; +import org.apache.flink.agents.api.vectorstores.CollectionManageableVectorStore.Collection; +import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.api.vectorstores.VectorStoreQuery; +import org.apache.flink.agents.api.vectorstores.VectorStoreQueryResult; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.util.ExecutorUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.apache.flink.agents.runtime.memory.CompactionFunctions.summarize; + +public class VectorStoreLongTermMemory implements BaseLongTermMemory { + private static final Logger LOG = LoggerFactory.getLogger(VectorStoreLongTermMemory.class); + + public static final ObjectMapper mapper = new ObjectMapper(); + public static final DateTimeFormatter formatter = DateTimeFormatter.ISO_DATE_TIME; + + private final RunnerContext ctx; + private final boolean asyncCompaction; + + private String jobId; + private String key; + private transient ExecutorService lazyCompactExecutor; + private Object vectorStore; + + public VectorStoreLongTermMemory(RunnerContext ctx, Object vectorStore, String jobId) { + this(ctx, vectorStore, jobId, null); + } + + @VisibleForTesting + public VectorStoreLongTermMemory( + RunnerContext ctx, Object vectorStore, String jobId, String key) { + this.ctx = ctx; + this.vectorStore = vectorStore; + this.jobId = jobId; + this.key = key; + this.asyncCompaction = ctx.getConfig().get(LongTermMemoryOptions.ASYNC_COMPACTION); + } + + @Override + public void switchContext(String key) { + this.key = key; Review Comment: How is this handled in python? ########## runtime/src/main/java/org/apache/flink/agents/runtime/memory/CompactionFunctions.java: ########## @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.memory; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.memory.BaseLongTermMemory; +import org.apache.flink.agents.api.memory.MemorySet; +import org.apache.flink.agents.api.memory.MemorySetItem; +import org.apache.flink.agents.api.memory.compaction.SummarizationStrategy; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.formatter; +import static org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory.mapper; + +public class CompactionFunctions { + private static final Logger LOG = LoggerFactory.getLogger(CompactionFunctions.class); + + private static Prompt DEFAULT_ANALYSIS_PROMPT = + Prompt.fromText( + "<role>\n" + + "Context Summarize Assistant\n" + + "</role>\n" + + "\n" + + "<primary_objective>\n" + + "Your sole objective in this task is to summarize the context above.\n" + + "</primary_objective>\n" + + "\n" + + "<objective_information>\n" + + "You're nearing the total number of input tokens you can accept, so you need compact the context. To achieve this objective, you should extract important topics. Notice,\n" + + "**The topics must no more than {limit}**. Afterwards, you should generate summarization for each topic, and record indices of the messages the summary was derived from. " + + "**There are {count} messages totally, indexed from 0 to {end}, DO NOT omit any message, even if irrelevant**. The messages involved in each topic must not overlap, and their union must equal the entire set of messages.\n" + + "</objective_information>\n" + + "\n" + + "<output_example>\n" + + "You must always respond with valid json format in this format:\n" + + "{\"topic1\": {\"summarization\": \"User ask what is 1 * 2, and the result is 3.\", \"messages\": [0,1,2,3]},\n" + + " ...\n" + + " \"topic4\": {\"summarization\": \"User ask what's the weather tomorrow, llm use the search_weather, and the answer is snow.\", \"messages\": [9,10,11,12]}\n" + + "}\n" + + "</output_example>"); + + /** + * Generate summarization of the items in the memory set. + * + * <p>This method will add the summarization to memory set, and delete original items involved + * in summarization. + * + * @param ltm The long term memory the memory set belongs to. + * @param memorySet The memory set to be summarized. + * @param ctx The runner context used to retrieve needed resources. + * @param ids The ids of items to be summarized. If not provided, all items will be involved in + * summarization. Optional. + */ + @SuppressWarnings("unchecked") + public static void summarize( Review Comment: It's weird that we need to duplicate this for python & java. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
