This is an automated email from the ASF dual-hosted git repository.

xintongsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new e98af528 [api][plan][integrations] Record built-in chat token metrics 
outside the async call boundary (#712)
e98af528 is described below

commit e98af52851a580459da0f52c1123b942779c9182
Author: Weiqing Yang <[email protected]>
AuthorDate: Sun May 31 01:16:27 2026 -0700

    [api][plan][integrations] Record built-in chat token metrics outside the 
async call boundary (#712)
---
 .../api/chat/model/BaseChatModelConnection.java    | 19 -----
 .../agents/api/chat/model/BaseChatModelSetup.java  | 20 ++++-
 .../flink/agents/api/context/RunnerContext.java    |  8 ++
 ...ava => BaseChatModelSetupTokenMetricsTest.java} | 76 +++++++------------
 .../anthropic/AnthropicChatModelConnection.java    |  7 +-
 .../azureai/AzureAIChatModelConnection.java        |  9 ++-
 .../bedrock/BedrockChatModelConnection.java        | 10 ++-
 .../ollama/OllamaChatModelConnection.java          |  7 +-
 .../openai/AzureOpenAIChatModelConnection.java     |  9 ++-
 .../openai/OpenAICompletionsConnection.java        | 11 +--
 .../openai/OpenAIResponsesModelConnection.java     |  8 +-
 .../flink/agents/plan/actions/ChatModelAction.java | 18 +++++
 .../agents/plan/actions/ChatModelActionTest.java   | 85 ++++++++++++++++++++++
 13 files changed, 190 insertions(+), 97 deletions(-)

diff --git 
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
 
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
index a6dccc44..7ce69b6d 100644
--- 
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
+++ 
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnection.java
@@ -19,7 +19,6 @@
 package org.apache.flink.agents.api.chat.model;
 
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
 import org.apache.flink.agents.api.resource.Resource;
 import org.apache.flink.agents.api.resource.ResourceContext;
 import org.apache.flink.agents.api.resource.ResourceDescriptor;
@@ -56,22 +55,4 @@ public abstract class BaseChatModelConnection extends 
Resource {
      */
     public abstract ChatMessage chat(
             List<ChatMessage> messages, List<Tool> tools, Map<String, Object> 
arguments);
-
-    /**
-     * Record token usage metrics for the given model.
-     *
-     * @param modelName the name of the model used
-     * @param promptTokens the number of prompt tokens
-     * @param completionTokens the number of completion tokens
-     */
-    protected void recordTokenMetrics(String modelName, long promptTokens, 
long completionTokens) {
-        FlinkAgentsMetricGroup metricGroup = getMetricGroup();
-        if (metricGroup == null) {
-            return;
-        }
-
-        FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
-        modelGroup.getCounter("promptTokens").inc(promptTokens);
-        modelGroup.getCounter("completionTokens").inc(completionTokens);
-    }
 }
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
 
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
index 858cd069..3a9c7b2d 100644
--- 
a/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
+++ 
b/api/src/main/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetup.java
@@ -20,6 +20,7 @@ package org.apache.flink.agents.api.chat.model;
 
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
 import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
 import org.apache.flink.agents.api.prompt.Prompt;
 import org.apache.flink.agents.api.resource.Resource;
 import org.apache.flink.agents.api.resource.ResourceContext;
@@ -107,6 +108,23 @@ public abstract class BaseChatModelSetup extends Resource {
 
     public abstract Map<String, Object> getParameters();
 
+    /**
+     * Record token usage metrics for the given model on this setup's bound 
metric group.
+     *
+     * @param modelName the name of the model used
+     * @param promptTokens the number of prompt tokens
+     * @param completionTokens the number of completion tokens
+     */
+    public void recordTokenMetrics(String modelName, long promptTokens, long 
completionTokens) {
+        FlinkAgentsMetricGroup metricGroup = getMetricGroup();
+        if (metricGroup == null) {
+            return;
+        }
+        FlinkAgentsMetricGroup modelGroup = metricGroup.getSubGroup(modelName);
+        modelGroup.getCounter("promptTokens").inc(promptTokens);
+        modelGroup.getCounter("completionTokens").inc(completionTokens);
+    }
+
     public ChatMessage chat(List<ChatMessage> messages) {
         return this.chat(messages, Collections.emptyMap(), 
Collections.emptyMap());
     }
@@ -118,8 +136,6 @@ public abstract class BaseChatModelSetup extends Resource {
         Preconditions.checkNotNull(
                 connection,
                 "Connection is not initialized. Ensure open() is called before 
chat().");
-        // Pass metric group to connection for token usage tracking
-        connection.setMetricGroup(getMetricGroup());
 
         // Format input messages if set prompt.
         if (this.prompt != null) {
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java 
b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
index 06cd2b38..c3e5d19b 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
@@ -67,6 +67,10 @@ public interface RunnerContext {
     /**
      * Gets the metric group for Flink Agents.
      *
+     * <p>The returned group must only be accessed from the operator/mailbox 
(action) thread, not
+     * from inside a {@link #durableExecute} or {@link #durableExecuteAsync} 
callable, which runs on
+     * a separate thread pool.
+     *
      * @return the metric group shared across all actions.
      */
     FlinkAgentsMetricGroup getAgentMetricGroup();
@@ -74,6 +78,10 @@ public interface RunnerContext {
     /**
      * Gets the individual metric group dedicated for each action.
      *
+     * <p>The returned group must only be accessed from the operator/mailbox 
(action) thread, not
+     * from inside a {@link #durableExecute} or {@link #durableExecuteAsync} 
callable, which runs on
+     * a separate thread pool.
+     *
      * @return the individual metric group specific to the current action.
      */
     FlinkAgentsMetricGroup getActionMetricGroup();
diff --git 
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
 
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
similarity index 59%
rename from 
api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
rename to 
api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
index 43654944..cde9f683 100644
--- 
a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelConnectionTokenMetricsTest.java
+++ 
b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelSetupTokenMetricsTest.java
@@ -18,71 +18,60 @@
 
 package org.apache.flink.agents.api.chat.model;
 
-import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.chat.messages.MessageRole;
 import org.apache.flink.agents.api.metrics.FlinkAgentsMetricGroup;
 import org.apache.flink.agents.api.resource.ResourceContext;
 import org.apache.flink.agents.api.resource.ResourceDescriptor;
 import org.apache.flink.agents.api.resource.ResourceType;
-import org.apache.flink.agents.api.tools.Tool;
 import org.apache.flink.metrics.Counter;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.DisplayName;
 import org.junit.jupiter.api.Test;
 
 import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 
-import static org.junit.jupiter.api.Assertions.*;
-import static org.mockito.Mockito.*;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
 
-/** Test cases for BaseChatModelConnection token metrics functionality. */
-class BaseChatModelConnectionTokenMetricsTest {
+/** Test cases for BaseChatModelSetup token metrics functionality. */
+class BaseChatModelSetupTokenMetricsTest {
 
-    private TestChatModelConnection connection;
+    private TestChatModelSetup setup;
     private FlinkAgentsMetricGroup mockMetricGroup;
     private FlinkAgentsMetricGroup mockModelGroup;
     private Counter mockPromptTokensCounter;
     private Counter mockCompletionTokensCounter;
 
-    /** Test implementation of BaseChatModelConnection for testing purposes. */
-    private static class TestChatModelConnection extends 
BaseChatModelConnection {
+    /** Test implementation of BaseChatModelSetup for testing purposes. */
+    private static class TestChatModelSetup extends BaseChatModelSetup {
 
-        public TestChatModelConnection(
-                ResourceDescriptor descriptor, ResourceContext 
resourceContext) {
+        public TestChatModelSetup(ResourceDescriptor descriptor, 
ResourceContext resourceContext) {
             super(descriptor, resourceContext);
         }
 
         @Override
-        public ChatMessage chat(
-                List<ChatMessage> messages, List<Tool> tools, Map<String, 
Object> arguments) {
-            // Simple test implementation
-            return new ChatMessage(MessageRole.ASSISTANT, "Test response");
-        }
-
-        // Expose protected method for testing
-        public void testRecordTokenMetrics(
-                String modelName, long promptTokens, long completionTokens) {
-            recordTokenMetrics(modelName, promptTokens, completionTokens);
+        public Map<String, Object> getParameters() {
+            return Collections.emptyMap();
         }
     }
 
     @BeforeEach
     void setUp() {
-        connection =
-                new TestChatModelConnection(
+        setup =
+                new TestChatModelSetup(
                         new ResourceDescriptor(
-                                TestChatModelConnection.class.getName(), 
Collections.emptyMap()),
+                                TestChatModelSetup.class.getName(), 
Collections.emptyMap()),
                         null);
 
-        // Create mock objects
         mockMetricGroup = mock(FlinkAgentsMetricGroup.class);
         mockModelGroup = mock(FlinkAgentsMetricGroup.class);
         mockPromptTokensCounter = mock(Counter.class);
         mockCompletionTokensCounter = mock(Counter.class);
 
-        // Set up mock behavior
         when(mockMetricGroup.getSubGroup("gpt-4")).thenReturn(mockModelGroup);
         
when(mockModelGroup.getCounter("promptTokens")).thenReturn(mockPromptTokensCounter);
         
when(mockModelGroup.getCounter("completionTokens")).thenReturn(mockCompletionTokensCounter);
@@ -91,13 +80,10 @@ class BaseChatModelConnectionTokenMetricsTest {
     @Test
     @DisplayName("Test token metrics are recorded when metric group is set")
     void testRecordTokenMetricsWithMetricGroup() {
-        // Set the metric group
-        connection.setMetricGroup(mockMetricGroup);
+        setup.setMetricGroup(mockMetricGroup);
 
-        // Record token metrics
-        connection.testRecordTokenMetrics("gpt-4", 100, 50);
+        setup.recordTokenMetrics("gpt-4", 100, 50);
 
-        // Verify the metrics were recorded
         verify(mockMetricGroup).getSubGroup("gpt-4");
         verify(mockModelGroup).getCounter("promptTokens");
         verify(mockModelGroup).getCounter("completionTokens");
@@ -108,22 +94,16 @@ class BaseChatModelConnectionTokenMetricsTest {
     @Test
     @DisplayName("Test token metrics are not recorded when metric group is 
null")
     void testRecordTokenMetricsWithoutMetricGroup() {
-        // Do not set metric group (should be null by default)
+        assertDoesNotThrow(() -> setup.recordTokenMetrics("gpt-4", 100, 50));
 
-        // Record token metrics - should not throw
-        assertDoesNotThrow(() -> connection.testRecordTokenMetrics("gpt-4", 
100, 50));
-
-        // No metrics should be recorded
         verifyNoInteractions(mockMetricGroup);
     }
 
     @Test
-    @DisplayName("Test token metrics hierarchy: actionMetricGroup -> modelName 
-> counters")
+    @DisplayName("Test token metrics hierarchy: metricGroup -> modelName -> 
counters")
     void testTokenMetricsHierarchy() {
-        // Set the metric group
-        connection.setMetricGroup(mockMetricGroup);
+        setup.setMetricGroup(mockMetricGroup);
 
-        // Record token metrics for different models
         FlinkAgentsMetricGroup mockGpt35Group = 
mock(FlinkAgentsMetricGroup.class);
         Counter mockGpt35PromptCounter = mock(Counter.class);
         Counter mockGpt35CompletionCounter = mock(Counter.class);
@@ -132,13 +112,9 @@ class BaseChatModelConnectionTokenMetricsTest {
         
when(mockGpt35Group.getCounter("promptTokens")).thenReturn(mockGpt35PromptCounter);
         
when(mockGpt35Group.getCounter("completionTokens")).thenReturn(mockGpt35CompletionCounter);
 
-        // Record for gpt-4
-        connection.testRecordTokenMetrics("gpt-4", 100, 50);
-
-        // Record for gpt-3.5-turbo
-        connection.testRecordTokenMetrics("gpt-3.5-turbo", 200, 100);
+        setup.recordTokenMetrics("gpt-4", 100, 50);
+        setup.recordTokenMetrics("gpt-3.5-turbo", 200, 100);
 
-        // Verify each model has its own counters
         verify(mockMetricGroup).getSubGroup("gpt-4");
         verify(mockMetricGroup).getSubGroup("gpt-3.5-turbo");
         verify(mockPromptTokensCounter).inc(100);
@@ -148,8 +124,8 @@ class BaseChatModelConnectionTokenMetricsTest {
     }
 
     @Test
-    @DisplayName("Test resource type is CHAT_MODEL_CONNECTION")
+    @DisplayName("Test resource type is CHAT_MODEL")
     void testResourceType() {
-        assertEquals(ResourceType.CHAT_MODEL_CONNECTION, 
connection.getResourceType());
+        assertEquals(ResourceType.CHAT_MODEL, setup.getResourceType());
     }
 }
diff --git 
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
 
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
index 248b464e..93691d3f 100644
--- 
a/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
+++ 
b/integrations/chat-models/anthropic/src/main/java/org/apache/flink/agents/integrations/chatmodels/anthropic/AnthropicChatModelConnection.java
@@ -133,7 +133,7 @@ public class AnthropicChatModelConnection extends 
BaseChatModelConnection {
             Message response = client.messages().create(params);
             ChatMessage result = convertResponse(response, jsonPrefillApplied);
 
-            // Record token metrics
+            // Stash token usage
             String modelName = null;
             if (arguments != null && arguments.get("model") != null) {
                 modelName = arguments.get("model").toString();
@@ -142,8 +142,9 @@ public class AnthropicChatModelConnection extends 
BaseChatModelConnection {
                 modelName = this.defaultModel;
             }
             if (modelName != null && !modelName.isBlank()) {
-                recordTokenMetrics(
-                        modelName, response.usage().inputTokens(), 
response.usage().outputTokens());
+                result.getExtraArgs().put("model_name", modelName);
+                result.getExtraArgs().put("promptTokens", 
response.usage().inputTokens());
+                result.getExtraArgs().put("completionTokens", 
response.usage().outputTokens());
             }
 
             return result;
diff --git 
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
 
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
index 3051ecf4..318b5457 100644
--- 
a/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
+++ 
b/integrations/chat-models/azureai/src/main/java/org/apache/flink/agents/integrations/chatmodels/azureai/AzureAIChatModelConnection.java
@@ -187,12 +187,15 @@ public class AzureAIChatModelConnection extends 
BaseChatModelConnection {
                 chatMessage.setToolCalls(convertedToolCalls);
             }
 
-            // Record token metrics if model name is available
+            // Stash token usage if model name is available
             if (modelName != null && !modelName.isBlank()) {
                 CompletionsUsage usage = completions.getUsage();
                 if (usage != null) {
-                    recordTokenMetrics(
-                            modelName, usage.getPromptTokens(), 
usage.getCompletionTokens());
+                    chatMessage.getExtraArgs().put("model_name", modelName);
+                    chatMessage.getExtraArgs().put("promptTokens", (long) 
usage.getPromptTokens());
+                    chatMessage
+                            .getExtraArgs()
+                            .put("completionTokens", (long) 
usage.getCompletionTokens());
                 }
             }
 
diff --git 
a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java
 
b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java
index 8327795a..58d23508 100644
--- 
a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java
+++ 
b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java
@@ -178,12 +178,14 @@ public class BedrockChatModelConnection extends 
BaseChatModelConnection {
         ConverseResponse response =
                 retryExecutor.execute(() -> client.converse(request), 
"BedrockConverse");
 
+        ChatMessage result = convertResponse(response);
         if (response.usage() != null) {
-            recordTokenMetrics(
-                    modelId, response.usage().inputTokens(), 
response.usage().outputTokens());
+            result.getExtraArgs().put("model_name", modelId);
+            result.getExtraArgs().put("promptTokens", 
response.usage().inputTokens().longValue());
+            result.getExtraArgs()
+                    .put("completionTokens", 
response.usage().outputTokens().longValue());
         }
-
-        return convertResponse(response);
+        return result;
     }
 
     private static boolean isRetryable(Exception e) {
diff --git 
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
 
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
index 2cda1ea4..4c617455 100644
--- 
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
+++ 
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelConnection.java
@@ -224,13 +224,14 @@ public class OllamaChatModelConnection extends 
BaseChatModelConnection {
                 chatMessage.setToolCalls(toolCalls);
             }
 
-            // Record token metrics if model name is available
+            // Stash token usage if model name is available
             if (modelName != null && !modelName.isBlank()) {
                 Integer promptTokens = ollamaChatResponse.getPromptEvalCount();
                 Integer completionTokens = ollamaChatResponse.getEvalCount();
                 if (promptTokens != null && completionTokens != null) {
-                    recordTokenMetrics(
-                            modelName, promptTokens.longValue(), 
completionTokens.longValue());
+                    extraArgs.put("model_name", modelName);
+                    extraArgs.put("promptTokens", promptTokens.longValue());
+                    extraArgs.put("completionTokens", 
completionTokens.longValue());
                 }
             }
 
diff --git 
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
 
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
index 7d6b5c2c..6567bd2b 100644
--- 
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
+++ 
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
@@ -216,10 +216,11 @@ public class AzureOpenAIChatModelConnection extends 
BaseChatModelConnection {
             if (modelOfAzureDeployment != null
                     && !modelOfAzureDeployment.isBlank()
                     && completion.usage().isPresent()) {
-                recordTokenMetrics(
-                        modelOfAzureDeployment,
-                        completion.usage().get().promptTokens(),
-                        completion.usage().get().completionTokens());
+                response.getExtraArgs().put("model_name", 
modelOfAzureDeployment);
+                response.getExtraArgs()
+                        .put("promptTokens", 
completion.usage().get().promptTokens());
+                response.getExtraArgs()
+                        .put("completionTokens", 
completion.usage().get().completionTokens());
             }
 
             return response;
diff --git 
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
 
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
index e4947e8f..2a0b78fe 100644
--- 
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
+++ 
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
@@ -129,17 +129,18 @@ public class OpenAICompletionsConnection extends 
BaseChatModelConnection {
                     OpenAIChatCompletionsUtils.convertFromOpenAIMessage(
                             completion.choices().get(0).message());
 
-            // Record token metrics
+            // Stash token usage
             if (completion.usage().isPresent()) {
                 String modelName = arguments != null ? (String) 
arguments.get("model") : null;
                 if (modelName == null || modelName.isBlank()) {
                     modelName = this.defaultModel;
                 }
                 if (modelName != null && !modelName.isBlank()) {
-                    recordTokenMetrics(
-                            modelName,
-                            completion.usage().get().promptTokens(),
-                            completion.usage().get().completionTokens());
+                    response.getExtraArgs().put("model_name", modelName);
+                    response.getExtraArgs()
+                            .put("promptTokens", 
completion.usage().get().promptTokens());
+                    response.getExtraArgs()
+                            .put("completionTokens", 
completion.usage().get().completionTokens());
                 }
             }
 
diff --git 
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java
 
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java
index 9b0d143e..00b5f9b6 100644
--- 
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java
+++ 
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIResponsesModelConnection.java
@@ -140,10 +140,10 @@ public class OpenAIResponsesModelConnection extends 
BaseChatModelConnection {
                     modelName = this.defaultModel;
                 }
                 if (modelName != null && !modelName.isBlank()) {
-                    recordTokenMetrics(
-                            modelName,
-                            response.usage().get().inputTokens(),
-                            response.usage().get().outputTokens());
+                    result.getExtraArgs().put("model_name", modelName);
+                    result.getExtraArgs().put("promptTokens", 
response.usage().get().inputTokens());
+                    result.getExtraArgs()
+                            .put("completionTokens", 
response.usage().get().outputTokens());
                 }
             }
 
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index 5805eceb..504b4fb9 100644
--- 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -182,6 +182,23 @@ public class ChatModelAction {
         }
     }
 
+    static void recordChatTokenMetrics(BaseChatModelSetup chatModel, 
ChatMessage response) {
+        Map<String, Object> extraArgs = response.getExtraArgs();
+        Object modelName = extraArgs.get("model_name");
+        Object promptTokens = extraArgs.get("promptTokens");
+        Object completionTokens = extraArgs.get("completionTokens");
+        if (modelName != null
+                && !modelName.toString().isEmpty()
+                && promptTokens instanceof Number
+                && completionTokens instanceof Number) {
+            long prompt = ((Number) promptTokens).longValue();
+            long completion = ((Number) completionTokens).longValue();
+            if (prompt > 0 && completion > 0) {
+                chatModel.recordTokenMetrics(modelName.toString(), prompt, 
completion);
+            }
+        }
+    }
+
     private static void handleToolCalls(
             ChatMessage response,
             UUID initialRequestId,
@@ -355,6 +372,7 @@ public class ChatModelAction {
                         chatAsync
                                 ? ctx.durableExecuteAsync(callable)
                                 : ctx.durableExecute(callable);
+                recordChatTokenMetrics(chatModel, response);
                 // only generate structured output for final response.
                 if (outputSchema != null && response.getToolCalls().isEmpty()) 
{
                     response = generateStructuredOutput(response, 
outputSchema);
diff --git 
a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
 
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
index d7f11785..85c263a6 100644
--- 
a/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
+++ 
b/plan/src/test/java/org/apache/flink/agents/plan/actions/ChatModelActionTest.java
@@ -17,13 +17,98 @@
  */
 package org.apache.flink.agents.plan.actions;
 
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
 import org.junit.jupiter.api.Test;
 
+import java.util.HashMap;
+import java.util.Map;
+
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
 
 /** Tests for {@link ChatModelAction}. */
 class ChatModelActionTest {
 
+    private static ChatMessage responseWith(Map<String, Object> extraArgs) {
+        return new ChatMessage(MessageRole.ASSISTANT, "response", extraArgs);
+    }
+
+    @Test
+    void testRecordChatTokenMetricsRecordsWhenAllKeysPresent() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("model_name", "m");
+        extraArgs.put("promptTokens", 100L);
+        extraArgs.put("completionTokens", 50L);
+
+        ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+        verify(setup).recordTokenMetrics("m", 100L, 50L);
+    }
+
+    @Test
+    void testRecordChatTokenMetricsHandlesIntegerTokenValues() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("model_name", "m");
+        extraArgs.put("promptTokens", 100);
+        extraArgs.put("completionTokens", 50);
+
+        ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+        verify(setup).recordTokenMetrics("m", 100L, 50L);
+    }
+
+    @Test
+    void testRecordChatTokenMetricsSkipsWhenTokenValueNonNumeric() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("model_name", "m");
+        extraArgs.put("promptTokens", "100");
+        extraArgs.put("completionTokens", 50L);
+
+        ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+        verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), 
anyLong());
+    }
+
+    @Test
+    void testRecordChatTokenMetricsSkipsWhenKeyMissing() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+        Map<String, Object> extraArgs = new HashMap<>();
+        extraArgs.put("model_name", "m");
+        extraArgs.put("completionTokens", 50L);
+
+        ChatModelAction.recordChatTokenMetrics(setup, responseWith(extraArgs));
+
+        verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), 
anyLong());
+    }
+
+    @Test
+    void testRecordChatTokenMetricsSkipsZeroTokensOrEmptyModel() {
+        BaseChatModelSetup setup = mock(BaseChatModelSetup.class);
+
+        Map<String, Object> zeroPrompt = new HashMap<>();
+        zeroPrompt.put("model_name", "m");
+        zeroPrompt.put("promptTokens", 0L);
+        zeroPrompt.put("completionTokens", 50L);
+        ChatModelAction.recordChatTokenMetrics(setup, 
responseWith(zeroPrompt));
+
+        Map<String, Object> emptyModel = new HashMap<>();
+        emptyModel.put("model_name", "");
+        emptyModel.put("promptTokens", 100L);
+        emptyModel.put("completionTokens", 50L);
+        ChatModelAction.recordChatTokenMetrics(setup, 
responseWith(emptyModel));
+
+        verify(setup, never()).recordTokenMetrics(anyString(), anyLong(), 
anyLong());
+    }
+
     @Test
     void testCleanLlmResponseWithJsonBlock() {
         String input = "```json\n{\"key\": \"value\"}\n```";

Reply via email to