This is an automated email from the ASF dual-hosted git repository.
wenjin272 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 66225cd7 [api] Specify model in ChatModelSetup (Java + Python) (#685)
66225cd7 is described below
commit 66225cd72d3aa6a34f645b99df747311160a0c24
Author: Weiqing Yang <[email protected]>
AuthorDate: Sun May 17 23:15:19 2026 -0700
[api] Specify model in ChatModelSetup (Java + Python) (#685)
---
.../chatmodels/ollama/OllamaChatModelSetup.java | 2 -
.../ollama/OllamaChatModelSetupTest.java | 55 ++++++++++++++++++++++
.../compatibility/CreateJavaAgentPlanFromJson.java | 2 +
python/flink_agents/api/agents/react_agent.py | 1 +
python/flink_agents/api/chat_models/chat_model.py | 3 ++
.../api/chat_models/tests/test_chat_model_base.py | 47 ++++++++++++++++++
.../api/chat_models/tests/test_token_metrics.py | 12 ++---
.../built_in_action_async_execution_test.py | 1 +
.../chat_models/anthropic/anthropic_chat_model.py | 10 ++--
.../anthropic/tests/test_anthropic_chat_model.py | 14 ++++++
.../chat_models/azure/azure_openai_chat_model.py | 7 +--
.../azure/tests/test_azure_openai_chat_model.py | 7 +++
.../integrations/chat_models/ollama_chat_model.py | 6 +--
.../chat_models/openai/openai_chat_model.py | 8 ++--
.../openai/tests/test_openai_chat_model.py | 14 ++++++
.../chat_models/tests/test_ollama_chat_model.py | 7 +++
.../chat_models/tests/test_tongyi_chat_model.py | 14 ++++++
.../integrations/chat_models/tongyi_chat_model.py | 6 +--
.../python_agent_plan_compatibility_test_agent.py | 2 +
.../plan/tests/resources/agent_plan.json | 3 +-
python/flink_agents/plan/tests/test_agent_plan.py | 2 +
.../flink_agents/runtime/java/java_chat_model.py | 5 +-
.../runtime/tests/test_built_in_actions.py | 1 +
.../runtime/tests/test_get_resource_in_action.py | 1 +
24 files changed, 195 insertions(+), 35 deletions(-)
diff --git
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java
index 81a3f92f..a04ea321 100644
---
a/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java
+++
b/integrations/chat-models/ollama/src/main/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetup.java
@@ -53,13 +53,11 @@ import java.util.Map;
*/
public class OllamaChatModelSetup extends BaseChatModelSetup {
- private final String model;
private final Object think;
private final boolean extractReasoning;
public OllamaChatModelSetup(ResourceDescriptor descriptor, ResourceContext
resourceContext) {
super(descriptor, resourceContext);
- this.model = descriptor.getArgument("model");
this.think = descriptor.getArgument("think", true);
this.extractReasoning = descriptor.getArgument("extract_reasoning",
true);
}
diff --git
a/integrations/chat-models/ollama/src/test/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetupTest.java
b/integrations/chat-models/ollama/src/test/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetupTest.java
new file mode 100644
index 00000000..7836339e
--- /dev/null
+++
b/integrations/chat-models/ollama/src/test/java/org/apache/flink/agents/integrations/chatmodels/ollama/OllamaChatModelSetupTest.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.integrations.chatmodels.ollama;
+
+import org.apache.flink.agents.api.resource.ResourceContext;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link OllamaChatModelSetup}. */
+class OllamaChatModelSetupTest {
+
+ private static final ResourceContext NOOP =
ResourceContext.fromGetResource((a, b) -> null);
+
+ @Test
+ void getModel_returnsValueFromDescriptor() {
+ ResourceDescriptor desc =
+
ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName())
+ .addInitialArgument("connection", "dummy-connection")
+ .addInitialArgument("model", "qwen3:4b")
+ .build();
+ OllamaChatModelSetup setup = new OllamaChatModelSetup(desc, NOOP);
+
+ assertThat(setup.getModel()).isEqualTo("qwen3:4b");
+ }
+
+ @Test
+ void getParameters_includesModelFromDescriptor() {
+ ResourceDescriptor desc =
+
ResourceDescriptor.Builder.newBuilder(OllamaChatModelSetup.class.getName())
+ .addInitialArgument("connection", "dummy-connection")
+ .addInitialArgument("model", "qwen3:4b")
+ .build();
+ OllamaChatModelSetup setup = new OllamaChatModelSetup(desc, NOOP);
+
+ assertThat(setup.getParameters()).containsEntry("model", "qwen3:4b");
+ }
+}
diff --git
a/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java
b/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java
index 3388af22..c5589be3 100644
---
a/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java
+++
b/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java
@@ -144,6 +144,8 @@ public class CreateJavaAgentPlanFromJson {
kwargs.put("name", "chat_model");
kwargs.put("prompt", "prompt");
kwargs.put("tools", List.of("add"));
+ kwargs.put("connection", "mock_connection");
+ kwargs.put("model", "mock-model");
ResourceDescriptor chatModelDescriptor =
new ResourceDescriptor(
"flink_agents.plan.tests.compatibility.python_agent_plan_compatibility_test_agent",
diff --git a/python/flink_agents/api/agents/react_agent.py
b/python/flink_agents/api/agents/react_agent.py
index 740afbe7..cef651a1 100644
--- a/python/flink_agents/api/agents/react_agent.py
+++ b/python/flink_agents/api/agents/react_agent.py
@@ -91,6 +91,7 @@ class ReActAgent(Agent):
chat_model=ResourceDescriptor(
clazz=OllamaChatModelSetup,
connection="ollama_server",
+ model="qwen3:8b",
tools=["notify_shipping_manager"],
),
prompt=prompt,
diff --git a/python/flink_agents/api/chat_models/chat_model.py
b/python/flink_agents/api/chat_models/chat_model.py
index 7d7fc194..96055979 100644
--- a/python/flink_agents/api/chat_models/chat_model.py
+++ b/python/flink_agents/api/chat_models/chat_model.py
@@ -131,6 +131,8 @@ class BaseChatModelSetup(Resource):
"""Base abstract class for chat model setup.
Responsible for managing chat configurations, such as:
+ - Connection to chat model service (connection)
+ - Model name (model)
- Prompt templates (prompt)
- Available tools (tools)
- Generation parameters (temperature, max_tokens, etc.)
@@ -143,6 +145,7 @@ class BaseChatModelSetup(Resource):
"""
connection: str = Field(description="The referenced connection name.")
+ model: str = Field(description="Name of the chat model to use.")
_resolved_connection: BaseChatModelConnection | None =
PrivateAttr(default=None)
prompt: Prompt | str | None = None
tools: List[str] | List[Tool] = Field(default_factory=list)
diff --git a/python/flink_agents/api/chat_models/tests/test_chat_model_base.py
b/python/flink_agents/api/chat_models/tests/test_chat_model_base.py
new file mode 100644
index 00000000..651061f4
--- /dev/null
+++ b/python/flink_agents/api/chat_models/tests/test_chat_model_base.py
@@ -0,0 +1,47 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+from typing import Any, Dict
+
+import pytest
+from pydantic import ValidationError
+
+from flink_agents.api.chat_models.chat_model import BaseChatModelSetup
+
+
+class _MinimalChatModelSetup(BaseChatModelSetup):
+ """Minimal subclass that omits the `model` field declaration.
+
+ Used to assert the `model` field is inherited from `BaseChatModelSetup`.
+ """
+
+ @property
+ def model_kwargs(self) -> Dict[str, Any]:
+ """Return chat model settings derived from the inherited `model`
field."""
+ return {"model": self.model}
+
+
+def test_inherits_model_field_from_base() -> None:
+ """A subclass that omits `model` still exposes it via inheritance."""
+ setup = _MinimalChatModelSetup(connection="c", model="m1")
+ assert setup.model == "m1"
+
+
+def test_missing_model_raises_validation_error() -> None:
+ """Constructing without `model` must raise a Pydantic ValidationError."""
+ with pytest.raises(ValidationError):
+ _MinimalChatModelSetup(connection="c")
diff --git a/python/flink_agents/api/chat_models/tests/test_token_metrics.py
b/python/flink_agents/api/chat_models/tests/test_token_metrics.py
index 982565ab..d987d121 100644
--- a/python/flink_agents/api/chat_models/tests/test_token_metrics.py
+++ b/python/flink_agents/api/chat_models/tests/test_token_metrics.py
@@ -100,7 +100,7 @@ class TestBaseChatModelTokenMetrics:
def test_record_token_metrics_with_metric_group(self) -> None:
"""Test token metrics are recorded when metric group is set."""
- chat_model = TestChatModelSetup(connection="mock")
+ chat_model = TestChatModelSetup(connection="mock", model="mock-model")
mock_metric_group = _MockMetricGroup()
# Set the metric group
@@ -116,7 +116,7 @@ class TestBaseChatModelTokenMetrics:
def test_record_token_metrics_without_metric_group(self) -> None:
"""Test token metrics are not recorded when metric group is null."""
- chat_model = TestChatModelSetup(connection="mock")
+ chat_model = TestChatModelSetup(connection="mock", model="mock-model")
# Do not set metric group (should be None by default)
# Record token metrics - should not throw
@@ -125,7 +125,7 @@ class TestBaseChatModelTokenMetrics:
def test_token_metrics_hierarchy(self) -> None:
"""Test token metrics hierarchy: actionMetricGroup -> modelName ->
counters."""
- chat_model = TestChatModelSetup(connection="mock")
+ chat_model = TestChatModelSetup(connection="mock", model="mock-model")
mock_metric_group = _MockMetricGroup()
# Set the metric group
@@ -148,7 +148,7 @@ class TestBaseChatModelTokenMetrics:
def test_token_metrics_accumulation(self) -> None:
"""Test that token metrics accumulate across multiple calls."""
- chat_model = TestChatModelSetup(connection="mock")
+ chat_model = TestChatModelSetup(connection="mock", model="mock-model")
mock_metric_group = _MockMetricGroup()
# Set the metric group
@@ -165,12 +165,12 @@ class TestBaseChatModelTokenMetrics:
def test_resource_type(self) -> None:
"""Test resource type is CHAT_MODEL_CONNECTION."""
- chat_model = TestChatModelSetup(connection="mock")
+ chat_model = TestChatModelSetup(connection="mock", model="mock-model")
assert chat_model.resource_type() == ResourceType.CHAT_MODEL
def test_bound_metric_group_property(self) -> None:
"""Test bound_metric_group property."""
- chat_model = TestChatModelSetup(connection="mock")
+ chat_model = TestChatModelSetup(connection="mock", model="mock-model")
# Initially should be None
assert chat_model.metric_group is None
diff --git
a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py
b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py
index cd4d5484..f5fefd72 100644
---
a/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py
+++
b/python/flink_agents/e2e_tests/e2e_tests_integration/built_in_action_async_execution_test.py
@@ -72,6 +72,7 @@ class AsyncTestAgent(Agent):
return ResourceDescriptor(
clazz=f"{SlowMockChatModel.__module__}.{SlowMockChatModel.__name__}",
connection="placement",
+ model="slow-mock-model",
tools=["add"],
)
diff --git
a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py
b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py
index 171fba24..c077c6c8 100644
---
a/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/anthropic/anthropic_chat_model.py
@@ -246,23 +246,19 @@ class AnthropicChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
+ model : str
+ Specifies the Anthropic model to use. Defaults to
claude-sonnet-4-20250514
+ when omitted via ``__init__``. (Inherited from BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from
BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from
BaseChatModelSetup)
- model : str
- Specifies the Anthropic model to use. Defaults to
claude-sonnet-4-20250514.
max_tokens: int
The maximum number of tokens to generate before stopping. Defaults to
1024.
temperature : float
Amount of randomness injected into the response.
"""
- model: str = Field(
- default=DEFAULT_ANTHROPIC_MODEL,
- description="Specifies the Anthropic model to use. Defaults to "
- "claude-sonnet-4-20250514.",
- )
max_tokens: int = Field(
default=DEFAULT_MAX_TOKENS,
description="The maximum number of tokens to generate before stopping.
Defaults to 1024.",
diff --git
a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py
b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py
index 767cb6a4..247759a2 100644
---
a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py
@@ -24,6 +24,7 @@ from flink_agents.api.chat_message import ChatMessage,
MessageRole
from flink_agents.api.resource import Resource, ResourceType
from flink_agents.api.resource_context import ResourceContext
from flink_agents.integrations.chat_models.anthropic.anthropic_chat_model
import (
+ DEFAULT_ANTHROPIC_MODEL,
AnthropicChatModelConnection,
AnthropicChatModelSetup,
)
@@ -103,3 +104,16 @@ def test_anthropic_chat_with_tools() -> None:
tool_call = tool_calls[0]
assert add(**tool_call["function"]["arguments"]) == 2
assert tool_call.get("original_id") is not None
+
+
+def test_model_field_roundtrip() -> None:
+ """Verify `model` is preserved through pydantic dump/validate
round-trip."""
+ setup = AnthropicChatModelSetup(connection="conn", model="test-model")
+ restored = AnthropicChatModelSetup.model_validate(setup.model_dump())
+ assert restored.model == "test-model"
+
+
+def test_default_model_when_omitted() -> None:
+ """Verify per-integration default applies when `model` is omitted from
__init__."""
+ setup = AnthropicChatModelSetup(connection="conn")
+ assert setup.model == DEFAULT_ANTHROPIC_MODEL
diff --git
a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
index 64da6890..576d3ab7 100644
---
a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
@@ -167,12 +167,12 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
+ model : str
+ Name of OpenAI model deployment on Azure. (Inherited from
BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from
BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from
BaseChatModelSetup)
- model : str
- Name of OpenAI model deployment on Azure.
model_of_azure_deployment : Optional[str]
The underlying model name of the Azure deployment (e.g., 'gpt-4').
Used for token counting and cost calculation.
@@ -193,9 +193,6 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup):
Additional kwargs for the Azure OpenAI API.
"""
- model: str = Field(
- description="Name of OpenAI model deployment on Azure.",
- )
model_of_azure_deployment: str | None = Field(
default=None,
description="The underlying model name of the Azure deployment (e.g.,
'gpt-4', "
diff --git
a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
index 95172a67..30753d75 100644
---
a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
@@ -119,3 +119,10 @@ def test_azure_openai_chat_with_tools() -> None:
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert add(**tool_call["function"]["arguments"]) == 1065
+
+
+def test_model_field_roundtrip() -> None:
+ """Verify `model` is preserved through pydantic dump/validate
round-trip."""
+ setup = AzureOpenAIChatModelSetup(connection="conn",
model="test-deployment")
+ restored = AzureOpenAIChatModelSetup.model_validate(setup.model_dump())
+ assert restored.model == "test-deployment"
diff --git a/python/flink_agents/integrations/chat_models/ollama_chat_model.py
b/python/flink_agents/integrations/chat_models/ollama_chat_model.py
index a879dcf9..7c36ec38 100644
--- a/python/flink_agents/integrations/chat_models/ollama_chat_model.py
+++ b/python/flink_agents/integrations/chat_models/ollama_chat_model.py
@@ -176,12 +176,12 @@ class OllamaChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
+ model : str
+ Model name to use. (Inherited from BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from
BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from
BaseChatModelSetup)
- model : str
- Model name to use.
temperature : float
The temperature to use for sampling.
num_ctx : int
@@ -196,8 +196,6 @@ class OllamaChatModelSetup(BaseChatModelSetup):
stores it in additional_kwargs.
"""
- model: str = Field(description="Model name to use.")
-
temperature: float = Field(
default=0.75,
description="The temperature to use for sampling.",
diff --git
a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py
b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py
index 95edb298..2e5fe720 100644
--- a/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py
+++ b/python/flink_agents/integrations/chat_models/openai/openai_chat_model.py
@@ -203,12 +203,13 @@ class OpenAIChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
+ model : str
+ The OpenAI model to use. Defaults to ``DEFAULT_OPENAI_MODEL`` when
omitted via
+ ``__init__``. (Inherited from BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from
BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from
BaseChatModelSetup)
- model : str
- The OpenAI model to use.
temperature : float
The temperature to use during generation.
max_tokens : Optional[int]
@@ -225,9 +226,6 @@ class OpenAIChatModelSetup(BaseChatModelSetup):
The effort to use for reasoning models.
"""
- model: str = Field(
- default=DEFAULT_OPENAI_MODEL, description="The OpenAI model to use."
- )
temperature: float = Field(
default=DEFAULT_TEMPERATURE,
description="The temperature to use during generation.",
diff --git
a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py
b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py
index 3c893ecb..2280bf44 100644
---
a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py
@@ -24,6 +24,7 @@ from flink_agents.api.chat_message import ChatMessage,
MessageRole
from flink_agents.api.resource import Resource, ResourceType
from flink_agents.api.resource_context import ResourceContext
from flink_agents.integrations.chat_models.openai.openai_chat_model import (
+ DEFAULT_OPENAI_MODEL,
OpenAIChatModelConnection,
OpenAIChatModelSetup,
)
@@ -104,3 +105,16 @@ def test_openai_chat_with_tools() -> None:
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert add(**tool_call["function"]["arguments"]) == 1065
+
+
+def test_model_field_roundtrip() -> None:
+ """Verify `model` is preserved through pydantic dump/validate
round-trip."""
+ setup = OpenAIChatModelSetup(connection="conn", model="test-model")
+ restored = OpenAIChatModelSetup.model_validate(setup.model_dump())
+ assert restored.model == "test-model"
+
+
+def test_default_model_when_omitted() -> None:
+ """Verify per-integration default applies when `model` is omitted from
__init__."""
+ setup = OpenAIChatModelSetup(connection="conn")
+ assert setup.model == DEFAULT_OPENAI_MODEL
diff --git
a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py
b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py
index e56fc0fd..5503185b 100644
---
a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py
@@ -133,6 +133,13 @@ def test_ollama_chat_with_tools() -> None:
assert add(**tool_call["function"]["arguments"]) == 3
+def test_model_field_roundtrip() -> None:
+ """Verify `model` is preserved through pydantic dump/validate
round-trip."""
+ setup = OllamaChatModelSetup(connection="conn", model="test-model")
+ restored = OllamaChatModelSetup.model_validate(setup.model_dump())
+ assert restored.model == "test-model"
+
+
def test_extract_think_tags() -> None:
"""Test the static method that extracts content from <think></think>
tags."""
# Test with a think tag at the beginning (most common case)
diff --git
a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py
b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py
index b6be7404..fc80ac6a 100644
---
a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py
@@ -25,6 +25,7 @@ from flink_agents.api.chat_message import ChatMessage,
MessageRole
from flink_agents.api.resource import Resource, ResourceType
from flink_agents.api.resource_context import ResourceContext
from flink_agents.integrations.chat_models.tongyi_chat_model import (
+ DEFAULT_MODEL,
TongyiChatModelConnection,
TongyiChatModelSetup,
)
@@ -175,3 +176,16 @@ def test_tongyi_chat_with_extract_reasoning(monkeypatch:
pytest.MonkeyPatch) ->
assert "reasoning" in response.extra_args
assert "philosophical perspectives" in response.extra_args["reasoning"]
assert "Hitchhiker's Guide to the Galaxy" in
response.extra_args["reasoning"]
+
+
+def test_model_field_roundtrip() -> None:
+ """Verify `model` is preserved through pydantic dump/validate
round-trip."""
+ setup = TongyiChatModelSetup(connection="conn", model="test-model")
+ restored = TongyiChatModelSetup.model_validate(setup.model_dump())
+ assert restored.model == "test-model"
+
+
+def test_default_model_when_omitted() -> None:
+ """Verify per-integration default applies when `model` is omitted from
__init__."""
+ setup = TongyiChatModelSetup(connection="conn")
+ assert setup.model == DEFAULT_MODEL
diff --git a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py
b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py
index 7d37f3fc..6587a8cb 100644
--- a/python/flink_agents/integrations/chat_models/tongyi_chat_model.py
+++ b/python/flink_agents/integrations/chat_models/tongyi_chat_model.py
@@ -219,12 +219,13 @@ class TongyiChatModelSetup(BaseChatModelSetup):
----------
connection : str
Name of the referenced connection. (Inherited from BaseChatModelSetup)
+ model : str
+ Model name to use. Defaults to ``DEFAULT_MODEL`` when omitted via
+ ``__init__``. (Inherited from BaseChatModelSetup)
prompt : Optional[Union[Prompt, str]
Prompt template or string for the model. (Inherited from
BaseChatModelSetup)
tools : Optional[List[str]]
List of available tools to use in the chat. (Inherited from
BaseChatModelSetup)
- model : str
- Model name to use.
temperature : float
The temperature to use for sampling.
additional_kwargs : Dict[str, Any]
@@ -234,7 +235,6 @@ class TongyiChatModelSetup(BaseChatModelSetup):
in additional_kwargs.
"""
- model: str = Field(default=DEFAULT_MODEL, description="Model name to use.")
temperature: float = Field(
default=0.7,
description="The temperature to use for sampling.",
diff --git
a/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py
b/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py
index 122f8ad0..5080cefb 100644
---
a/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py
+++
b/python/flink_agents/plan/tests/compatibility/python_agent_plan_compatibility_test_agent.py
@@ -70,6 +70,8 @@ class PythonAgentPlanCompatibilityTestAgent(Agent):
name="chat_model",
prompt="prompt",
tools=["add"],
+ connection="mock_connection",
+ model="mock-model",
)
@tool
diff --git a/python/flink_agents/plan/tests/resources/agent_plan.json
b/python/flink_agents/plan/tests/resources/agent_plan.json
index ad0a5875..9f9a3f41 100644
--- a/python/flink_agents/plan/tests/resources/agent_plan.json
+++ b/python/flink_agents/plan/tests/resources/agent_plan.json
@@ -95,7 +95,8 @@
"arguments": {
"host": "8.8.8.8",
"desc": "mock resource just for testing.",
- "connection": "mock"
+ "connection": "mock",
+ "model": "mock-model"
}
},
"__resource_provider_type__": "PythonResourceProvider"
diff --git a/python/flink_agents/plan/tests/test_agent_plan.py
b/python/flink_agents/plan/tests/test_agent_plan.py
index a58289ea..001b1d4d 100644
--- a/python/flink_agents/plan/tests/test_agent_plan.py
+++ b/python/flink_agents/plan/tests/test_agent_plan.py
@@ -205,6 +205,7 @@ class MyAgent(Agent):
host="8.8.8.8",
desc="mock resource just for testing.",
connection="mock",
+ model="mock-model",
)
@embedding_model_connection
@@ -298,6 +299,7 @@ def test_add_action_and_resource_to_agent() -> None:
host="8.8.8.8",
desc="mock resource just for testing.",
connection="mock",
+ model="mock-model",
),
)
diff --git a/python/flink_agents/runtime/java/java_chat_model.py
b/python/flink_agents/runtime/java/java_chat_model.py
index 28ca408d..7160d531 100644
--- a/python/flink_agents/runtime/java/java_chat_model.py
+++ b/python/flink_agents/runtime/java/java_chat_model.py
@@ -106,9 +106,10 @@ class JavaChatModelSetupImpl(JavaChatModelSetup):
j_resource_adapter: The Java resource adapter for method invocation
**kwargs: Additional keyword arguments
"""
- # connection is a required parameter for BaseChatModelSetup
+ # connection and model are required parameters for BaseChatModelSetup
connection = kwargs.pop("connection", "")
- super().__init__(connection=connection, **kwargs)
+ model = kwargs.pop("model", "")
+ super().__init__(connection=connection, model=model, **kwargs)
self._j_resource = j_resource
self._j_resource_adapter = j_resource_adapter
diff --git a/python/flink_agents/runtime/tests/test_built_in_actions.py
b/python/flink_agents/runtime/tests/test_built_in_actions.py
index 09178dd3..05a91881 100644
--- a/python/flink_agents/runtime/tests/test_built_in_actions.py
+++ b/python/flink_agents/runtime/tests/test_built_in_actions.py
@@ -144,6 +144,7 @@ class MyAgent(Agent):
return ResourceDescriptor(
clazz=f"{MockChatModel.__module__}.{MockChatModel.__name__}",
connection="mock_connection",
+ model="mock-model",
prompt="prompt",
tools=["add"],
)
diff --git a/python/flink_agents/runtime/tests/test_get_resource_in_action.py
b/python/flink_agents/runtime/tests/test_get_resource_in_action.py
index 1953645d..f1f1e8e1 100644
--- a/python/flink_agents/runtime/tests/test_get_resource_in_action.py
+++ b/python/flink_agents/runtime/tests/test_get_resource_in_action.py
@@ -54,6 +54,7 @@ class MyAgent(Agent):
host="8.8.8.8",
desc="mock chat model just for testing.",
connection="mock",
+ model="mock-model",
)
@tool