This is an automated email from the ASF dual-hosted git repository.
xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new c477ce8a [Feature] [Integration][Python] Add built-in support for
Azure OpenAI Chat Model (#478)
c477ce8a is described below
commit c477ce8a2f6be2dda3e638b6124cd59fb5261731
Author: Alan Z. <[email protected]>
AuthorDate: Mon Jan 26 19:04:43 2026 -0800
[Feature] [Integration][Python] Add built-in support for Azure OpenAI Chat
Model (#478)
---
.../flink/agents/api/resource/ResourceName.java | 6 +
docs/content/docs/development/chat_models.md | 103 ++++++++
python/flink_agents/api/resource.py | 27 ++-
.../chat_model_integration_agent.py | 24 ++
.../chat_model_integration_test.py | 11 +
.../integrations/chat_models/azure/__init__.py | 17 ++
.../chat_models/azure/azure_openai_chat_model.py | 260 +++++++++++++++++++++
.../chat_models/azure/tests/__init__.py | 17 ++
.../azure/tests/test_azure_openai_chat_model.py | 108 +++++++++
9 files changed, 562 insertions(+), 11 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java
b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java
index f542570c..39da35aa 100644
--- a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java
+++ b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java
@@ -86,6 +86,12 @@ public final class ResourceName {
public static final String ANTHROPIC_SETUP =
"flink_agents.integrations.chat_models.anthropic.anthropic_chat_model.AnthropicChatModelSetup";
+ // Azure OpenAI
+ public static final String AZURE_OPENAI_CONNECTION =
+
"flink_agents.integrations.chat_models.azure.azure_openai_chat_model.AzureOpenAIChatModelConnection";
+ public static final String AZURE_OPENAI_SETUP =
+
"flink_agents.integrations.chat_models.azure.azure_openai_chat_model.AzureOpenAIChatModelSetup";
+
// Ollama
public static final String OLLAMA_CONNECTION =
"flink_agents.integrations.chat_models.ollama_chat_model.OllamaChatModelConnection";
diff --git a/docs/content/docs/development/chat_models.md
b/docs/content/docs/development/chat_models.md
index 1bcbcd7f..c7de39f3 100644
--- a/docs/content/docs/development/chat_models.md
+++ b/docs/content/docs/development/chat_models.md
@@ -679,6 +679,109 @@ Some popular options include:
Model availability and specifications may change. Always check the official
OpenAI documentation for the latest information before implementing in
production.
{{< /hint >}}
+### OpenAI (Azure)
+
+OpenAI (Azure) provides access to OpenAI models through Azure's cloud
infrastructure, using the same OpenAI SDK with Azure-specific authentication
and endpoints. This offers enterprise security, compliance, and regional
availability while using familiar OpenAI APIs.
+
+{{< hint info >}}
+OpenAI (Azure) is only supported in Python currently. To use OpenAI (Azure)
from Java agents, see [Using Cross-Language
Providers](#using-cross-language-providers).
+{{< /hint >}}
+
+#### Prerequisites
+
+1. Create an Azure OpenAI resource in the [Azure
Portal](https://portal.azure.com/)
+2. Deploy a model in [Azure OpenAI Studio](https://oai.azure.com/)
+3. Obtain your endpoint URL, API key, API version, and deployment name from
the Azure portal
+
+#### AzureOpenAIChatModelConnection Parameters
+
+{{< tabs "AzureOpenAIChatModelConnection Parameters" >}}
+
+{{< tab "Python" >}}
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `api_key` | str | Required | Azure OpenAI API key for authentication |
+| `api_version` | str | Required | Azure OpenAI REST API version (e.g.,
"2024-02-15-preview"). See [API
versions](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning)
|
+| `azure_endpoint` | str | Required | Azure OpenAI endpoint URL (e.g.,
`https://{resource-name}.openai.azure.com`) |
+| `timeout` | float | `60.0` | API request timeout in seconds |
+| `max_retries` | int | `3` | Maximum number of API retry attempts |
+
+{{< /tab >}}
+
+{{< /tabs >}}
+
+#### AzureOpenAIChatModelSetup Parameters
+
+{{< tabs "AzureOpenAIChatModelSetup Parameters" >}}
+
+{{< tab "Python" >}}
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `connection` | str | Required | Reference to connection method name |
+| `model` | str | Required | Name of OpenAI model deployment on Azure |
+| `model_of_azure_deployment` | str | None | The underlying model name (e.g.,
'gpt-4', 'gpt-35-turbo'). Used for token metrics tracking |
+| `prompt` | Prompt \| str | None | Prompt template or reference to prompt
resource |
+| `tools` | List[str] | None | List of tool names available to the model |
+| `temperature` | float | None | Sampling temperature (0.0 to 2.0). Not
supported by reasoning models |
+| `max_tokens` | int | None | Maximum number of tokens to generate |
+| `logprobs` | bool | `False` | Whether to return log probabilities of output
tokens |
+| `additional_kwargs` | dict | `{}` | Additional Azure OpenAI API parameters |
+
+{{< /tab >}}
+
+{{< /tabs >}}
+
+#### Usage Example
+
+{{< tabs "OpenAI (Azure) Usage Example" >}}
+
+{{< tab "Python" >}}
+```python
+class MyAgent(Agent):
+
+ @chat_model_connection
+ @staticmethod
+ def azure_openai_connection() -> ResourceDescriptor:
+ return ResourceDescriptor(
+ clazz=ResourceName.ChatModel.AZURE_OPENAI_CONNECTION,
+ api_key="<your-api-key>",
+ api_version="2024-02-15-preview",
+ azure_endpoint="https://your-resource.openai.azure.com"
+ )
+
+ @chat_model_setup
+ @staticmethod
+ def azure_openai_chat_model() -> ResourceDescriptor:
+ return ResourceDescriptor(
+ clazz=ResourceName.ChatModel.AZURE_OPENAI_SETUP,
+ connection="azure_openai_connection",
+ model="my-gpt4-deployment", # Your Azure deployment name
+ model_of_azure_deployment="gpt-4", # Underlying model for metrics
+ max_tokens=1000
+ )
+
+ ...
+```
+{{< /tab >}}
+
+{{< /tabs >}}
+
+#### Available Models
+
+OpenAI (Azure) supports OpenAI models deployed through your Azure
subscription. Visit the [Azure OpenAI Models
documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models)
for the complete and up-to-date list of available models.
+
+Some popular options include:
+- **GPT-4o** (gpt-4o)
+- **GPT-4** (gpt-4)
+- **GPT-4 Turbo** (gpt-4-turbo)
+- **GPT-3.5 Turbo** (gpt-35-turbo)
+
+{{< hint warning >}}
+Model availability depends on your Azure region and subscription. Always check
the official Azure OpenAI documentation for regional availability before
implementing in production.
+{{< /hint >}}
+
### Tongyi (DashScope)
Tongyi provides cloud-based chat models from Alibaba Cloud, offering powerful
Chinese and English language capabilities.
diff --git a/python/flink_agents/api/resource.py
b/python/flink_agents/api/resource.py
index f7f7aee8..de5b2127 100644
--- a/python/flink_agents/api/resource.py
+++ b/python/flink_agents/api/resource.py
@@ -120,14 +120,14 @@ class ResourceDescriptor(BaseModel):
arguments: Dict[str, Any]
def __init__(
- self,
- /,
- *,
- clazz: str | None = None,
- target_module: str | None = None,
- target_clazz: str | None = None,
- arguments: Dict[str, Any] | None = None,
- **kwargs: Any,
+ self,
+ /,
+ *,
+ clazz: str | None = None,
+ target_module: str | None = None,
+ target_clazz: str | None = None,
+ arguments: Dict[str, Any] | None = None,
+ **kwargs: Any,
) -> None:
"""Initialize ResourceDescriptor.
@@ -182,9 +182,9 @@ class ResourceDescriptor(BaseModel):
if not isinstance(other, ResourceDescriptor):
return False
return (
- self.target_module == other.target_module
- and self.target_clazz == other.target_clazz
- and self.arguments == other.arguments
+ self.target_module == other.target_module
+ and self.target_clazz == other.target_clazz
+ and self.arguments == other.arguments
)
def __hash__(self) -> int:
@@ -211,6 +211,7 @@ def get_resource_class(module_path: str, class_name: str)
-> Type[Resource]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
+
class ResourceName:
"""Hierarchical resource class names for pointing a resource
implementation in
ResourceDescriptor.
@@ -235,6 +236,10 @@ class ResourceName:
ANTHROPIC_CONNECTION =
"flink_agents.integrations.chat_models.anthropic.anthropic_chat_model.AnthropicChatModelConnection"
ANTHROPIC_SETUP =
"flink_agents.integrations.chat_models.anthropic.anthropic_chat_model.AnthropicChatModelSetup"
+ # Azure OpenAI
+ AZURE_OPENAI_CONNECTION =
"flink_agents.integrations.chat_models.azure.azure_openai_chat_model.AzureOpenAIChatModelConnection"
+ AZURE_OPENAI_SETUP =
"flink_agents.integrations.chat_models.azure.azure_openai_chat_model.AzureOpenAIChatModelSetup"
+
# Ollama
OLLAMA_CONNECTION =
"flink_agents.integrations.chat_models.ollama_chat_model.OllamaChatModelConnection"
OLLAMA_SETUP =
"flink_agents.integrations.chat_models.ollama_chat_model.OllamaChatModelSetup"
diff --git
a/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_agent.py
b/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_agent.py
index 427f642e..2838bbf5 100644
---
a/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_agent.py
+++
b/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_agent.py
@@ -45,6 +45,17 @@ class ChatModelTestAgent(Agent):
clazz=ResourceName.ChatModel.OPENAI_CONNECTION,
api_key=os.environ.get("OPENAI_API_KEY")
)
+ @chat_model_connection
+ @staticmethod
+ def azure_openai_connection() -> ResourceDescriptor:
+ """ChatModelConnection responsible for openai model service
connection."""
+ return ResourceDescriptor(
+ clazz=ResourceName.ChatModel.AZURE_OPENAI_CONNECTION,
+ api_key=os.environ.get("AZURE_OPENAI_API_KEY"),
+ api_version=os.environ.get("AZURE_OPENAI_API_VERSION"),
+ azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT"),
+ )
+
@chat_model_connection
@staticmethod
def tongyi_connection() -> ResourceDescriptor:
@@ -86,6 +97,13 @@ class ChatModelTestAgent(Agent):
model=os.environ.get("OPENAI_CHAT_MODEL", "gpt-3.5-turbo"),
tools=["add"],
)
+ elif model_provider == "AzureOpenAI":
+ return ResourceDescriptor(
+ clazz=ResourceName.ChatModel.AZURE_OPENAI_SETUP,
+ connection="azure_openai_connection",
+ model=os.environ.get("AZURE_OPENAI_CHAT_MODEL", "gpt-5"),
+ tools=["add"],
+ )
else:
err_msg = f"Unknown model_provider {model_provider}"
raise RuntimeError(err_msg)
@@ -114,6 +132,12 @@ class ChatModelTestAgent(Agent):
connection="openai_connection",
model=os.environ.get("OPENAI_CHAT_MODEL", "gpt-3.5-turbo"),
)
+ elif model_provider == "AzureOpenAI":
+ return ResourceDescriptor(
+ clazz=ResourceName.ChatModel.AZURE_OPENAI_SETUP,
+ connection="azure_openai_connection",
+ model=os.environ.get("AZURE_OPENAI_CHAT_MODEL", "gpt-5"),
+ )
else:
err_msg = f"Unknown model_provider {model_provider}"
raise RuntimeError(err_msg)
diff --git
a/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_test.py
b/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_test.py
index 93dd48fd..f53b593e 100644
---
a/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_test.py
+++
b/python/flink_agents/e2e_tests/e2e_tests_integration/chat_model_integration_test.py
@@ -34,9 +34,14 @@ OLLAMA_MODEL = os.environ.get("OLLAMA_CHAT_MODEL",
"qwen3:1.7b")
os.environ["OLLAMA_CHAT_MODEL"] = OLLAMA_MODEL
OPENAI_MODEL = os.environ.get("OPENAI_CHAT_MODEL", "gpt-3.5-turbo")
os.environ["OPENAI_CHAT_MODEL"] = OPENAI_MODEL
+AZURE_OPENAI_MODEL = os.environ.get("AZURE_OPENAI_CHAT_MODEL", "gpt-5")
+os.environ["AZURE_OPENAI_CHAT_MODEL"] = AZURE_OPENAI_MODEL
+AZURE_OPENAI_API_VERSION= os.environ.get("AZURE_OPENAI_API_VERSION",
"2025-04-01-preview")
+os.environ["AZURE_OPENAI_API_VERSION"] = AZURE_OPENAI_API_VERSION
DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
+AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
client = pull_model(OLLAMA_MODEL)
@@ -63,6 +68,12 @@ client = pull_model(OLLAMA_MODEL)
OPENAI_API_KEY is None, reason="OpenAI api key is not set."
),
),
+ pytest.param(
+ "AzureOpenAI",
+ marks=pytest.mark.skipif(
+ AZURE_OPENAI_API_KEY is None, reason="Azure OpenAI api key is
not set."
+ ),
+ ),
],
)
def test_chat_model_integration(model_provider: str) -> None: # noqa: D103
diff --git a/python/flink_agents/integrations/chat_models/azure/__init__.py
b/python/flink_agents/integrations/chat_models/azure/__init__.py
new file mode 100644
index 00000000..e154fadd
--- /dev/null
+++ b/python/flink_agents/integrations/chat_models/azure/__init__.py
@@ -0,0 +1,17 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
diff --git
a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
new file mode 100644
index 00000000..74a8fe2e
--- /dev/null
+++
b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
@@ -0,0 +1,260 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+from typing import Any, Dict, List, Sequence
+
+from openai import NOT_GIVEN, AzureOpenAI
+from pydantic import Field, PrivateAttr
+
+from flink_agents.api.chat_message import ChatMessage
+from flink_agents.api.chat_models.chat_model import (
+ BaseChatModelConnection,
+ BaseChatModelSetup,
+)
+from flink_agents.api.tools.tool import Tool
+from flink_agents.integrations.chat_models.chat_model_utils import
to_openai_tool
+from flink_agents.integrations.chat_models.openai.openai_utils import (
+ convert_from_openai_message,
+ convert_to_openai_messages,
+)
+
+
+class AzureOpenAIChatModelConnection(BaseChatModelConnection):
+ """The connection to the Azure OpenAI LLM.
+
+ Attributes:
+ ----------
+ api_key : str
+ The Azure OpenAI API key.
+ api_version : str
+ Azure OpenAI REST API version to use.
+ See more:
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
+ azure_endpoint : str
+ Supported Azure OpenAI endpoints. Example:
https://{your-resource-name}.openai.azure.com
+ timeout : float
+ The number of seconds to wait for an API call before it times out.
+ max_retries : int
+ The number of times to retry the API call upon failure.
+ """
+
+ api_key: str = Field(default=None, description="The Azure OpenAI API key.")
+ api_version: str = Field(
+ default=None,
+ description="Azure OpenAI REST API version to use.",
+ )
+ azure_endpoint: str = Field(
+ default=None,
+ description="Supported Azure OpenAI endpoints. Example:
https://{your-resource-name}.openai.azure.com"
+ )
+ timeout: float = Field(
+ default=60.0,
+ description="The number of seconds to wait for an API call before it
times out.",
+ ge=0,
+ )
+ max_retries: int = Field(
+ default=3,
+ description="The number of times to retry the API call upon failure.",
+ ge=0,
+ )
+
+ def __init__(
+ self,
+ *,
+ api_key: str | None = None,
+ api_version: str | None = None,
+ azure_endpoint: str | None = None,
+ timeout: float = 60.0,
+ max_retries: int = 3,
+ **kwargs: Any,
+ ) -> None:
+ """Init method."""
+ super().__init__(
+ api_key=api_key,
+ api_version=api_version,
+ azure_endpoint=azure_endpoint,
+ timeout=timeout,
+ max_retries=max_retries,
+ **kwargs,
+ )
+
+ _client: AzureOpenAI | None = PrivateAttr(default=None)
+
+ @property
+ def client(self) -> AzureOpenAI:
+ """Get Azure OpenAI client."""
+ if self._client is None:
+ self._client = AzureOpenAI(
+ azure_endpoint=self.azure_endpoint,
+ api_key=self.api_key,
+ api_version=self.api_version,
+ timeout=self.timeout,
+ max_retries=self.max_retries,
+ )
+ return self._client
+
+ def chat(self, messages: Sequence[ChatMessage], tools: List[Tool] | None =
None, **kwargs: Any,) -> ChatMessage:
+ """Direct communication with model service for chat conversation.
+
+ Parameters
+ ----------
+ messages : Sequence[ChatMessage]
+ Input message sequence
+ tools : Optional[List]
+ List of tools that can be called by the model
+ **kwargs : Any
+ Additional parameters passed to the model service (e.g.,
temperature,
+ max_tokens, etc.)
+
+ Returns:
+ -------
+ ChatMessage
+ Model response message
+ """
+ tool_specs = None
+ if tools is not None:
+ tool_specs = [to_openai_tool(metadata=tool.metadata) for tool in
tools]
+
+ # Extract model (azure_deployment) and model_of_azure_deployment from
kwargs
+ azure_deployment = kwargs.pop("model", "")
+ if not azure_deployment:
+ msg = "model is required for Azure OpenAI API calls"
+ raise ValueError(msg)
+ model_of_azure_deployment = kwargs.pop("model_of_azure_deployment",
None)
+
+ response = self.client.chat.completions.create(
+ # Azure OpenAI APIs use Azure deployment name as the model
parameter
+ model=azure_deployment,
+ messages=convert_to_openai_messages(messages),
+ tools=tool_specs or NOT_GIVEN,
+ **kwargs,
+ )
+
+ extra_args = {}
+ # Record token metrics only if model_of_azure_deployment is provided
+ if model_of_azure_deployment and response.usage:
+ extra_args["model_name"] = model_of_azure_deployment
+ extra_args["promptTokens"] = response.usage.prompt_tokens
+ extra_args["completionTokens"] = response.usage.completion_tokens
+
+ message = response.choices[0].message
+
+ return convert_from_openai_message(message, extra_args)
+
+
+class AzureOpenAIChatModelSetup(BaseChatModelSetup):
+ """The settings for the Azure OpenAI LLM.
+
+ Attributes:
+ ----------
+ connection : str
+ Name of the referenced connection. (Inherited from BaseChatModelSetup)
+ prompt : Optional[Union[Prompt, str]
+ Prompt template or string for the model. (Inherited from
BaseChatModelSetup)
+ tools : Optional[List[str]]
+ List of available tools to use in the chat. (Inherited from
BaseChatModelSetup)
+ model : str
+ Name of OpenAI model deployment on Azure.
+ model_of_azure_deployment : Optional[str]
+ The underlying model name of the Azure deployment (e.g., 'gpt-4').
+ Used for token counting and cost calculation.
+ temperature : Optional[float]
+ What sampling temperature to use, between 0 and 2. Higher values like
0.8
+ will make the output more random, while lower values like 0.2 will
make it
+ more focused and deterministic.
+ Not supported by reasoning models (e.g. gpt-5, o-series).
+ max_tokens : Optional[int]
+ The maximum number of tokens that can be generated in the chat
completion.
+ The total length of input tokens and generated tokens is limited by the
+ model's context length.
+ logprobs : Optional[bool]
+ Whether to return log probabilities of the output tokens or not. If
true,
+ returns the log probabilities of each output token returned in the
content
+ of message.
+ additional_kwargs : Dict[str, Any]
+ Additional kwargs for the Azure OpenAI API.
+ """
+
+ model: str = Field(
+ description="Name of OpenAI model deployment on Azure.",
+ )
+ model_of_azure_deployment: str | None = Field(
+ default=None,
+ description="The underlying model name of the Azure deployment (e.g.,
'gpt-4', "
+ "'gpt-35-turbo'). Used for token counting and cost
calculation. "
+ "Required for token metrics tracking.",
+ )
+ temperature: float | None = Field(
+ default=None,
+ description="What sampling temperature to use, between 0 and 2. Higher
values like 0.8 will make the output "
+ "more random, while lower values like 0.2 will make it
more focused and deterministic. "
+ "Not supported by reasoning models (e.g. gpt-5,
o-series).",
+ ge=0.0,
+ le=2.0,
+ )
+ max_tokens: int | None = Field(
+ default=None,
+ description="The maximum number of tokens that can be generated in the
chat completion. The total length of "
+ "input tokens and generated tokens is limited by the
model's context length.",
+ gt=0,
+ )
+ logprobs: bool | None = Field(
+ description="Whether to return log probabilities of the output tokens
or not. If true, returns the log "
+ "probabilities of each output token returned in the
content of message.",
+ default=False,
+ )
+ additional_kwargs: Dict[str, Any] = Field(
+ default_factory=dict, description="Additional kwargs for the Azure
OpenAI API."
+ )
+
+ def __init__(
+ self,
+ *,
+ model: str,
+ model_of_azure_deployment: str | None = None,
+ temperature: float | None = None,
+ max_tokens: int | None = None,
+ logprobs: bool | None = False,
+ additional_kwargs: Dict[str, Any] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ """Init method."""
+ additional_kwargs = additional_kwargs or {}
+ super().__init__(
+ model=model,
+ model_of_azure_deployment=model_of_azure_deployment,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ logprobs=logprobs,
+ additional_kwargs=additional_kwargs,
+ **kwargs,
+ )
+
+ @property
+ def model_kwargs(self) -> Dict[str, Any]:
+ """Return chat model settings."""
+ base_kwargs = {
+ "model": self.model,
+ "model_of_azure_deployment": self.model_of_azure_deployment,
+ "logprobs": self.logprobs,
+ }
+ if self.temperature is not None:
+ base_kwargs["temperature"] = self.temperature
+ if self.max_tokens is not None:
+ base_kwargs["max_tokens"] = self.max_tokens
+
+ all_kwargs = {**base_kwargs, **self.additional_kwargs}
+ return all_kwargs
diff --git
a/python/flink_agents/integrations/chat_models/azure/tests/__init__.py
b/python/flink_agents/integrations/chat_models/azure/tests/__init__.py
new file mode 100644
index 00000000..e154fadd
--- /dev/null
+++ b/python/flink_agents/integrations/chat_models/azure/tests/__init__.py
@@ -0,0 +1,17 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
diff --git
a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
new file mode 100644
index 00000000..d8a1e9f0
--- /dev/null
+++
b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
@@ -0,0 +1,108 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import os
+
+import pytest
+
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.resource import Resource, ResourceType
+from flink_agents.integrations.chat_models.azure.azure_openai_chat_model
import (
+ AzureOpenAIChatModelConnection,
+ AzureOpenAIChatModelSetup,
+)
+from flink_agents.plan.tools.function_tool import from_callable
+
+test_deployment = os.environ.get("TEST_AZURE_DEPLOYMENT")
+api_key = os.environ.get("AZURE_OPENAI_API_KEY")
+azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
+api_version = os.environ.get("AZURE_OPENAI_API_VERSION")
+
+
[email protected](api_key is None, reason="AZURE_OPENAI_API_KEY is not set")
+def test_azure_openai_chat_model() -> None: # noqa: D103
+ connection = AzureOpenAIChatModelConnection(
+ name="azure_openai",
+ api_key=api_key,
+ azure_endpoint=azure_endpoint,
+ api_version=api_version,
+ )
+
+ def get_resource(name: str, type: ResourceType) -> Resource:
+ if type == ResourceType.CHAT_MODEL_CONNECTION:
+ return connection
+ else:
+ return get_resource(name, ResourceType.TOOL)
+
+ chat_model = AzureOpenAIChatModelSetup(
+ name="azure_openai",
+ model=test_deployment,
+ connection="azure_openai",
+ get_resource=get_resource,
+ )
+ response = chat_model.chat([ChatMessage(role=MessageRole.USER,
content="Hello!")])
+ assert response is not None
+ assert str(response).strip() != ""
+
+
+def add(a: int, b: int) -> int:
+ """Calculate the sum of a and b.
+
+ Parameters
+ ----------
+ a : int
+ The first operand
+ b : int
+ The second operand
+
+ Returns:
+ -------
+ int:
+ The sum of a and b
+ """
+ return a + b
+
+
[email protected](api_key is None, reason="AZURE_OPENAI_API_KEY is not set")
+def test_azure_openai_chat_with_tools() -> None: # noqa : D103
+ connection = AzureOpenAIChatModelConnection(
+ name="azure_openai",
+ api_key=api_key,
+ azure_endpoint=azure_endpoint,
+ api_version=api_version,
+ )
+
+ def get_resource(name: str, type: ResourceType) -> Resource:
+ if type == ResourceType.CHAT_MODEL_CONNECTION:
+ return connection
+ else:
+ return from_callable(func=add)
+
+ chat_model = AzureOpenAIChatModelSetup(
+ name="azure_openai",
+ model=test_deployment,
+ connection="azure_openai",
+ tools=["add"],
+ get_resource=get_resource,
+ )
+ response = chat_model.chat(
+ [ChatMessage(role=MessageRole.USER, content="You MUST use the add tool
to calculate: What is 377 + 688?")]
+ )
+ tool_calls = response.tool_calls
+ assert len(tool_calls) == 1
+ tool_call = tool_calls[0]
+ assert add(**tool_call["function"]["arguments"]) == 1065