This is an automated email from the ASF dual-hosted git repository.
xintongsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 99dff18a [Feature][Integration][Java] Add built-in support for Azure
OpenAI Chat Model (#695)
99dff18a is described below
commit 99dff18a00fccfe3b56cdbd54055ba578398c15c
Author: Alan Z. <[email protected]>
AuthorDate: Wed May 27 05:46:36 2026 -0700
[Feature][Integration][Java] Add built-in support for Azure OpenAI Chat
Model (#695)
---
.../flink/agents/api/resource/ResourceName.java | 6 +
docs/content/docs/development/chat_models.md | 61 ++++-
.../test/ChatModelIntegrationAgent.java | 18 ++
.../integration/test/ChatModelIntegrationTest.java | 13 +-
integrations/chat-models/openai/pom.xml | 6 +
.../openai/AzureOpenAIChatModelConnection.java | 294 +++++++++++++++++++++
.../openai/AzureOpenAIChatModelSetup.java | 121 +++++++++
.../openai/OpenAIChatCompletionsUtils.java | 229 ++++++++++++++++
.../openai/OpenAICompletionsConnection.java | 191 +------------
.../openai/AzureOpenAIChatModelConnectionTest.java | 131 +++++++++
.../openai/AzureOpenAIChatModelSetupTest.java | 145 ++++++++++
python/flink_agents/api/resource.py | 4 +
.../chat_models/azure/azure_openai_chat_model.py | 30 ++-
.../azure/tests/test_azure_openai_chat_model.py | 39 +++
14 files changed, 1091 insertions(+), 197 deletions(-)
diff --git
a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java
b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java
index d6dfa697..17cbe177 100644
--- a/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java
+++ b/api/src/main/java/org/apache/flink/agents/api/resource/ResourceName.java
@@ -83,6 +83,12 @@ public final class ResourceName {
public static final String OPENAI_RESPONSES_SETUP =
"org.apache.flink.agents.integrations.chatmodels.openai.OpenAIResponsesModelSetup";
+ // Azure OpenAI
+ public static final String AZURE_OPENAI_CONNECTION =
+
"org.apache.flink.agents.integrations.chatmodels.openai.AzureOpenAIChatModelConnection";
+ public static final String AZURE_OPENAI_SETUP =
+
"org.apache.flink.agents.integrations.chatmodels.openai.AzureOpenAIChatModelSetup";
+
// Python Wrapper
public static final String PYTHON_WRAPPER_CONNECTION =
"org.apache.flink.agents.api.chat.model.python.PythonChatModelConnection";
diff --git a/docs/content/docs/development/chat_models.md
b/docs/content/docs/development/chat_models.md
index 5e619ac4..0fc50545 100644
--- a/docs/content/docs/development/chat_models.md
+++ b/docs/content/docs/development/chat_models.md
@@ -487,10 +487,6 @@ Model availability and specifications may change. Always
check the official Azur
Azure OpenAI provides access to OpenAI models (GPT-4, GPT-4o, etc.) through
Azure's cloud infrastructure, using the same OpenAI SDK with Azure-specific
authentication and endpoints. This offers enterprise security, compliance, and
regional availability while using familiar OpenAI APIs.
-{{< hint info >}}
-Azure OpenAI is only supported in Python currently. To use Azure OpenAI from
Java agents, see [Using Cross-Language
Providers](#using-cross-language-providers).
-{{< /hint >}}
-
{{< hint warning >}}
**Azure OpenAI vs Azure AI:** Azure OpenAI uses the OpenAI SDK to access
OpenAI models (GPT-4, etc.) hosted on Azure. If you want to use other models
like Llama, Mistral, or Phi deployed via Azure AI Studio, see [Azure
AI](#azure-ai) instead.
{{< /hint >}}
@@ -517,6 +513,19 @@ Azure OpenAI is only supported in Python currently. To use
Azure OpenAI from Jav
{{< /tab >}}
+{{< tab "Java" >}}
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `api_key` | String | Required | Azure OpenAI API key for authentication |
+| `api_version` | String | Required | Azure OpenAI REST API version (e.g.,
"2024-02-01"). See [API
versions](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning)
|
+| `azure_endpoint` | String | Required | Azure OpenAI endpoint URL (e.g.,
`https://{resource-name}.openai.azure.com`) — either a direct Azure resource or
a proxy/gateway URL that fronts an Azure OpenAI service |
+| `timeout` | int | None | Timeout in seconds for API requests; must be
greater than 0, otherwise ignored (SDK default applies) |
+| `max_retries` | int | None | Maximum number of API retry attempts; must be
non-negative, otherwise ignored (SDK default applies) |
+| `azure_url_path_mode` | String | `"AUTO"` | Controls how the SDK constructs
Azure OpenAI request URLs. One of `"AUTO"`, `"LEGACY"`, or `"UNIFIED"`. Custom
gateways that proxy Azure OpenAI typically need `"LEGACY"` to force the
`/openai/deployments/{model}` path |
+
+{{< /tab >}}
+
{{< /tabs >}}
#### AzureOpenAIChatModelSetup Parameters
@@ -539,6 +548,22 @@ Azure OpenAI is only supported in Python currently. To use
Azure OpenAI from Jav
{{< /tab >}}
+{{< tab "Java" >}}
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `connection` | String | Required | Reference to connection method name |
+| `model` | String | Required | Azure deployment name (not the underlying
OpenAI model name) |
+| `model_of_azure_deployment` | String | None | The underlying model name
(e.g., 'gpt-4', 'gpt-4o'). Used solely for token metrics tracking |
+| `prompt` | Prompt \| String | None | Prompt template or reference to prompt
resource |
+| `tools` | List<String> | None | List of tool names available to the model |
+| `temperature` | double | None | Sampling temperature (0.0 to 2.0). Not
supported by reasoning models |
+| `max_tokens` | int | None | Maximum number of tokens to generate (must be
greater than 0) |
+| `logprobs` | boolean | `false` | Whether to return log probabilities of
output tokens |
+| `additional_kwargs` | Map<String, Object> | `{}` | Additional Azure OpenAI
API parameters (forwarded to the OpenAI request body) |
+
+{{< /tab >}}
+
{{< /tabs >}}
#### Usage Example
@@ -574,6 +599,34 @@ class MyAgent(Agent):
```
{{< /tab >}}
+{{< tab "Java" >}}
+```java
+public class MyAgent extends Agent {
+ @ChatModelConnection
+ public static ResourceDescriptor azureOpenAIConnection() {
+ return
ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.AZURE_OPENAI_CONNECTION)
+ .addInitialArgument("api_key", "<your-api-key>")
+ .addInitialArgument("api_version", "2024-02-01")
+ .addInitialArgument("azure_endpoint",
"https://your-resource.openai.azure.com")
+ .build();
+ }
+
+ @ChatModelSetup
+ public static ResourceDescriptor azureOpenAIChatModel() {
+ return
ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.AZURE_OPENAI_SETUP)
+ .addInitialArgument("connection", "azureOpenAIConnection")
+ .addInitialArgument("model", "my-gpt4-deployment") //
Your Azure deployment name
+ .addInitialArgument("model_of_azure_deployment", "gpt-4") //
Underlying model for metrics
+ .addInitialArgument("temperature", 0.3d)
+ .addInitialArgument("max_tokens", 1000)
+ .build();
+ }
+
+ ...
+}
+```
+{{< /tab >}}
+
{{< /tabs >}}
#### Available Models
diff --git
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java
index 2c56b10a..4492a8f4 100644
---
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java
+++
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationAgent.java
@@ -91,6 +91,16 @@ public class ChatModelIntegrationAgent extends Agent {
ResourceName.ChatModel.OPENAI_RESPONSES_CONNECTION)
.addInitialArgument("api_key",
System.getenv().get("OPENAI_API_KEY"))
.build();
+ } else if (provider.equals("AZURE_OPENAI")) {
+ return ResourceDescriptor.Builder.newBuilder(
+ ResourceName.ChatModel.AZURE_OPENAI_CONNECTION)
+ .addInitialArgument("api_key",
System.getenv().get("AZURE_OPENAI_API_KEY"))
+ .addInitialArgument(
+ "api_version",
System.getenv().get("AZURE_OPENAI_API_VERSION"))
+ .addInitialArgument(
+ "azure_endpoint",
System.getenv().get("AZURE_OPENAI_ENDPOINT"))
+ .addInitialArgument("azure_url_path_mode", "LEGACY")
+ .build();
} else if (provider.equals("ANTHROPIC")) {
String apiKey = System.getenv().get("ANTHROPIC_API_KEY");
return ResourceDescriptor.Builder.newBuilder(
@@ -150,6 +160,14 @@ public class ChatModelIntegrationAgent extends Agent {
"tools",
List.of("calculateBMI", "convertTemperature",
"createRandomNumber"))
.build();
+ } else if (provider.equals("AZURE_OPENAI")) {
+ return
ResourceDescriptor.Builder.newBuilder(ResourceName.ChatModel.AZURE_OPENAI_SETUP)
+ .addInitialArgument("connection", "chatModelConnection")
+ .addInitialArgument("model",
System.getenv().get("AZURE_OPENAI_DEPLOYMENT"))
+ .addInitialArgument(
+ "tools",
+ List.of("calculateBMI", "convertTemperature",
"createRandomNumber"))
+ .build();
} else {
throw new RuntimeException(String.format("Unknown model provider
%s", provider));
}
diff --git
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
index 75a3d5c1..de967858 100644
---
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
+++
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
@@ -27,8 +27,6 @@ import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
@@ -41,7 +39,6 @@ import static
org.apache.flink.agents.integration.test.ChatModelIntegrationAgent
* prompts.
*/
public class ChatModelIntegrationTest extends OllamaPreparationUtils {
- private static final Logger LOG =
LoggerFactory.getLogger(ChatModelIntegrationTest.class);
private static final String API_KEY = "_API_KEY";
private static final String OLLAMA = "OLLAMA";
@@ -53,7 +50,15 @@ public class ChatModelIntegrationTest extends
OllamaPreparationUtils {
}
@ParameterizedTest()
- @ValueSource(strings = {"ANTHROPIC", "AZURE", "OLLAMA", "OPENAI",
"OPENAI_RESPONSES"})
+ @ValueSource(
+ strings = {
+ "ANTHROPIC",
+ "AZURE",
+ "AZURE_OPENAI",
+ "OLLAMA",
+ "OPENAI",
+ "OPENAI_RESPONSES"
+ })
public void testChatModeIntegration(String provider) throws Exception {
Assumptions.assumeTrue(
(OLLAMA.equals(provider) && ollamaReady)
diff --git a/integrations/chat-models/openai/pom.xml
b/integrations/chat-models/openai/pom.xml
index e1c31cb5..ba0c6ce1 100644
--- a/integrations/chat-models/openai/pom.xml
+++ b/integrations/chat-models/openai/pom.xml
@@ -43,6 +43,12 @@ under the License.
<artifactId>openai-java</artifactId>
<version>${openai.version}</version>
</dependency>
+
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ <version>${slf4j.version}</version>
+ </dependency>
</dependencies>
</project>
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
new file mode 100644
index 00000000..7d6b5c2c
--- /dev/null
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnection.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.integrations.chatmodels.openai;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.azure.AzureOpenAIServiceVersion;
+import com.openai.azure.AzureUrlPathMode;
+import com.openai.azure.credential.AzureApiKeyCredential;
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.core.JsonValue;
+import com.openai.models.ChatModel;
+import com.openai.models.FunctionDefinition;
+import com.openai.models.FunctionParameters;
+import com.openai.models.chat.completions.ChatCompletion;
+import com.openai.models.chat.completions.ChatCompletionCreateParams;
+import com.openai.models.chat.completions.ChatCompletionFunctionTool;
+import com.openai.models.chat.completions.ChatCompletionTool;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.model.BaseChatModelConnection;
+import org.apache.flink.agents.api.resource.ResourceContext;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.apache.flink.agents.api.tools.Tool;
+import org.apache.flink.agents.api.tools.ToolMetadata;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Chat model integration for Azure OpenAI Service. Built on the openai-java
SDK using its built-in
+ * Azure support ({@link AzureOpenAIServiceVersion}, {@link
AzureApiKeyCredential}).
+ *
+ * <p>Required connection arguments:
+ *
+ * <ul>
+ * <li><b>api_key</b>: Azure OpenAI API key
+ * <li><b>api_version</b>: Azure OpenAI REST API version (e.g., {@code
"2024-02-01"})
+ * <li><b>azure_endpoint</b>: base URL for the Azure OpenAI deployment —
either a direct Azure
+ * resource (e.g., {@code "https://your-resource.openai.azure.com"}) or
a proxy/gateway URL
+ * that fronts an Azure OpenAI service. Custom gateway hostnames also
require setting {@code
+ * azure_url_path_mode} below.
+ * </ul>
+ *
+ * <p>Optional connection arguments:
+ *
+ * <ul>
+ * <li><b>timeout</b> (Number): seconds before an API call times out; must
be greater than 0,
+ * otherwise ignored (SDK default applies)
+ * <li><b>max_retries</b> (Number): retry attempts on failure; must be
non-negative, otherwise
+ * ignored (SDK default applies)
+ * <li><b>azure_url_path_mode</b> (String): one of {@code "AUTO"}, {@code
"LEGACY"}, or {@code
+ * "UNIFIED"} (default {@code "AUTO"}). Controls how the SDK constructs
Azure OpenAI request
+ * URLs. In {@code AUTO} mode the SDK only treats the endpoint as Azure
when its hostname
+ * matches a known suffix (e.g. {@code .openai.azure.com}); custom
gateways that proxy Azure
+ * OpenAI need {@code LEGACY} to force the {@code
/openai/deployments/{model}} path.
+ * </ul>
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * @ChatModelConnection
+ * public static ResourceDescriptor azureOpenAIConnection() {
+ * return ResourceDescriptor.Builder.newBuilder(
+ * AzureOpenAIChatModelConnection.class.getName())
+ * .addInitialArgument("api_key",
System.getenv("AZURE_OPENAI_API_KEY"))
+ * .addInitialArgument("api_version", "2024-02-01")
+ * .addInitialArgument("azure_endpoint",
"https://my-resource.openai.azure.com")
+ * .build();
+ * }
+ * }</pre>
+ */
+public class AzureOpenAIChatModelConnection extends BaseChatModelConnection {
+
+ private static final ObjectMapper mapper = new ObjectMapper();
+
+ private static final Set<String> RESERVED_KWARG_KEYS =
+ Set.of("model", "model_of_azure_deployment", "temperature",
"max_tokens", "logprobs");
+
+ private final OpenAIClient client;
+
+ public AzureOpenAIChatModelConnection(
+ ResourceDescriptor descriptor, ResourceContext resourceContext) {
+ super(descriptor, resourceContext);
+
+ String apiKey = descriptor.getArgument("api_key");
+ if (apiKey == null || apiKey.isBlank()) {
+ throw new IllegalArgumentException("api_key should not be null or
empty.");
+ }
+
+ String apiVersion = descriptor.getArgument("api_version");
+ if (apiVersion == null || apiVersion.isBlank()) {
+ throw new IllegalArgumentException("api_version should not be null
or empty.");
+ }
+
+ String azureEndpoint = descriptor.getArgument("azure_endpoint");
+ if (azureEndpoint == null || azureEndpoint.isBlank()) {
+ throw new IllegalArgumentException("azure_endpoint should not be
null or empty.");
+ }
+
+ OpenAIOkHttpClient.Builder clientBuilder =
+ OpenAIOkHttpClient.builder()
+ .baseUrl(azureEndpoint)
+ .credential(AzureApiKeyCredential.create(apiKey))
+
.azureServiceVersion(AzureOpenAIServiceVersion.fromString(apiVersion));
+
+ Integer timeoutSeconds = descriptor.getArgument("timeout");
+ if (timeoutSeconds != null && timeoutSeconds > 0) {
+ clientBuilder.timeout(Duration.ofSeconds(timeoutSeconds));
+ }
+
+ Integer maxRetries = descriptor.getArgument("max_retries");
+ if (maxRetries != null && maxRetries >= 0) {
+ clientBuilder.maxRetries(maxRetries);
+ }
+
+ String azureUrlPathMode =
descriptor.getArgument("azure_url_path_mode");
+ if (azureUrlPathMode != null && !azureUrlPathMode.isBlank()) {
+ try {
+ clientBuilder.azureUrlPathMode(
+
AzureUrlPathMode.valueOf(azureUrlPathMode.trim().toUpperCase()));
+ } catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException(
+ "azure_url_path_mode must be one of AUTO, LEGACY, or
UNIFIED; got: "
+ + azureUrlPathMode,
+ e);
+ }
+ }
+
+ this.client = clientBuilder.build();
+ }
+
+ @Override
+ public ChatMessage chat(
+ List<ChatMessage> messages, List<Tool> tools, Map<String, Object>
arguments) {
+ try {
+ Map<String, Object> mutableArgs =
+ arguments != null ? new HashMap<>(arguments) : new
HashMap<>();
+
+ String azureDeployment = (String) mutableArgs.remove("model");
+ if (azureDeployment == null || azureDeployment.isBlank()) {
+ throw new IllegalArgumentException("model is required for
Azure OpenAI API calls");
+ }
+ String modelOfAzureDeployment =
+ (String) mutableArgs.remove("model_of_azure_deployment");
+
+ ChatCompletionCreateParams.Builder builder =
+ ChatCompletionCreateParams.builder()
+ .model(ChatModel.of(azureDeployment))
+
.messages(OpenAIChatCompletionsUtils.convertToOpenAIMessages(messages));
+
+ if (tools != null && !tools.isEmpty()) {
+ builder.tools(convertTools(tools));
+ }
+
+ Object temperature = mutableArgs.remove("temperature");
+ if (temperature instanceof Number) {
+ builder.temperature(((Number) temperature).doubleValue());
+ }
+
+ Object maxTokens = mutableArgs.remove("max_tokens");
+ if (maxTokens instanceof Number) {
+ builder.maxCompletionTokens(((Number) maxTokens).longValue());
+ }
+
+ Object logprobs = mutableArgs.remove("logprobs");
+ if (Boolean.TRUE.equals(logprobs)) {
+ builder.logprobs(true);
+ }
+
+ @SuppressWarnings("unchecked")
+ Map<String, Object> additionalKwargs =
+ (Map<String, Object>)
mutableArgs.remove("additional_kwargs");
+ if (additionalKwargs != null) {
+ Set<String> collisions = new
HashSet<>(additionalKwargs.keySet());
+ collisions.retainAll(RESERVED_KWARG_KEYS);
+ if (!collisions.isEmpty()) {
+ throw new IllegalArgumentException(
+ "additional_kwargs must not contain reserved typed
fields: "
+ + collisions
+ + ". Set these via the corresponding Setup
field instead.");
+ }
+ for (Map.Entry<String, Object> entry :
additionalKwargs.entrySet()) {
+ builder.putAdditionalBodyProperty(
+ entry.getKey(), toJsonValue(entry.getValue()));
+ }
+ }
+
+ ChatCompletion completion =
client.chat().completions().create(builder.build());
+
+ ChatMessage response =
+ OpenAIChatCompletionsUtils.convertFromOpenAIMessage(
+ completion.choices().get(0).message());
+
+ if (modelOfAzureDeployment != null
+ && !modelOfAzureDeployment.isBlank()
+ && completion.usage().isPresent()) {
+ recordTokenMetrics(
+ modelOfAzureDeployment,
+ completion.usage().get().promptTokens(),
+ completion.usage().get().completionTokens());
+ }
+
+ return response;
+ } catch (IllegalArgumentException e) {
+ throw e;
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to call Azure OpenAI chat
completions API.", e);
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ this.client.close();
+ }
+
+ private List<ChatCompletionTool> convertTools(List<Tool> tools) {
+ List<ChatCompletionTool> openaiTools = new ArrayList<>(tools.size());
+ for (Tool tool : tools) {
+ ToolMetadata metadata = tool.getMetadata();
+ FunctionDefinition.Builder functionBuilder =
+ FunctionDefinition.builder()
+ .name(metadata.getName())
+ .description(metadata.getDescription());
+
+ String schema = metadata.getInputSchema();
+ if (schema != null && !schema.isBlank()) {
+ functionBuilder.parameters(parseFunctionParameters(schema));
+ }
+
+ ChatCompletionFunctionTool functionTool =
+ ChatCompletionFunctionTool.builder()
+ .function(functionBuilder.build())
+ .type(JsonValue.from("function"))
+ .build();
+
+ openaiTools.add(ChatCompletionTool.ofFunction(functionTool));
+ }
+ return openaiTools;
+ }
+
+ private FunctionParameters parseFunctionParameters(String schemaJson) {
+ try {
+ JsonNode root = mapper.readTree(schemaJson);
+ if (root == null || !root.isObject()) {
+ return FunctionParameters.builder().build();
+ }
+ FunctionParameters.Builder builder = FunctionParameters.builder();
+ root.fields()
+ .forEachRemaining(
+ entry ->
+ builder.putAdditionalProperty(
+ entry.getKey(),
+
JsonValue.fromJsonNode(entry.getValue())));
+ return builder.build();
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to parse tool schema JSON.", e);
+ }
+ }
+
+ private JsonValue toJsonValue(Object value) {
+ if (value instanceof JsonValue) {
+ return (JsonValue) value;
+ }
+ if (value instanceof String
+ || value instanceof Number
+ || value instanceof Boolean
+ || value == null) {
+ return JsonValue.from(value);
+ }
+ return JsonValue.fromJsonNode(mapper.valueToTree(value));
+ }
+}
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetup.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetup.java
new file mode 100644
index 00000000..44a7c843
--- /dev/null
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetup.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.integrations.chatmodels.openai;
+
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.apache.flink.agents.api.resource.ResourceContext;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Setup for Azure OpenAI Chat Completions.
+ *
+ * <p>{@code model} (inherited from {@link BaseChatModelSetup}) is the Azure
deployment name, not
+ * the underlying OpenAI model name. The underlying model name can be supplied
via {@code
+ * model_of_azure_deployment} and is used solely for token-metrics tracking.
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * @ChatModelSetup
+ * public static ResourceDescriptor azureOpenAIModel() {
+ * return
ResourceDescriptor.Builder.newBuilder(AzureOpenAIChatModelSetup.class.getName())
+ * .addInitialArgument("connection", "myAzureOpenAIConnection")
+ * .addInitialArgument("model", "my-gpt4o-deployment")
+ * .addInitialArgument("model_of_azure_deployment", "gpt-4o")
+ * .addInitialArgument("temperature", 0.3d)
+ * .addInitialArgument("max_tokens", 500)
+ * .build();
+ * }
+ * }</pre>
+ */
+public class AzureOpenAIChatModelSetup extends BaseChatModelSetup {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(AzureOpenAIChatModelSetup.class);
+
+ private final String modelOfAzureDeployment;
+ private final Double temperature;
+ private final Integer maxTokens;
+ private final Boolean logprobs;
+ private final Map<String, Object> additionalKwargs;
+
+ public AzureOpenAIChatModelSetup(
+ ResourceDescriptor descriptor, ResourceContext resourceContext) {
+ super(descriptor, resourceContext);
+
+ this.modelOfAzureDeployment =
descriptor.getArgument("model_of_azure_deployment");
+ if (this.modelOfAzureDeployment == null ||
this.modelOfAzureDeployment.isBlank()) {
+ LOG.warn(
+ "model_of_azure_deployment is not set; token usage metrics
will not be recorded for this Azure OpenAI deployment '{}'.",
+ this.model);
+ }
+
+ this.temperature =
+
Optional.ofNullable(descriptor.<Number>getArgument("temperature"))
+ .map(Number::doubleValue)
+ .orElse(null);
+ if (this.temperature != null && (this.temperature < 0.0 ||
this.temperature > 2.0)) {
+ throw new IllegalArgumentException("temperature must be between
0.0 and 2.0");
+ }
+
+ this.maxTokens =
+
Optional.ofNullable(descriptor.<Number>getArgument("max_tokens"))
+ .map(Number::intValue)
+ .orElse(null);
+ if (this.maxTokens != null && this.maxTokens <= 0) {
+ throw new IllegalArgumentException("max_tokens must be greater
than 0");
+ }
+
+ this.logprobs =
+
Optional.ofNullable(descriptor.<Boolean>getArgument("logprobs")).orElse(false);
+
+ Map<String, Object> additional =
+ Optional.ofNullable(
+ descriptor.<Map<String,
Object>>getArgument("additional_kwargs"))
+ .map(HashMap::new)
+ .orElseGet(HashMap::new);
+ this.additionalKwargs = additional;
+ }
+
+ @Override
+ public Map<String, Object> getParameters() {
+ Map<String, Object> params = new HashMap<>();
+ if (model != null) {
+ params.put("model", model);
+ }
+ if (modelOfAzureDeployment != null) {
+ params.put("model_of_azure_deployment", modelOfAzureDeployment);
+ }
+ params.put("logprobs", logprobs);
+ if (temperature != null) {
+ params.put("temperature", temperature);
+ }
+ if (maxTokens != null) {
+ params.put("max_tokens", maxTokens);
+ }
+ if (additionalKwargs != null && !additionalKwargs.isEmpty()) {
+ params.put("additional_kwargs", additionalKwargs);
+ }
+ return params;
+ }
+}
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatCompletionsUtils.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatCompletionsUtils.java
new file mode 100644
index 00000000..c9d8c8d9
--- /dev/null
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAIChatCompletionsUtils.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.integrations.chatmodels.openai;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.core.JsonValue;
+import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessage;
+import
com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageToolCall;
+import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
+import com.openai.models.chat.completions.ChatCompletionToolMessageParam;
+import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * Static helpers for converting between Flink Agents {@link ChatMessage} and
OpenAI Chat
+ * Completions API message types. Restricted to message conversion (no
tool-definition conversion —
+ * that stays per-connection).
+ *
+ * <p>Used by both {@code OpenAICompletionsConnection} (OpenAI /
OpenAI-compatible providers) and
+ * {@code AzureOpenAIChatModelConnection} (Azure OpenAI). Both rely on the
same openai-java SDK
+ * message types.
+ */
+final class OpenAIChatCompletionsUtils {
+
+ private OpenAIChatCompletionsUtils() {}
+
+ private static final ObjectMapper mapper = new ObjectMapper();
+ private static final TypeReference<Map<String, Object>> MAP_TYPE = new
TypeReference<>() {};
+
+ /** Convert a list of Flink Agents ChatMessages to OpenAI
ChatCompletionMessageParams. */
+ public static List<ChatCompletionMessageParam> convertToOpenAIMessages(
+ List<ChatMessage> messages) {
+ return messages.stream()
+ .map(OpenAIChatCompletionsUtils::convertToOpenAIMessage)
+ .collect(Collectors.toList());
+ }
+
+ /** Convert a single Flink Agents ChatMessage to an OpenAI
ChatCompletionMessageParam. */
+ public static ChatCompletionMessageParam
convertToOpenAIMessage(ChatMessage message) {
+ MessageRole role = message.getRole();
+ String content = Optional.ofNullable(message.getContent()).orElse("");
+
+ switch (role) {
+ case SYSTEM:
+ return ChatCompletionMessageParam.ofSystem(
+
ChatCompletionSystemMessageParam.builder().content(content).build());
+ case USER:
+ return ChatCompletionMessageParam.ofUser(
+
ChatCompletionUserMessageParam.builder().content(content).build());
+ case ASSISTANT:
+ ChatCompletionAssistantMessageParam.Builder assistantBuilder =
+ ChatCompletionAssistantMessageParam.builder();
+ if (!content.isEmpty()) {
+ assistantBuilder.content(content);
+ }
+ List<Map<String, Object>> toolCalls = message.getToolCalls();
+ if (toolCalls != null && !toolCalls.isEmpty()) {
+
assistantBuilder.toolCalls(convertAssistantToolCalls(toolCalls));
+ }
+ Object refusal = message.getExtraArgs().get("refusal");
+ if (refusal instanceof String) {
+ assistantBuilder.refusal((String) refusal);
+ }
+ return
ChatCompletionMessageParam.ofAssistant(assistantBuilder.build());
+ case TOOL:
+ ChatCompletionToolMessageParam.Builder toolBuilder =
+
ChatCompletionToolMessageParam.builder().content(content);
+ Object toolCallId = message.getExtraArgs().get("externalId");
+ if (toolCallId == null) {
+ throw new IllegalArgumentException(
+ "Tool message must have an externalId in
extraArgs.");
+ }
+ toolBuilder.toolCallId(toolCallId.toString());
+ return ChatCompletionMessageParam.ofTool(toolBuilder.build());
+ default:
+ throw new IllegalArgumentException("Unsupported role: " +
role);
+ }
+ }
+
+ /**
+ * Convert an OpenAI {@link ChatCompletionMessage} to a Flink Agents
{@link ChatMessage}. {@code
+ * message.refusal()} is written as {@code extraArgs["refusal"]} on the
returned ChatMessage
+ * when present, preserving prior Java behavior.
+ */
+ public static ChatMessage convertFromOpenAIMessage(ChatCompletionMessage
message) {
+ String content = message.content().orElse("");
+ ChatMessage response = ChatMessage.assistant(content);
+
+ message.refusal().ifPresent(refusal ->
response.getExtraArgs().put("refusal", refusal));
+
+ List<ChatCompletionMessageToolCall> toolCalls =
message.toolCalls().orElse(List.of());
+ if (!toolCalls.isEmpty()) {
+ response.setToolCalls(convertResponseToolCalls(toolCalls));
+ }
+ return response;
+ }
+
+ private static List<ChatCompletionMessageToolCall>
convertAssistantToolCalls(
+ List<Map<String, Object>> toolCalls) {
+ List<ChatCompletionMessageToolCall> result = new
ArrayList<>(toolCalls.size());
+ for (Map<String, Object> call : toolCalls) {
+ Object type = call.getOrDefault("type", "function");
+ if (!"function".equals(String.valueOf(type))) {
+ continue;
+ }
+
+ Map<String, Object> functionPayload = toMap(call.get("function"));
+ ChatCompletionMessageFunctionToolCall.Function.Builder
functionBuilder =
+ ChatCompletionMessageFunctionToolCall.Function.builder();
+
+ Object functionName = functionPayload.get("name");
+ if (functionName != null) {
+ functionBuilder.name(functionName.toString());
+ }
+
+ Object arguments = functionPayload.get("arguments");
+ functionBuilder.arguments(serializeArguments(arguments));
+
+ Object idObj = call.get("id");
+ if (idObj == null) {
+ throw new IllegalArgumentException("Tool call must have an
id.");
+ }
+ String toolCallId = idObj.toString();
+
+ ChatCompletionMessageFunctionToolCall.Builder toolCallBuilder =
+ ChatCompletionMessageFunctionToolCall.builder()
+ .id(toolCallId)
+ .function(functionBuilder.build())
+ .type(JsonValue.from(String.valueOf(type)));
+
+
result.add(ChatCompletionMessageToolCall.ofFunction(toolCallBuilder.build()));
+ }
+ return result;
+ }
+
+ private static List<Map<String, Object>> convertResponseToolCalls(
+ List<ChatCompletionMessageToolCall> toolCalls) {
+ List<Map<String, Object>> result = new ArrayList<>(toolCalls.size());
+ for (ChatCompletionMessageToolCall toolCall : toolCalls) {
+ if (!toolCall.isFunction()) {
+ continue;
+ }
+
+ ChatCompletionMessageFunctionToolCall functionToolCall =
toolCall.asFunction();
+ Map<String, Object> callMap = new LinkedHashMap<>();
+ String toolCallId = functionToolCall.id();
+ if (toolCallId == null || toolCallId.isBlank()) {
+ throw new IllegalStateException("OpenAI tool call ID is null
or empty.");
+ }
+
+ callMap.put("id", toolCallId);
+ callMap.put("type", "function");
+
+ ChatCompletionMessageFunctionToolCall.Function function =
functionToolCall.function();
+ Map<String, Object> functionMap = new LinkedHashMap<>();
+ functionMap.put("name", function.name());
+ functionMap.put("arguments", parseArguments(function.arguments()));
+ callMap.put("function", functionMap);
+ callMap.put("original_id", toolCallId);
+ result.add(callMap);
+ }
+ return result;
+ }
+
+ private static Map<String, Object> parseArguments(String arguments) {
+ if (arguments == null || arguments.isBlank()) {
+ return Map.of();
+ }
+ try {
+ return mapper.readValue(arguments, MAP_TYPE);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to parse tool arguments: " +
arguments, e);
+ }
+ }
+
+ private static String serializeArguments(Object arguments) {
+ if (arguments == null) {
+ return "{}";
+ }
+ if (arguments instanceof String) {
+ return (String) arguments;
+ }
+ try {
+ return mapper.writeValueAsString(arguments);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to serialize tool call
arguments.", e);
+ }
+ }
+
+ private static Map<String, Object> toMap(Object value) {
+ if (value instanceof Map) {
+ @SuppressWarnings("unchecked")
+ Map<String, Object> casted = (Map<String, Object>) value;
+ return new LinkedHashMap<>(casted);
+ }
+ if (value == null) {
+ return new LinkedHashMap<>();
+ }
+ return mapper.convertValue(value, MAP_TYPE);
+ }
+}
diff --git
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
index 15307307..e4947e8f 100644
---
a/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
+++
b/integrations/chat-models/openai/src/main/java/org/apache/flink/agents/integrations/chatmodels/openai/OpenAICompletionsConnection.java
@@ -18,7 +18,6 @@
package org.apache.flink.agents.integrations.chatmodels.openai;
import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.openai.client.OpenAIClient;
@@ -29,19 +28,10 @@ import com.openai.models.FunctionDefinition;
import com.openai.models.FunctionParameters;
import com.openai.models.ReasoningEffort;
import com.openai.models.chat.completions.ChatCompletion;
-import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import com.openai.models.chat.completions.ChatCompletionFunctionTool;
-import com.openai.models.chat.completions.ChatCompletionMessage;
-import
com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall;
-import com.openai.models.chat.completions.ChatCompletionMessageParam;
-import com.openai.models.chat.completions.ChatCompletionMessageToolCall;
-import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
import com.openai.models.chat.completions.ChatCompletionTool;
-import com.openai.models.chat.completions.ChatCompletionToolMessageParam;
-import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
-import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.chat.model.BaseChatModelConnection;
import org.apache.flink.agents.api.resource.ResourceContext;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
@@ -51,11 +41,8 @@ import org.apache.flink.agents.api.tools.ToolMetadata;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
-import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
-import java.util.stream.Collectors;
/**
* A chat model integration for the OpenAI Chat Completions service using the
official Java SDK.
@@ -91,9 +78,7 @@ import java.util.stream.Collectors;
*/
public class OpenAICompletionsConnection extends BaseChatModelConnection {
- private static final TypeReference<Map<String, Object>> MAP_TYPE = new
TypeReference<>() {};
-
- private final ObjectMapper mapper = new ObjectMapper();
+ private static final ObjectMapper mapper = new ObjectMapper();
private final OpenAIClient client;
private final String defaultModel;
@@ -140,7 +125,9 @@ public class OpenAICompletionsConnection extends
BaseChatModelConnection {
try {
ChatCompletionCreateParams params = buildRequest(messages, tools,
arguments);
ChatCompletion completion =
client.chat().completions().create(params);
- ChatMessage response = convertResponse(completion);
+ ChatMessage response =
+ OpenAIChatCompletionsUtils.convertFromOpenAIMessage(
+ completion.choices().get(0).message());
// Record token metrics
if (completion.usage().isPresent()) {
@@ -176,10 +163,7 @@ public class OpenAICompletionsConnection extends
BaseChatModelConnection {
ChatCompletionCreateParams.Builder builder =
ChatCompletionCreateParams.builder()
.model(ChatModel.of(modelName))
- .messages(
- messages.stream()
- .map(this::convertToOpenAIMessage)
- .collect(Collectors.toList()));
+
.messages(OpenAIChatCompletionsUtils.convertToOpenAIMessages(messages));
if (tools != null && !tools.isEmpty()) {
builder.tools(convertTools(tools, strictMode));
@@ -272,145 +256,6 @@ public class OpenAICompletionsConnection extends
BaseChatModelConnection {
}
}
- private ChatCompletionMessageParam convertToOpenAIMessage(ChatMessage
message) {
- MessageRole role = message.getRole();
- String content = Optional.ofNullable(message.getContent()).orElse("");
-
- switch (role) {
- case SYSTEM:
- return ChatCompletionMessageParam.ofSystem(
-
ChatCompletionSystemMessageParam.builder().content(content).build());
- case USER:
- return ChatCompletionMessageParam.ofUser(
-
ChatCompletionUserMessageParam.builder().content(content).build());
- case ASSISTANT:
- ChatCompletionAssistantMessageParam.Builder assistantBuilder =
- ChatCompletionAssistantMessageParam.builder();
- if (!content.isEmpty()) {
- assistantBuilder.content(content);
- }
- List<Map<String, Object>> toolCalls = message.getToolCalls();
- if (toolCalls != null && !toolCalls.isEmpty()) {
-
assistantBuilder.toolCalls(convertAssistantToolCalls(toolCalls));
- }
- Object refusal = message.getExtraArgs().get("refusal");
- if (refusal instanceof String) {
- assistantBuilder.refusal((String) refusal);
- }
- return
ChatCompletionMessageParam.ofAssistant(assistantBuilder.build());
- case TOOL:
- ChatCompletionToolMessageParam.Builder toolBuilder =
-
ChatCompletionToolMessageParam.builder().content(content);
- Object toolCallId = message.getExtraArgs().get("externalId");
- if (toolCallId == null) {
- throw new IllegalArgumentException(
- "Tool message must have an externalId in
extraArgs.");
- }
- toolBuilder.toolCallId(toolCallId.toString());
- return ChatCompletionMessageParam.ofTool(toolBuilder.build());
- default:
- throw new IllegalArgumentException("Unsupported role: " +
role);
- }
- }
-
- private List<ChatCompletionMessageToolCall> convertAssistantToolCalls(
- List<Map<String, Object>> toolCalls) {
- List<ChatCompletionMessageToolCall> result = new
ArrayList<>(toolCalls.size());
- for (Map<String, Object> call : toolCalls) {
- Object type = call.getOrDefault("type", "function");
- if (!"function".equals(String.valueOf(type))) {
- continue;
- }
-
- Map<String, Object> functionPayload = toMap(call.get("function"));
- ChatCompletionMessageFunctionToolCall.Function.Builder
functionBuilder =
- ChatCompletionMessageFunctionToolCall.Function.builder();
-
- Object functionName = functionPayload.get("name");
- if (functionName != null) {
- functionBuilder.name(functionName.toString());
- }
-
- Object arguments = functionPayload.get("arguments");
- functionBuilder.arguments(serializeArguments(arguments));
-
- Object idObj = call.get("id");
- if (idObj == null) {
- throw new IllegalArgumentException("Tool call must have an
id.");
- }
- String toolCallId = idObj.toString();
-
- ChatCompletionMessageFunctionToolCall.Builder toolCallBuilder =
- ChatCompletionMessageFunctionToolCall.builder()
- .id(toolCallId)
- .function(functionBuilder.build())
- .type(JsonValue.from(String.valueOf(type)));
-
-
result.add(ChatCompletionMessageToolCall.ofFunction(toolCallBuilder.build()));
- }
- return result;
- }
-
- private ChatMessage convertResponse(ChatCompletion completion) {
- List<ChatCompletion.Choice> choices = completion.choices();
- if (choices.isEmpty()) {
- throw new IllegalStateException("OpenAI response did not contain
any choices.");
- }
-
- ChatCompletionMessage message = choices.get(0).message();
- String content = message.content().orElse("");
- ChatMessage response = ChatMessage.assistant(content);
-
- message.refusal().ifPresent(refusal ->
response.getExtraArgs().put("refusal", refusal));
-
- List<ChatCompletionMessageToolCall> toolCalls =
message.toolCalls().orElse(List.of());
- if (!toolCalls.isEmpty()) {
- response.setToolCalls(convertResponseToolCalls(toolCalls));
- }
-
- return response;
- }
-
- private List<Map<String, Object>> convertResponseToolCalls(
- List<ChatCompletionMessageToolCall> toolCalls) {
- List<Map<String, Object>> result = new ArrayList<>(toolCalls.size());
- for (ChatCompletionMessageToolCall toolCall : toolCalls) {
- if (!toolCall.isFunction()) {
- continue;
- }
-
- ChatCompletionMessageFunctionToolCall functionToolCall =
toolCall.asFunction();
- Map<String, Object> callMap = new LinkedHashMap<>();
- String toolCallId = functionToolCall.id();
- if (toolCallId == null || toolCallId.isBlank()) {
- throw new IllegalStateException("OpenAI tool call ID is null
or empty.");
- }
-
- callMap.put("id", toolCallId);
- callMap.put("type", "function");
-
- ChatCompletionMessageFunctionToolCall.Function function =
functionToolCall.function();
- Map<String, Object> functionMap = new LinkedHashMap<>();
- functionMap.put("name", function.name());
- functionMap.put("arguments", parseArguments(function.arguments()));
- callMap.put("function", functionMap);
- callMap.put("original_id", toolCallId);
- result.add(callMap);
- }
- return result;
- }
-
- private Map<String, Object> parseArguments(String arguments) {
- if (arguments == null || arguments.isBlank()) {
- return Map.of();
- }
- try {
- return mapper.readValue(arguments, MAP_TYPE);
- } catch (JsonProcessingException e) {
- throw new RuntimeException("Failed to parse tool arguments: " +
arguments, e);
- }
- }
-
private JsonValue toJsonValue(Object value) {
if (value instanceof JsonValue) {
return (JsonValue) value;
@@ -424,32 +269,6 @@ public class OpenAICompletionsConnection extends
BaseChatModelConnection {
return JsonValue.fromJsonNode(mapper.valueToTree(value));
}
- private String serializeArguments(Object arguments) {
- if (arguments == null) {
- return "{}";
- }
- if (arguments instanceof String) {
- return (String) arguments;
- }
- try {
- return mapper.writeValueAsString(arguments);
- } catch (JsonProcessingException e) {
- throw new RuntimeException("Failed to serialize tool call
arguments.", e);
- }
- }
-
- private Map<String, Object> toMap(Object value) {
- if (value instanceof Map) {
- @SuppressWarnings("unchecked")
- Map<String, Object> casted = (Map<String, Object>) value;
- return new LinkedHashMap<>(casted);
- }
- if (value == null) {
- return new LinkedHashMap<>();
- }
- return mapper.convertValue(value, MAP_TYPE);
- }
-
@Override
public void close() throws Exception {
this.client.close();
diff --git
a/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnectionTest.java
b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnectionTest.java
new file mode 100644
index 00000000..60a29729
--- /dev/null
+++
b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelConnectionTest.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.integrations.chatmodels.openai;
+
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.chat.model.BaseChatModelConnection;
+import org.apache.flink.agents.api.resource.ResourceContext;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Unit tests for {@link AzureOpenAIChatModelConnection} — constructor
validation only, no network
+ * access. End-to-end tests against a real Azure OpenAI deployment live in
{@link
+ * AzureOpenAIChatModelIT}.
+ */
+class AzureOpenAIChatModelConnectionTest {
+
+ private static final ResourceContext NOOP =
ResourceContext.fromGetResource((a, b) -> null);
+
+ private static ResourceDescriptor.Builder connectionDescriptor() {
+ return ResourceDescriptor.Builder.newBuilder(
+ AzureOpenAIChatModelConnection.class.getName());
+ }
+
+ @Test
+ @DisplayName("Constructor throws when api_key is missing")
+ void testConstructorMissingApiKey() {
+ ResourceDescriptor desc =
+ connectionDescriptor()
+ .addInitialArgument("api_version", "2024-02-01")
+ .addInitialArgument("azure_endpoint",
"https://example.openai.azure.com")
+ .build();
+ assertThatThrownBy(() -> new AzureOpenAIChatModelConnection(desc,
NOOP))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("api_key");
+ }
+
+ @Test
+ @DisplayName("Constructor throws when api_version is missing")
+ void testConstructorMissingApiVersion() {
+ ResourceDescriptor desc =
+ connectionDescriptor()
+ .addInitialArgument("api_key", "test-key")
+ .addInitialArgument("azure_endpoint",
"https://example.openai.azure.com")
+ .build();
+ assertThatThrownBy(() -> new AzureOpenAIChatModelConnection(desc,
NOOP))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("api_version");
+ }
+
+ @Test
+ @DisplayName("Constructor throws when azure_endpoint is missing")
+ void testConstructorMissingAzureEndpoint() {
+ ResourceDescriptor desc =
+ connectionDescriptor()
+ .addInitialArgument("api_key", "test-key")
+ .addInitialArgument("api_version", "2024-02-01")
+ .build();
+ assertThatThrownBy(() -> new AzureOpenAIChatModelConnection(desc,
NOOP))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("azure_endpoint");
+ }
+
+ @Test
+ @DisplayName("Constructor succeeds with all required args (no network call
yet)")
+ void testConstructorAllRequiredArgs() {
+ ResourceDescriptor desc =
+ connectionDescriptor()
+ .addInitialArgument("api_key", "test-key")
+ .addInitialArgument("api_version", "2024-02-01")
+ .addInitialArgument("azure_endpoint",
"https://example.openai.azure.com")
+ .build();
+ AzureOpenAIChatModelConnection conn = new
AzureOpenAIChatModelConnection(desc, NOOP);
+ assertThat(conn).isInstanceOf(BaseChatModelConnection.class);
+ }
+
+ @Test
+ @DisplayName("chat() rejects additional_kwargs that collide with reserved
typed fields")
+ void testChatRejectsReservedKeyInAdditionalKwargs() {
+ ResourceDescriptor desc =
+ connectionDescriptor()
+ .addInitialArgument("api_key", "test-key")
+ .addInitialArgument("api_version", "2024-02-01")
+ .addInitialArgument("azure_endpoint",
"https://example.openai.azure.com")
+ .build();
+ AzureOpenAIChatModelConnection conn = new
AzureOpenAIChatModelConnection(desc, NOOP);
+
+ Map<String, Object> args =
+ Map.of(
+ "model",
+ "my-deployment",
+ "temperature",
+ 0.3d,
+ "additional_kwargs",
+ Map.of("temperature", 5.0d));
+
+ assertThatThrownBy(
+ () ->
+ conn.chat(
+ List.of(new
ChatMessage(MessageRole.USER, "hi")),
+ null,
+ args))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("additional_kwargs")
+ .hasMessageContaining("temperature");
+ }
+}
diff --git
a/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetupTest.java
b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetupTest.java
new file mode 100644
index 00000000..ffcb1939
--- /dev/null
+++
b/integrations/chat-models/openai/src/test/java/org/apache/flink/agents/integrations/chatmodels/openai/AzureOpenAIChatModelSetupTest.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.integrations.chatmodels.openai;
+
+import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.apache.flink.agents.api.resource.ResourceContext;
+import org.apache.flink.agents.api.resource.ResourceDescriptor;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link AzureOpenAIChatModelSetup}. */
+class AzureOpenAIChatModelSetupTest {
+
+ private static final ResourceContext NOOP =
ResourceContext.fromGetResource((a, b) -> null);
+
+ private static ResourceDescriptor.Builder descriptorBuilder() {
+ return
ResourceDescriptor.Builder.newBuilder(AzureOpenAIChatModelSetup.class.getName());
+ }
+
+ @Test
+ @DisplayName("getParameters includes model and default logprobs=false")
+ void testGetParametersMinimal() {
+ ResourceDescriptor desc =
+ descriptorBuilder().addInitialArgument("model",
"my-deployment").build();
+ AzureOpenAIChatModelSetup setup = new AzureOpenAIChatModelSetup(desc,
NOOP);
+
+ Map<String, Object> params = setup.getParameters();
+ assertThat(params).containsEntry("model", "my-deployment");
+ assertThat(params).containsEntry("logprobs", false);
+ assertThat(params)
+ .doesNotContainKeys("temperature", "max_tokens",
"model_of_azure_deployment");
+ }
+
+ @Test
+ @DisplayName("getParameters includes all explicitly-set fields")
+ void testGetParametersAllFields() {
+ ResourceDescriptor desc =
+ descriptorBuilder()
+ .addInitialArgument("model", "my-deployment")
+ .addInitialArgument("model_of_azure_deployment",
"gpt-4o")
+ .addInitialArgument("temperature", 0.3d)
+ .addInitialArgument("max_tokens", 500)
+ .addInitialArgument("logprobs", true)
+ .build();
+ AzureOpenAIChatModelSetup setup = new AzureOpenAIChatModelSetup(desc,
NOOP);
+
+ Map<String, Object> params = setup.getParameters();
+ assertThat(params)
+ .containsEntry("model", "my-deployment")
+ .containsEntry("model_of_azure_deployment", "gpt-4o")
+ .containsEntry("temperature", 0.3d)
+ .containsEntry("max_tokens", 500)
+ .containsEntry("logprobs", true);
+ }
+
+ @Test
+ @DisplayName("getParameters nests additional_kwargs under a dedicated key")
+ void testGetParametersNestsAdditionalKwargs() {
+ ResourceDescriptor desc =
+ descriptorBuilder()
+ .addInitialArgument("model", "my-deployment")
+ .addInitialArgument(
+ "additional_kwargs", Map.of("seed", 42,
"user", "user-123"))
+ .build();
+ AzureOpenAIChatModelSetup setup = new AzureOpenAIChatModelSetup(desc,
NOOP);
+
+ Map<String, Object> params = setup.getParameters();
+ assertThat(params)
+ .containsEntry("model", "my-deployment")
+ .containsEntry("additional_kwargs", Map.of("seed", 42, "user",
"user-123"))
+ .doesNotContainKeys("seed", "user");
+ }
+
+ @Test
+ @DisplayName("temperature must be in [0.0, 2.0]")
+ void testTemperatureValidation() {
+ ResourceDescriptor tooHigh =
+ descriptorBuilder()
+ .addInitialArgument("model", "m")
+ .addInitialArgument("temperature", 2.5d)
+ .build();
+ assertThatThrownBy(() -> new AzureOpenAIChatModelSetup(tooHigh, NOOP))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("temperature must be between 0.0 and
2.0");
+
+ ResourceDescriptor negative =
+ descriptorBuilder()
+ .addInitialArgument("model", "m")
+ .addInitialArgument("temperature", -0.1d)
+ .build();
+ assertThatThrownBy(() -> new AzureOpenAIChatModelSetup(negative, NOOP))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
+ @Test
+ @DisplayName("max_tokens must be greater than 0")
+ void testMaxTokensValidation() {
+ ResourceDescriptor zero =
+ descriptorBuilder()
+ .addInitialArgument("model", "m")
+ .addInitialArgument("max_tokens", 0)
+ .build();
+ assertThatThrownBy(() -> new AzureOpenAIChatModelSetup(zero, NOOP))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("max_tokens must be greater than 0");
+ }
+
+ @Test
+ @DisplayName("Extends BaseChatModelSetup")
+ void testInheritance() {
+ ResourceDescriptor desc =
descriptorBuilder().addInitialArgument("model", "m").build();
+ assertThat(new AzureOpenAIChatModelSetup(desc, NOOP))
+ .isInstanceOf(BaseChatModelSetup.class);
+ }
+
+ @Test
+ @DisplayName("model field is preserved through descriptor round-trip")
+ void testModelFieldRoundtrip() {
+ ResourceDescriptor desc =
+ descriptorBuilder().addInitialArgument("model",
"test-deployment").build();
+ AzureOpenAIChatModelSetup setup = new AzureOpenAIChatModelSetup(desc,
NOOP);
+ assertThat(setup.getParameters()).containsEntry("model",
"test-deployment");
+ }
+}
diff --git a/python/flink_agents/api/resource.py
b/python/flink_agents/api/resource.py
index 3f20821e..c8a1e88b 100644
--- a/python/flink_agents/api/resource.py
+++ b/python/flink_agents/api/resource.py
@@ -291,6 +291,10 @@ class ResourceName:
OPENAI_RESPONSES_CONNECTION =
"org.apache.flink.agents.integrations.chatmodels.openai.OpenAIResponsesModelConnection"
OPENAI_RESPONSES_SETUP =
"org.apache.flink.agents.integrations.chatmodels.openai.OpenAIResponsesModelSetup"
+ # Azure OpenAI
+ AZURE_OPENAI_CONNECTION =
"org.apache.flink.agents.integrations.chatmodels.openai.AzureOpenAIChatModelConnection"
+ AZURE_OPENAI_SETUP =
"org.apache.flink.agents.integrations.chatmodels.openai.AzureOpenAIChatModelSetup"
+
class EmbeddingModel:
"""EmbeddingModel resource names."""
diff --git
a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
index 576d3ab7..18a09212 100644
---
a/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/azure/azure_openai_chat_model.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
+import logging
from typing import Any, Dict, List, Sequence
from openai import NOT_GIVEN, AzureOpenAI
@@ -32,6 +33,12 @@ from
flink_agents.integrations.chat_models.openai.openai_utils import (
convert_to_openai_messages,
)
+logger = logging.getLogger(__name__)
+
+_RESERVED_KWARG_KEYS = frozenset(
+ {"model", "model_of_azure_deployment", "temperature", "max_tokens",
"logprobs"}
+)
+
class AzureOpenAIChatModelConnection(BaseChatModelConnection):
"""The connection to the Azure OpenAI LLM.
@@ -139,6 +146,16 @@ class
AzureOpenAIChatModelConnection(BaseChatModelConnection):
msg = "model is required for Azure OpenAI API calls"
raise ValueError(msg)
model_of_azure_deployment = kwargs.pop("model_of_azure_deployment",
None)
+ additional_kwargs = kwargs.pop("additional_kwargs", None) or {}
+
+ collisions = _RESERVED_KWARG_KEYS & additional_kwargs.keys()
+ if collisions:
+ msg = (
+ f"additional_kwargs must not contain reserved typed fields: "
+ f"{sorted(collisions)}. Set these via the corresponding "
+ f"Setup field instead."
+ )
+ raise ValueError(msg)
response = self.client.chat.completions.create(
# Azure OpenAI APIs use Azure deployment name as the model
parameter
@@ -146,6 +163,7 @@ class
AzureOpenAIChatModelConnection(BaseChatModelConnection):
messages=convert_to_openai_messages(messages),
tools=tool_specs or NOT_GIVEN,
**kwargs,
+ **additional_kwargs,
)
extra_args = {}
@@ -235,6 +253,12 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup):
) -> None:
"""Init method."""
additional_kwargs = additional_kwargs or {}
+ if not model_of_azure_deployment:
+ logger.warning(
+ "model_of_azure_deployment is not set; token usage metrics
will "
+ "not be recorded for this Azure OpenAI deployment '%s'.",
+ model,
+ )
super().__init__(
model=model,
model_of_azure_deployment=model_of_azure_deployment,
@@ -257,6 +281,6 @@ class AzureOpenAIChatModelSetup(BaseChatModelSetup):
base_kwargs["temperature"] = self.temperature
if self.max_tokens is not None:
base_kwargs["max_tokens"] = self.max_tokens
-
- all_kwargs = {**base_kwargs, **self.additional_kwargs}
- return all_kwargs
+ if self.additional_kwargs:
+ base_kwargs["additional_kwargs"] = self.additional_kwargs
+ return base_kwargs
diff --git
a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
index 79d2fb5c..ce69d42e 100644
---
a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
+++
b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py
@@ -127,3 +127,42 @@ def test_model_field_roundtrip() -> None:
setup = AzureOpenAIChatModelSetup(connection="conn",
model="test-deployment")
restored = AzureOpenAIChatModelSetup.model_validate(setup.model_dump())
assert restored.model == "test-deployment"
+
+
+def test_model_kwargs_nests_additional_kwargs() -> None:
+ """`additional_kwargs` is nested under its own key, not flattened.
+
+ Flattening would allow a colliding key (e.g. `temperature`) in
+ `additional_kwargs` to silently overwrite the field-validated value.
+ """
+ setup = AzureOpenAIChatModelSetup(
+ connection="conn",
+ model="my-deployment",
+ additional_kwargs={"seed": 42, "user": "user-123"},
+ )
+ kwargs = setup.model_kwargs
+ assert kwargs["model"] == "my-deployment"
+ assert kwargs["additional_kwargs"] == {"seed": 42, "user": "user-123"}
+ assert "seed" not in kwargs
+ assert "user" not in kwargs
+
+
+def test_chat_rejects_reserved_key_in_additional_kwargs() -> None:
+ """`additional_kwargs` containing a reserved typed key must raise.
+
+ Without this check, `**kwargs, **additional_kwargs` would raise an opaque
+ TypeError, and (worse) leaves the door open for callers to bypass the
+ field-level validation on `temperature`, `max_tokens`, etc.
+ """
+ connection = AzureOpenAIChatModelConnection(
+ api_key="fake-key",
+ azure_endpoint="https://example.openai.azure.com",
+ api_version="2024-02-01",
+ )
+ with pytest.raises(ValueError, match="additional_kwargs"):
+ connection.chat(
+ messages=[ChatMessage(role=MessageRole.USER, content="hi")],
+ model="my-deployment",
+ temperature=0.3,
+ additional_kwargs={"temperature": 5.0},
+ )