This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 551994711be8ddc663aeeb50fc68ffff3ab209f7 Author: WenjinXie <[email protected]> AuthorDate: Sat Jan 17 02:10:43 2026 +0800 [runtime] Report token usage metric out of compaction for long-term memory. --- python/flink_agents/api/memory/long_term_memory.py | 4 ++ python/flink_agents/api/runner_context.py | 5 +- .../flink_agents/runtime/flink_runner_context.py | 47 +++++++++------ python/flink_agents/runtime/local_runner.py | 2 +- .../runtime/memory/compaction_functions.py | 16 +++-- .../memory/internal_base_long_term_memory.py | 36 ++++++++++++ .../tests/test_vector_store_long_term_memory.py | 4 ++ .../memory/vector_store_long_term_memory.py | 68 +++++++++++++++++++--- .../runtime/python/utils/PythonActionExecutor.java | 8 +-- 9 files changed, 151 insertions(+), 39 deletions(-) diff --git a/python/flink_agents/api/memory/long_term_memory.py b/python/flink_agents/api/memory/long_term_memory.py index da673c81..bedc5741 100644 --- a/python/flink_agents/api/memory/long_term_memory.py +++ b/python/flink_agents/api/memory/long_term_memory.py @@ -340,3 +340,7 @@ class BaseLongTermMemory(ABC, BaseModel): Returns: Related memory items retrieved. """ + + @abstractmethod + def close(self) -> None: + """Logic executed when job close.""" diff --git a/python/flink_agents/api/runner_context.py b/python/flink_agents/api/runner_context.py index 8b44efeb..4e3bc303 100644 --- a/python/flink_agents/api/runner_context.py +++ b/python/flink_agents/api/runner_context.py @@ -90,7 +90,7 @@ class RunnerContext(ABC): """ @abstractmethod - def get_resource(self, name: str, type: ResourceType) -> Resource: + def get_resource(self, name: str, type: ResourceType, metric_group: MetricGroup = None) -> Resource: """Get resource from context. Parameters @@ -99,6 +99,9 @@ class RunnerContext(ABC): The name of the resource. type : ResourceType The type of the resource. + metric_group: MetricGroup + The metric group used for reporting the metric. If not provided, + will use the action metric group. """ @property diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index dffdaf73..994f095e 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -32,11 +32,15 @@ from flink_agents.api.memory.long_term_memory import ( LongTermMemoryOptions, ) from flink_agents.api.memory_object import MemoryType +from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.runner_context import AsyncExecutionResult, RunnerContext from flink_agents.plan.agent_plan import AgentPlan from flink_agents.runtime.flink_memory_object import FlinkMemoryObject from flink_agents.runtime.flink_metric_group import FlinkMetricGroup +from flink_agents.runtime.memory.internal_base_long_term_memory import ( + InternalBaseLongTermMemory, +) from flink_agents.runtime.memory.vector_store_long_term_memory import ( VectorStoreLongTermMemory, ) @@ -174,7 +178,7 @@ class FlinkRunnerContext(RunnerContext): """ __agent_plan: AgentPlan | None - __ltm: BaseLongTermMemory = None + __ltm: InternalBaseLongTermMemory = None def __init__( self, @@ -195,7 +199,7 @@ class FlinkRunnerContext(RunnerContext): self.__agent_plan.set_java_resource_adapter(j_resource_adapter) self.executor = executor - def set_long_term_memory(self, ltm: BaseLongTermMemory) -> None: + def set_long_term_memory(self, ltm: InternalBaseLongTermMemory) -> None: """Set long term memory instance to this context. Parameters @@ -224,10 +228,10 @@ class FlinkRunnerContext(RunnerContext): raise RuntimeError(err_msg) from e @override - def get_resource(self, name: str, type: ResourceType) -> Resource: + def get_resource(self, name: str, type: ResourceType, metric_group: MetricGroup = None) -> Resource: resource = self.__agent_plan.get_resource(name, type) - # Bind current action's metric group to the resource - resource.set_metric_group(self.action_metric_group) + # Bind metric group to the resource + resource.set_metric_group(metric_group or self.action_metric_group) return resource @property @@ -488,6 +492,9 @@ class FlinkRunnerContext(RunnerContext): @override def close(self) -> None: + if self.long_term_memory is not None: + self.long_term_memory.close() + if self.__agent_plan is not None: try: self.__agent_plan.close() @@ -500,23 +507,13 @@ def create_flink_runner_context( agent_plan_json: str, executor: ThreadPoolExecutor, j_resource_adapter: Any, + job_identifier: str, ) -> FlinkRunnerContext: """Used to create a FlinkRunnerContext Python object in Pemja environment.""" - return FlinkRunnerContext( + ctx = FlinkRunnerContext( j_runner_context, agent_plan_json, executor, j_resource_adapter ) - -def flink_runner_context_switch_action_context( - ctx: FlinkRunnerContext, - job_identifier: str, - key: int, -) -> None: - """Switch the context of the flink runner context. - - The ctx is reused across keyed partitions, the context related to - specific key should be switched when process new action. - """ backend = ctx.config.get(LongTermMemoryOptions.BACKEND) # use external vector store based long term memory if backend == LongTermMemoryBackend.EXTERNAL_VECTOR_STORE: @@ -528,10 +525,24 @@ def flink_runner_context_switch_action_context( ctx=ctx, vector_store=vector_store_name, job_id=job_identifier, - key=str(key), ) ) + return ctx + + +def flink_runner_context_switch_action_context( + ctx: FlinkRunnerContext, + key: int, +) -> None: + """Switch the context of the flink runner context. + + The ctx is reused across keyed partitions, the context related to + specific key should be switched when process new action. + """ + if ctx.long_term_memory is not None: + ctx.long_term_memory.switch_context(str(key)) + def close_flink_runner_context( ctx: FlinkRunnerContext, ) -> None: diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index 64ef4cc6..b8eaf3f0 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -119,7 +119,7 @@ class LocalRunnerContext(RunnerContext): self.events.append(event) @override - def get_resource(self, name: str, type: ResourceType) -> Resource: + def get_resource(self, name: str, type: ResourceType, metric_group: MetricGroup = None) -> Resource: return self.__agent_plan.get_resource(name, type) @property diff --git a/python/flink_agents/runtime/memory/compaction_functions.py b/python/flink_agents/runtime/memory/compaction_functions.py index 2bfeed12..669df4d5 100644 --- a/python/flink_agents/runtime/memory/compaction_functions.py +++ b/python/flink_agents/runtime/memory/compaction_functions.py @@ -17,7 +17,7 @@ ################################################################################# import json import logging -from typing import TYPE_CHECKING, List, Type, cast +from typing import TYPE_CHECKING, Any, Dict, List, Type, cast from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.memory.long_term_memory import ( @@ -26,6 +26,7 @@ from flink_agents.api.memory.long_term_memory import ( MemorySetItem, SummarizationStrategy, ) +from flink_agents.api.metric_group import MetricGroup from flink_agents.api.prompts.prompt import Prompt from flink_agents.api.resource import ResourceType from flink_agents.api.runner_context import RunnerContext @@ -62,8 +63,9 @@ def summarize( ltm: BaseLongTermMemory, memory_set: MemorySet, ctx: RunnerContext, + metric_group: MetricGroup, ids: List[str] | None = None, -) -> None: +) -> Dict[str, Any]: """Generate summarization of the items in the memory set. Will add the summarization to memory set, and delete original items involved @@ -73,6 +75,7 @@ def summarize( ltm: The long term memory the memory set belongs to. memory_set: The memory set to be summarized. ctx: The runner context used to retrieve needed resources. + metric_group: Metric group used to report metrics. ids: The ids of items to be summarized. If not provided, all items will be involved in summarization. Optional """ @@ -84,7 +87,7 @@ def summarize( items: List[MemorySetItem] = ltm.get(memory_set=memory_set, ids=ids) response: ChatMessage = _generate_summarization( - items, memory_set.item_type, strategy, ctx + items, memory_set.item_type, strategy, ctx, metric_group ) logging.debug(f"Items to be summarized: {items}\nSummarization: {response.content}") @@ -131,6 +134,8 @@ def summarize( }, ) + return response.extra_args + # TODO: Currently, we feed all items to the LLM at once, which may exceed the LLM's # context window. We need to support batched summary generation. @@ -139,6 +144,7 @@ def _generate_summarization( item_type: Type, strategy: SummarizationStrategy, ctx: RunnerContext, + metric_group: MetricGroup ) -> ChatMessage: """Generate summarization of the items by llm.""" # get arguments @@ -157,7 +163,7 @@ def _generate_summarization( # generate summary model: BaseChatModelSetup = cast( "BaseChatModelSetup", - ctx.get_resource(name=model_name, type=ResourceType.CHAT_MODEL), + ctx.get_resource(name=model_name, type=ResourceType.CHAT_MODEL, metric_group=metric_group), ) input_variable = {} for msg in msgs: @@ -167,7 +173,7 @@ def _generate_summarization( if isinstance(prompt, str): prompt: Prompt = cast( "Prompt", - ctx.get_resource(prompt, ResourceType.PROMPT), + ctx.get_resource(prompt, ResourceType.PROMPT, metric_group=metric_group), ) prompt_messages = prompt.format_messages( role=MessageRole.USER, **input_variable diff --git a/python/flink_agents/runtime/memory/internal_base_long_term_memory.py b/python/flink_agents/runtime/memory/internal_base_long_term_memory.py new file mode 100644 index 00000000..96ff4558 --- /dev/null +++ b/python/flink_agents/runtime/memory/internal_base_long_term_memory.py @@ -0,0 +1,36 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from abc import ABC, abstractmethod + +from flink_agents.api.memory.long_term_memory import BaseLongTermMemory + + +class InternalBaseLongTermMemory(BaseLongTermMemory, ABC): + """Internal interface extends BaseLongTermMemory for hiding some interface + to user. + """ + + @abstractmethod + def switch_context(self, key: str) -> None: + """Switches the context for the memory operations. This allows + the same memory instance to be used for different key by isolating + data based on the provided key. + + Args: + key: The context key. + """ diff --git a/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py b/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py index 9eb3c8b7..8465c2b5 100644 --- a/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py +++ b/python/flink_agents/runtime/memory/tests/test_vector_store_long_term_memory.py @@ -33,6 +33,7 @@ from flink_agents.api.memory.long_term_memory import ( MemorySet, SummarizationStrategy, ) +from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.runner_context import RunnerContext from flink_agents.integrations.chat_models.ollama_chat_model import ( @@ -118,6 +119,9 @@ def long_term_memory() -> VectorStoreLongTermMemory: # noqa: D103 mock_runner_context = create_autospec(RunnerContext, instance=True) mock_runner_context.get_resource = get_resource + mock_runner_context.agent_metric_group.get_sub_group.return_value = create_autospec( + MetricGroup, instance=True + ) return VectorStoreLongTermMemory( ctx=mock_runner_context, diff --git a/python/flink_agents/runtime/memory/vector_store_long_term_memory.py b/python/flink_agents/runtime/memory/vector_store_long_term_memory.py index a4cf203d..d43c9b50 100644 --- a/python/flink_agents/runtime/memory/vector_store_long_term_memory.py +++ b/python/flink_agents/runtime/memory/vector_store_long_term_memory.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################# import functools +import queue import uuid from concurrent.futures import Future from datetime import datetime, timezone @@ -26,7 +27,6 @@ from typing_extensions import override from flink_agents.api.chat_message import ChatMessage from flink_agents.api.memory.long_term_memory import ( - BaseLongTermMemory, CompactionStrategy, CompactionStrategyType, DatetimeRange, @@ -35,6 +35,7 @@ from flink_agents.api.memory.long_term_memory import ( MemorySet, MemorySetItem, ) +from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import ResourceType from flink_agents.api.runner_context import RunnerContext from flink_agents.api.vector_stores.vector_store import ( @@ -44,10 +45,13 @@ from flink_agents.api.vector_stores.vector_store import ( _maybe_cast_to_list, ) from flink_agents.runtime.memory.compaction_functions import summarize +from flink_agents.runtime.memory.internal_base_long_term_memory import ( + InternalBaseLongTermMemory, +) # TODO: support async execution for operations and compaction -class VectorStoreLongTermMemory(BaseLongTermMemory): +class VectorStoreLongTermMemory(InternalBaseLongTermMemory): """Long-Term Memory based on ChromaDB.""" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -62,19 +66,27 @@ class VectorStoreLongTermMemory(BaseLongTermMemory): job_id: str = Field(description="Unique identifier for the job.") - key: str = Field(description="Unique identifier for the keyed partition.") + key: str = Field( + default=None, description="Unique identifier for the keyed partition." + ) async_compaction: bool = Field( default=False, description="Whether to execute compact asynchronously." ) + metric_group: MetricGroup = Field( + default=None, description="Metric group for reporting long-term memory metrics." + ) + metric_records: queue.Queue = Field( + default=queue.Queue(), description="A thread safe queue for record metrics." + ) + def __init__( self, *, ctx: RunnerContext, vector_store: str, job_id: str, - key: str, **kwargs: Any, ) -> None: """Init method.""" @@ -82,11 +94,15 @@ class VectorStoreLongTermMemory(BaseLongTermMemory): ctx=ctx, vector_store=vector_store, job_id=job_id, - key=key, async_compaction=ctx.config.get(LongTermMemoryOptions.ASYNC_COMPACTION), + metric_group=ctx.agent_metric_group.get_sub_group("long-term-memory"), **kwargs, ) + @override + def switch_context(self, key: str) -> None: + self.key = key + @property def store(self) -> CollectionManageableVectorStore: """Get backend vector store. @@ -186,15 +202,23 @@ class VectorStoreLongTermMemory(BaseLongTermMemory): if memory_set.size >= memory_set.capacity: # trigger compaction if self.async_compaction: - future = self.ctx.executor.submit(self._compact, memory_set=memory_set) + future = self.ctx.executor.submit( + self._compact, + memory_set=memory_set, + metric_group=self.metric_group, + ) future.add_done_callback( functools.partial( self._handle_exception, self.job_id, self.key, memory_set ) ) else: - self._compact(memory_set=memory_set) + self._compact( + memory_set=memory_set, + metric_group=self.metric_group, + ) + self._report_token_metrics() return ids @override @@ -224,16 +248,42 @@ class VectorStoreLongTermMemory(BaseLongTermMemory): return self._convert_to_items(memory_set=memory_set, documents=result.documents) + @override + def close(self) -> None: + # report possible token usage metrics + self._report_token_metrics() + + def _report_token_metrics(self) -> None: + """Report token usage metrics.""" + if not self.metric_records.empty(): + if self.metric_group is None: + return + while not self.metric_records.empty(): + metric = self.metric_records.get() + if ( + metric.get("model_name") + and metric.get("promptTokens") + and metric.get("completionTokens") + ): + model_group = self.metric_group.get_sub_group(metric["model_name"]) + model_group.get_counter("promptTokens").inc(metric["promptTokens"]) + model_group.get_counter("completionTokens").inc( + metric["completionTokens"] + ) + def _name_mangling(self, name: str) -> str: """Mangle memory set name to actually name in vector store.""" return f"{self.job_id}-{self.key}-{name}" - def _compact(self, memory_set: MemorySet) -> None: + def _compact(self, memory_set: MemorySet, metric_group: MetricGroup) -> Any | None: """Compact memory set to manage storge.""" compaction_strategy: CompactionStrategy = memory_set.compaction_strategy if compaction_strategy.type == CompactionStrategyType.SUMMARIZATION: # currently, only support summarize all the items. - summarize(ltm=self, memory_set=memory_set, ctx=self.ctx) + extra_args = summarize( + ltm=self, memory_set=memory_set, ctx=self.ctx, metric_group=metric_group + ) + self.metric_records.put(extra_args) else: msg = f"Unknown compaction strategy: {compaction_strategy.type}" raise RuntimeError(msg) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java index 3f9ad702..55d38187 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java @@ -97,7 +97,8 @@ public class PythonActionExecutor { runnerContext, agentPlanJson, pythonAsyncThreadPool, - javaResourceAdapter); + javaResourceAdapter, + jobIdentifier); } /** @@ -116,10 +117,7 @@ public class PythonActionExecutor { function.setInterpreter(interpreter); interpreter.invoke( - FLINK_RUNNER_CONTEXT_SWITCH_ACTION_CONTEXT, - pythonRunnerContext, - jobIdentifier, - hashOfKey); + FLINK_RUNNER_CONTEXT_SWITCH_ACTION_CONTEXT, pythonRunnerContext, hashOfKey); Object pythonEventObject = interpreter.invoke(CONVERT_TO_PYTHON_OBJECT, event.getEvent());
