This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new db1cf974479 [FLINK-38857][model] Introduce Triton model function
db1cf974479 is described below
commit db1cf97447948b8e74f94a92be509750c60783e5
Author: averyzhang <[email protected]>
AuthorDate: Mon Jan 5 10:59:19 2026 +0800
[FLINK-38857][model] Introduce Triton model function
This closes #27385.
---
flink-docs/pom.xml | 6 +
.../docs/util/ConfigurationOptionLocator.java | 4 +-
.../ConfigOptionsDocsCompletenessITCase.java | 4 +-
flink-models/flink-model-triton/pom.xml | 176 ++++++++
.../model/triton/AbstractTritonModelFunction.java | 325 +++++++++++++++
.../apache/flink/model/triton/TritonDataType.java | 88 ++++
.../model/triton/TritonInferenceModelFunction.java | 456 +++++++++++++++++++++
.../model/triton/TritonModelProviderFactory.java | 94 +++++
.../apache/flink/model/triton/TritonOptions.java | 168 ++++++++
.../flink/model/triton/TritonTypeMapper.java | 317 ++++++++++++++
.../org/apache/flink/model/triton/TritonUtils.java | 157 +++++++
.../triton/exception/TritonClientException.java | 70 ++++
.../model/triton/exception/TritonException.java | 101 +++++
.../triton/exception/TritonNetworkException.java | 59 +++
.../triton/exception/TritonSchemaException.java | 88 ++++
.../triton/exception/TritonServerException.java | 71 ++++
.../src/main/resources/META-INF/NOTICE | 19 +
.../org.apache.flink.table.factories.Factory | 16 +
.../triton/TritonModelProviderFactoryTest.java | 57 +++
.../flink/model/triton/TritonTypeMapperTest.java | 181 ++++++++
flink-models/pom.xml | 1 +
tools/ci/stage.sh | 1 +
22 files changed, 2457 insertions(+), 2 deletions(-)
diff --git a/flink-docs/pom.xml b/flink-docs/pom.xml
index 06f9e572a8a..420da58a19a 100644
--- a/flink-docs/pom.xml
+++ b/flink-docs/pom.xml
@@ -194,6 +194,12 @@ under the License.
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-model-triton</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-sql-gateway</artifactId>
diff --git
a/flink-docs/src/main/java/org/apache/flink/docs/util/ConfigurationOptionLocator.java
b/flink-docs/src/main/java/org/apache/flink/docs/util/ConfigurationOptionLocator.java
index 0b755272110..95dae776d7d 100644
---
a/flink-docs/src/main/java/org/apache/flink/docs/util/ConfigurationOptionLocator.java
+++
b/flink-docs/src/main/java/org/apache/flink/docs/util/ConfigurationOptionLocator.java
@@ -90,7 +90,9 @@ public class ConfigurationOptionLocator {
"flink-external-resources/flink-external-resource-gpu",
"org.apache.flink.externalresource.gpu"),
new OptionsClassLocation(
- "flink-models/flink-model-openai",
"org.apache.flink.model.openai")
+ "flink-models/flink-model-openai",
"org.apache.flink.model.openai"),
+ new OptionsClassLocation(
+ "flink-models/flink-model-triton",
"org.apache.flink.model.triton")
};
private static final Set<String> EXCLUSIONS =
diff --git
a/flink-docs/src/test/java/org/apache/flink/docs/configuration/ConfigOptionsDocsCompletenessITCase.java
b/flink-docs/src/test/java/org/apache/flink/docs/configuration/ConfigOptionsDocsCompletenessITCase.java
index f0060271a02..79831e1e0a8 100644
---
a/flink-docs/src/test/java/org/apache/flink/docs/configuration/ConfigOptionsDocsCompletenessITCase.java
+++
b/flink-docs/src/test/java/org/apache/flink/docs/configuration/ConfigOptionsDocsCompletenessITCase.java
@@ -64,7 +64,9 @@ class ConfigOptionsDocsCompletenessITCase {
new HashSet<>(
Arrays.asList(
"org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions",
-
"org.apache.flink.table.api.config.VectorSearchRuntimeConfigOptions"));
+
"org.apache.flink.table.api.config.VectorSearchRuntimeConfigOptions",
+ "org.apache.flink.model.openai.OpenAIOptions",
+ "org.apache.flink.model.triton.TritonOptions"));
@Test
void testCompleteness() throws Exception {
diff --git a/flink-models/flink-model-triton/pom.xml
b/flink-models/flink-model-triton/pom.xml
new file mode 100644
index 00000000000..580ddfb2f4c
--- /dev/null
+++ b/flink-models/flink-model-triton/pom.xml
@@ -0,0 +1,176 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/maven-v4_0_0.xsd">
+
+ <modelVersion>4.0.0</modelVersion>
+
+ <parent>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-models</artifactId>
+ <version>2.3-SNAPSHOT</version>
+ </parent>
+
+ <artifactId>flink-model-triton</artifactId>
+ <name>Flink : Models : Triton</name>
+
+ <properties>
+ <okhttp.version>4.12.0</okhttp.version>
+ <jackson.version>2.15.2</jackson.version>
+ <test.gson.version>2.11.0</test.gson.version>
+ </properties>
+
+ <dependencies>
+ <!-- HTTP Client for Triton REST API -->
+ <dependency>
+ <groupId>com.squareup.okhttp3</groupId>
+ <artifactId>okhttp</artifactId>
+ <version>${okhttp.version}</version>
+ <optional>${flink.markBundledAsOptional}</optional>
+ </dependency>
+
+ <!-- JSON processing -->
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ <version>${jackson.version}</version>
+ <optional>${flink.markBundledAsOptional}</optional>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <version>${jackson.version}</version>
+ <optional>${flink.markBundledAsOptional}</optional>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ <version>${jackson.version}</version>
+ <optional>${flink.markBundledAsOptional}</optional>
+ </dependency>
+
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-core</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-table-api-java</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-table-common</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
+ <!-- test dependencies -->
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+
<artifactId>flink-table-planner_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>com.squareup.okhttp3</groupId>
+ <artifactId>mockwebserver</artifactId>
+ <version>${okhttp.version}</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-table-api-java-bridge</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-clients</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>com.google.code.gson</groupId>
+ <artifactId>gson</artifactId>
+ <version>${test.gson.version}</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>shade-flink</id>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <artifactSet>
+ <includes>
+
<include>*:*</include>
+ </includes>
+ <excludes>
+
<exclude>com.google.code.findbugs:jsr305</exclude>
+ </excludes>
+ </artifactSet>
+ <relocations
combine.children="append">
+ <relocation>
+
<pattern>com.fasterxml.jackson</pattern>
+
<shadedPattern>org.apache.flink.model.triton.com.fasterxml.jackson</shadedPattern>
+ </relocation>
+ <relocation>
+
<pattern>com.squareup</pattern>
+
<shadedPattern>org.apache.flink.model.triton.com.squareup</shadedPattern>
+ </relocation>
+ </relocations>
+ <filters>
+ <filter>
+
<artifact>*</artifact>
+
<excludes>
+
<exclude>okhttp3/internal/publicsuffix/NOTICE</exclude>
+
</excludes>
+ </filter>
+ </filters>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/AbstractTritonModelFunction.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/AbstractTritonModelFunction.java
new file mode 100644
index 00000000000..f0e07837090
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/AbstractTritonModelFunction.java
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.table.catalog.Column;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.types.logical.ArrayType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import okhttp3.OkHttpClient;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * Abstract parent class for {@link AsyncPredictFunction}s for Triton
Inference Server API.
+ *
+ * <p>This implementation uses REST-based HTTP communication with Triton
Inference Server. Each
+ * Flink record triggers a separate HTTP request (no Flink-side batching).
Triton's server-side
+ * dynamic batching can aggregate concurrent requests.
+ *
+ * <p><b>HTTP Client Lifecycle:</b> A shared HTTP client pool is maintained
per JVM with reference
+ * counting. Multiple function instances with identical timeout settings share
the same client
+ * instance to avoid resource exhaustion in high-parallelism scenarios.
+ *
+ * <p><b>Current Limitations (v1):</b>
+ *
+ * <ul>
+ * <li>Only single input column and single output column are supported
+ * <li>REST API only; gRPC may be introduced in future versions
+ * </ul>
+ *
+ * <p><b>Future Roadmap:</b> Support for multi-input/multi-output models using
ROW or MAP types, and
+ * native gRPC protocol for improved performance.
+ */
+public abstract class AbstractTritonModelFunction extends AsyncPredictFunction
{
+ private static final Logger LOG =
LoggerFactory.getLogger(AbstractTritonModelFunction.class);
+
+ protected transient OkHttpClient httpClient;
+
+ private final String endpoint;
+ private final String modelName;
+ private final String modelVersion;
+ private final Duration timeout;
+ private final boolean flattenBatchDim;
+ private final Integer priority;
+
+ /**
+ * Sequence ID used by Triton to correlate multiple inference requests
that belong to the same
+ * stateful sequence (e.g. RNN or streaming models).
+ *
+ * <p>See Triton Inference Server sequence batching documentation:
+ *
https://github.com/triton-inference-server/server/blob/main/docs/sequence_batcher.md
+ */
+ private final String sequenceId;
+
+ private final boolean sequenceStart;
+ private final boolean sequenceEnd;
+ private final String compression;
+ private final String authToken;
+ private final Map<String, String> customHeaders;
+
+ public AbstractTritonModelFunction(
+ ModelProviderFactory.Context factoryContext, ReadableConfig
config) {
+ this.endpoint = config.get(TritonOptions.ENDPOINT);
+ this.modelName = config.get(TritonOptions.MODEL_NAME);
+ this.modelVersion = config.get(TritonOptions.MODEL_VERSION);
+ this.timeout = config.get(TritonOptions.TIMEOUT);
+ this.flattenBatchDim = config.get(TritonOptions.FLATTEN_BATCH_DIM);
+ this.priority = config.get(TritonOptions.PRIORITY);
+ this.sequenceId = config.get(TritonOptions.SEQUENCE_ID);
+ this.sequenceStart = config.get(TritonOptions.SEQUENCE_START);
+ this.sequenceEnd = config.get(TritonOptions.SEQUENCE_END);
+ this.compression = config.get(TritonOptions.COMPRESSION);
+ this.authToken = config.get(TritonOptions.AUTH_TOKEN);
+ this.customHeaders = config.get(TritonOptions.CUSTOM_HEADERS);
+
+ // Validate input schema - support multiple types
+
validateInputSchema(factoryContext.getCatalogModel().getResolvedInputSchema());
+ }
+
+ @Override
+ public void open(FunctionContext context) throws Exception {
+ super.open(context);
+ LOG.debug("Creating Triton HTTP client.");
+ this.httpClient = TritonUtils.createHttpClient(timeout.toMillis());
+ }
+
+ @Override
+ public void close() throws Exception {
+ super.close();
+ if (this.httpClient != null) {
+ LOG.debug("Releasing Triton HTTP client.");
+ TritonUtils.releaseHttpClient(this.httpClient);
+ httpClient = null;
+ }
+ }
+
+ /**
+ * Validates the input schema. Subclasses can override for custom
validation.
+ *
+ * @param schema The input schema to validate
+ */
+ protected void validateInputSchema(ResolvedSchema schema) {
+ validateSingleColumnSchema(schema, null, "input");
+ }
+
+ /**
+ * Validates that the schema has exactly one physical column, optionally
checking the type.
+ *
+ * <p><b>Version 1 Limitation:</b> Only single input/single output models
are supported. For
+ * models requiring multiple tensors, consider these workarounds:
+ *
+ * <ul>
+ * <li>Flatten inputs into a JSON STRING and parse server-side
+ * <li>Use ARRAY<T> to pack multiple values
+ * <li>Wait for future ROW<...> support (planned for v2)
+ * </ul>
+ *
+ * @param schema The schema to validate
+ * @param expectedType The expected type, or null to skip type checking
+ * @param inputOrOutput Description of whether this is input or output
schema
+ */
+ protected void validateSingleColumnSchema(
+ ResolvedSchema schema, LogicalType expectedType, String
inputOrOutput) {
+ List<Column> columns = schema.getColumns();
+ if (columns.size() != 1) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Model should have exactly one %s column, but
actually has %s columns: %s. "
+ + "Current version only supports single
input/output. "
+ + "For multi-tensor models, consider using
JSON STRING encoding or ARRAY<T> packing.",
+ inputOrOutput,
+ columns.size(),
+
columns.stream().map(Column::getName).collect(Collectors.toList())));
+ }
+
+ Column column = columns.get(0);
+ if (!column.isPhysical()) {
+ throw new IllegalArgumentException(
+ String.format(
+ "%s column %s should be a physical column, but is
a %s.",
+ inputOrOutput, column.getName(),
column.getClass()));
+ }
+
+ if (expectedType != null &&
!expectedType.equals(column.getDataType().getLogicalType())) {
+ throw new IllegalArgumentException(
+ String.format(
+ "%s column %s should be %s, but is a %s.",
+ inputOrOutput,
+ column.getName(),
+ expectedType,
+ column.getDataType().getLogicalType()));
+ }
+
+ // Validate that the type is supported by Triton
+ try {
+
TritonTypeMapper.toTritonDataType(column.getDataType().getLogicalType());
+ } catch (IllegalArgumentException e) {
+ String suggestedType =
getSuggestedTypeForTriton(column.getDataType().getLogicalType());
+ throw new IllegalArgumentException(
+ String.format(
+ "%s column %s has unsupported type %s for Triton.
%s%s",
+ inputOrOutput,
+ column.getName(),
+ column.getDataType().getLogicalType(),
+ e.getMessage(),
+ suggestedType.isEmpty() ? "" : "\nSuggestion: " +
suggestedType));
+ }
+
+ // Enhanced validation for type compatibility
+ validateTritonTypeCompatibility(
+ column.getDataType().getLogicalType(), column.getName(),
inputOrOutput);
+ }
+
+ /**
+ * Validates Triton type compatibility with enhanced checks.
+ *
+ * <p>This method performs additional validation beyond basic type support:
+ *
+ * <ul>
+ * <li>Checks for nested arrays (multi-dimensional tensors not supported
in v1)
+ * <li>Warns about STRING to BYTES mapping
+ * <li>Provides structured error messages with troubleshooting hints
+ * </ul>
+ *
+ * @param type The logical type to validate
+ * @param columnName The name of the column
+ * @param inputOrOutput Description of whether this is input or output
+ */
+ private void validateTritonTypeCompatibility(
+ LogicalType type, String columnName, String inputOrOutput) {
+
+ // Check for nested arrays (multi-dimensional tensors)
+ if (type instanceof ArrayType) {
+ ArrayType arrayType = (ArrayType) type;
+ LogicalType elementType = arrayType.getElementType();
+
+ // Reject nested arrays
+ if (elementType instanceof ArrayType) {
+ throw new IllegalArgumentException(
+ String.format(
+ "%s column '%s' has nested array type: %s\n"
+ + "Multi-dimensional tensors
(ARRAY<ARRAY<T>>) are not supported in v1.\n"
+ + "=== Supported Types ===\n"
+ + " • Scalars: INT, BIGINT, FLOAT,
DOUBLE, BOOLEAN, STRING\n"
+ + " • 1-D Arrays: ARRAY<INT>,
ARRAY<FLOAT>, ARRAY<DOUBLE>, etc.\n"
+ + "=== Workarounds ===\n"
+ + " • Flatten to 1-D array:
ARRAY<FLOAT> with size = rows * cols\n"
+ + " • Use JSON STRING encoding for
complex structures\n"
+ + " • Wait for v2+ which will support
ROW<...> types",
+ inputOrOutput, columnName, type));
+ }
+ }
+
+ // Log info about STRING to BYTES mapping
+ if (type instanceof VarCharType) {
+ LOG.info(
+ "{} column '{}' uses STRING type, which will be mapped to
Triton BYTES dtype. "
+ + "Ensure your Triton model expects string/text
inputs.",
+ inputOrOutput,
+ columnName);
+ }
+ }
+
+ /** Provides user-friendly type suggestions for unsupported types. */
+ private String getSuggestedTypeForTriton(LogicalType unsupportedType) {
+ String typeName = unsupportedType.getTypeRoot().name();
+
+ if (typeName.contains("ARRAY") && unsupportedType instanceof
ArrayType) {
+ ArrayType arrayType = (ArrayType) unsupportedType;
+ if (arrayType.getElementType() instanceof ArrayType) {
+ return "Flatten nested array to 1-D: ARRAY<FLOAT> instead of
ARRAY<ARRAY<FLOAT>>";
+ }
+ }
+
+ if (typeName.contains("MAP")) {
+ return "Use ARRAY<T> instead of MAP, or serialize to JSON STRING";
+ } else if (typeName.contains("ROW") || typeName.contains("STRUCT")) {
+ return "Flatten ROW into single column, use ARRAY<T> packing, or
serialize to JSON STRING";
+ } else if (typeName.contains("TIME") || typeName.contains("DATE")) {
+ return "Convert timestamp/date to BIGINT (epoch milliseconds) or
STRING (ISO-8601)";
+ } else if (typeName.contains("DECIMAL")) {
+ return "Use DOUBLE for numeric precision or STRING for exact
decimal representation";
+ } else if (typeName.contains("BINARY") ||
typeName.contains("VARBINARY")) {
+ return "Consider using STRING (VARCHAR) type, which maps to Triton
BYTES";
+ }
+
+ return "";
+ }
+
+ // Getters for configuration values
+ protected String getEndpoint() {
+ return endpoint;
+ }
+
+ protected String getModelName() {
+ return modelName;
+ }
+
+ protected String getModelVersion() {
+ return modelVersion;
+ }
+
+ protected Duration getTimeout() {
+ return timeout;
+ }
+
+ protected boolean isFlattenBatchDim() {
+ return flattenBatchDim;
+ }
+
+ protected Integer getPriority() {
+ return priority;
+ }
+
+ protected String getSequenceId() {
+ return sequenceId;
+ }
+
+ protected boolean isSequenceStart() {
+ return sequenceStart;
+ }
+
+ protected boolean isSequenceEnd() {
+ return sequenceEnd;
+ }
+
+ protected String getCompression() {
+ return compression;
+ }
+
+ protected String getAuthToken() {
+ return authToken;
+ }
+
+ protected Map<String, String> getCustomHeaders() {
+ return customHeaders;
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonDataType.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonDataType.java
new file mode 100644
index 00000000000..451b400f9ff
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonDataType.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+/**
+ * Enumeration of data types supported by Triton Inference Server.
+ *
+ * <p>These data types correspond to the types defined in the Triton Inference
Server protocol.
+ * Reference:
+ *
https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_model_configuration.md
+ */
+public enum TritonDataType {
+ /** Boolean type. */
+ BOOL("BOOL"),
+
+ /** 8-bit unsigned integer. */
+ UINT8("UINT8"),
+
+ /** 16-bit unsigned integer. */
+ UINT16("UINT16"),
+
+ /** 32-bit unsigned integer. */
+ UINT32("UINT32"),
+
+ /** 64-bit unsigned integer. */
+ UINT64("UINT64"),
+
+ /** 8-bit signed integer. */
+ INT8("INT8"),
+
+ /** 16-bit signed integer. */
+ INT16("INT16"),
+
+ /** 32-bit signed integer. */
+ INT32("INT32"),
+
+ /** 64-bit signed integer. */
+ INT64("INT64"),
+
+ /** 16-bit floating point (half precision). */
+ FP16("FP16"),
+
+ /** 32-bit floating point (single precision). */
+ FP32("FP32"),
+
+ /** 64-bit floating point (double precision). */
+ FP64("FP64"),
+
+ /** String/text data. */
+ BYTES("BYTES");
+
+ private final String tritonName;
+
+ TritonDataType(String tritonName) {
+ this.tritonName = tritonName;
+ }
+
+ /** Returns the Triton protocol name for this data type. */
+ public String getTritonName() {
+ return tritonName;
+ }
+
+ /** Gets a TritonDataType from its Triton protocol name. */
+ public static TritonDataType fromTritonName(String tritonName) {
+ for (TritonDataType type : values()) {
+ if (type.tritonName.equals(tritonName)) {
+ return type;
+ }
+ }
+ throw new IllegalArgumentException("Unknown Triton data type: " +
tritonName);
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonInferenceModelFunction.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonInferenceModelFunction.java
new file mode 100644
index 00000000000..9be17a2f4a6
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonInferenceModelFunction.java
@@ -0,0 +1,456 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+import org.apache.flink.configuration.ReadableConfig;
+import org.apache.flink.model.triton.exception.TritonClientException;
+import org.apache.flink.model.triton.exception.TritonNetworkException;
+import org.apache.flink.model.triton.exception.TritonSchemaException;
+import org.apache.flink.model.triton.exception.TritonServerException;
+import org.apache.flink.table.catalog.Column;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryStringData;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.types.logical.ArrayType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+import com.fasterxml.jackson.databind.node.ObjectNode;
+import okhttp3.Call;
+import okhttp3.Callback;
+import okhttp3.MediaType;
+import okhttp3.Request;
+import okhttp3.RequestBody;
+import okhttp3.Response;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.zip.GZIPOutputStream;
+
+/**
+ * {@link AsyncPredictFunction} for Triton Inference Server generic inference
task.
+ *
+ * <p><b>Request Model (v1):</b> This implementation processes records
one-by-one. Each {@link
+ * #asyncPredict(RowData)} call triggers one HTTP request to Triton server.
There is no Flink-side
+ * mini-batch aggregation in the current version.
+ *
+ * <p><b>Batch Efficiency:</b> Inference throughput benefits from:
+ *
+ * <ul>
+ * <li><b>Triton Dynamic Batching</b>: Configure {@code dynamic_batching} in
model's {@code
+ * config.pbtxt} to aggregate concurrent requests server-side
+ * <li><b>Flink Parallelism</b>: High parallelism naturally creates
concurrent requests that
+ * Triton can batch together
+ * <li><b>AsyncDataStream Capacity</b>: Buffer size controls concurrent
in-flight requests,
+ * increasing opportunities for server-side batching
+ * </ul>
+ *
+ * <p><b>Future Roadmap (v2+):</b> Flink-side mini-batch aggregation will be
added to reduce HTTP
+ * overhead (configurable via {@code batch-size} and {@code batch-timeout}
options).
+ *
+ * @see <a
+ *
href="https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher">Triton
+ * Dynamic Batching Documentation</a>
+ */
+public class TritonInferenceModelFunction extends AbstractTritonModelFunction {
+ private static final long serialVersionUID = 1L;
+ private static final Logger LOG =
LoggerFactory.getLogger(TritonInferenceModelFunction.class);
+
+ private static final MediaType JSON_MEDIA_TYPE =
+ MediaType.get("application/json; charset=utf-8");
+ private static final ObjectMapper objectMapper = new ObjectMapper();
+
+ /** Reusable buffer for gzip compression to avoid repeated allocations. */
+ private final ByteArrayOutputStream compressionBuffer = new
ByteArrayOutputStream(1024);
+
+ private final LogicalType inputType;
+ private final LogicalType outputType;
+ private final String inputName;
+ private final String outputName;
+
+ public TritonInferenceModelFunction(
+ ModelProviderFactory.Context factoryContext, ReadableConfig
config) {
+ super(factoryContext, config);
+
+ // Validate and store input/output types
+ validateSingleColumnSchema(
+ factoryContext.getCatalogModel().getResolvedOutputSchema(),
+ null, // Allow any supported type
+ "output");
+
+ // Get input and output column information
+ Column inputColumn =
+
factoryContext.getCatalogModel().getResolvedInputSchema().getColumns().get(0);
+ Column outputColumn =
+
factoryContext.getCatalogModel().getResolvedOutputSchema().getColumns().get(0);
+
+ this.inputType = inputColumn.getDataType().getLogicalType();
+ this.outputType = outputColumn.getDataType().getLogicalType();
+ this.inputName = inputColumn.getName();
+ this.outputName = outputColumn.getName();
+ }
+
+ @Override
+ public CompletableFuture<Collection<RowData>> asyncPredict(RowData
rowData) {
+ CompletableFuture<Collection<RowData>> future = new
CompletableFuture<>();
+
+ try {
+ String requestBody = buildInferenceRequest(rowData);
+ String url =
+ TritonUtils.buildInferenceUrl(getEndpoint(),
getModelName(), getModelVersion());
+
+ Request.Builder requestBuilder = new Request.Builder().url(url);
+
+ // Handle compression and request body
+ if (getCompression() != null) {
+ if ("gzip".equalsIgnoreCase(getCompression())) {
+ // Compress request body with gzip using reusable buffer
+ compressionBuffer.reset();
+ try (GZIPOutputStream gzos = new
GZIPOutputStream(compressionBuffer)) {
+
gzos.write(requestBody.getBytes(StandardCharsets.UTF_8));
+ }
+ byte[] compressedData = compressionBuffer.toByteArray();
+
+ requestBuilder.addHeader("Content-Encoding", "gzip");
+ requestBuilder.post(RequestBody.create(compressedData,
JSON_MEDIA_TYPE));
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "Unsupported compression algorithm: '%s'.
Currently only 'gzip' is supported.",
+ getCompression()));
+ }
+ } else {
+ requestBuilder.post(RequestBody.create(requestBody,
JSON_MEDIA_TYPE));
+ }
+
+ // Add authentication header if provided
+ if (getAuthToken() != null) {
+ requestBuilder.addHeader("Authorization", "Bearer " +
getAuthToken());
+ }
+
+ // Add custom headers if provided
+ if (getCustomHeaders() != null && !getCustomHeaders().isEmpty()) {
+ getCustomHeaders().forEach((key, value) ->
requestBuilder.addHeader(key, value));
+ }
+
+ Request request = requestBuilder.build();
+
+ httpClient
+ .newCall(request)
+ .enqueue(
+ new Callback() {
+ @Override
+ public void onFailure(Call call, IOException
e) {
+ LOG.error(
+ "Triton inference request failed
due to network error",
+ e);
+
+ // Wrap IOException in
TritonNetworkException
+ TritonNetworkException networkException =
+ new TritonNetworkException(
+ String.format(
+ "Failed to connect
to Triton server at %s: %s. "
+ + "This
may indicate network connectivity issues, DNS resolution failure, or server
unavailability.",
+ url,
e.getMessage()),
+ e);
+
+
future.completeExceptionally(networkException);
+ }
+
+ @Override
+ public void onResponse(Call call, Response
response)
+ throws IOException {
+ try {
+ if (!response.isSuccessful()) {
+ handleErrorResponse(response,
future);
+ return;
+ }
+
+ String responseBody =
response.body().string();
+ Collection<RowData> result =
+
parseInferenceResponse(responseBody);
+ future.complete(result);
+ } catch (JsonProcessingException e) {
+ LOG.error("Failed to parse Triton
inference response", e);
+ future.completeExceptionally(
+ new TritonClientException(
+ "Failed to parse
Triton response JSON: "
+ +
e.getMessage()
+ + ". This may
indicate an incompatible response format.",
+ 400));
+ } catch (Exception e) {
+ LOG.error("Failed to process Triton
inference response", e);
+ future.completeExceptionally(e);
+ } finally {
+ response.close();
+ }
+ }
+ });
+
+ } catch (Exception e) {
+ LOG.error("Failed to build Triton inference request", e);
+ future.completeExceptionally(e);
+ }
+
+ return future;
+ }
+
+ /**
+ * Handles HTTP error responses and creates appropriate typed exceptions.
+ *
+ * @param response The HTTP response with error status
+ * @param future The future to complete exceptionally
+ * @throws IOException If reading response body fails
+ */
+ private void handleErrorResponse(
+ Response response, CompletableFuture<Collection<RowData>> future)
throws IOException {
+
+ String errorBody =
+ response.body() != null ? response.body().string() : "No error
details provided";
+ int statusCode = response.code();
+
+ // Build detailed error message with context
+ StringBuilder errorMsg = new StringBuilder();
+ errorMsg.append(
+ String.format("Triton inference failed with HTTP %d: %s\n",
statusCode, errorBody));
+ errorMsg.append("\n=== Request Configuration ===\n");
+ errorMsg.append(
+ String.format(" Model: %s (version: %s)\n", getModelName(),
getModelVersion()));
+ errorMsg.append(String.format(" Endpoint: %s\n", getEndpoint()));
+ errorMsg.append(String.format(" Input column: %s\n", inputName));
+ errorMsg.append(String.format(" Input Flink type: %s\n", inputType));
+ errorMsg.append(
+ String.format(
+ " Input Triton dtype: %s\n",
+
TritonTypeMapper.toTritonDataType(inputType).getTritonName()));
+
+ // Check if this is a shape mismatch error
+ boolean isShapeMismatch =
+ errorBody.toLowerCase().contains("shape")
+ || errorBody.toLowerCase().contains("dimension");
+
+ if (statusCode >= 400 && statusCode < 500) {
+ // Client error - user configuration issue
+ errorMsg.append("\n=== Troubleshooting (Client Error) ===\n");
+
+ if (statusCode == 400) {
+ errorMsg.append(" • Verify input shape matches model's
config.pbtxt\n");
+ errorMsg.append(" • For scalar: use
INT/FLOAT/DOUBLE/STRING\n");
+ errorMsg.append(" • For 1-D tensor: use ARRAY<type>\n");
+ errorMsg.append(
+ " • Try flatten-batch-dim=true if model expects [N]
but gets [1,N]\n");
+
+ if (isShapeMismatch) {
+ // Create schema exception for shape mismatches
+ future.completeExceptionally(
+ new TritonSchemaException(
+ errorMsg.toString(),
+ "See Triton model config.pbtxt",
+ String.format("Flink type: %s",
inputType)));
+ return;
+ }
+ } else if (statusCode == 404) {
+ errorMsg.append(" • Verify model-name:
").append(getModelName()).append("\n");
+ errorMsg.append(" • Verify model-version: ")
+ .append(getModelVersion())
+ .append("\n");
+ errorMsg.append(" • Check model is loaded: GET ")
+ .append(getEndpoint())
+ .append("\n");
+ } else if (statusCode == 401 || statusCode == 403) {
+ errorMsg.append(" • Check auth-token configuration\n");
+ errorMsg.append(" • Verify server authentication
requirements\n");
+ }
+
+ future.completeExceptionally(
+ new TritonClientException(errorMsg.toString(),
statusCode));
+
+ } else if (statusCode >= 500 && statusCode < 600) {
+ // Server error - Triton service issue
+ errorMsg.append("\n=== Troubleshooting (Server Error) ===\n");
+
+ if (statusCode == 500) {
+ errorMsg.append(" • Check Triton server logs for inference
crash details\n");
+ errorMsg.append(" • Model may have run out of memory\n");
+ errorMsg.append(" • Input data may trigger model bug\n");
+ } else if (statusCode == 503) {
+ errorMsg.append(" • Server is overloaded or unavailable\n");
+ errorMsg.append(" • This error is retryable with backoff\n");
+ errorMsg.append(" • Consider scaling Triton server
resources\n");
+ } else if (statusCode == 504) {
+ errorMsg.append(" • Inference exceeded gateway timeout\n");
+ errorMsg.append(" • This error is retryable\n");
+ errorMsg.append(" • Consider increasing timeout
configuration\n");
+ }
+
+ future.completeExceptionally(
+ new TritonServerException(errorMsg.toString(),
statusCode));
+
+ } else {
+ // Unexpected status code
+ errorMsg.append("\n=== Unexpected Status Code ===\n");
+ errorMsg.append(" • This status code is not standard for
Triton\n");
+ errorMsg.append(" • Check if proxy/load balancer is involved\n");
+
+ future.completeExceptionally(
+ new TritonClientException(errorMsg.toString(),
statusCode));
+ }
+ }
+
+ private String buildInferenceRequest(RowData rowData) throws
JsonProcessingException {
+ ObjectNode requestNode = objectMapper.createObjectNode();
+
+ // Add request ID if sequence ID is provided
+ if (getSequenceId() != null) {
+ requestNode.put("id", getSequenceId());
+ }
+
+ // Add parameters
+ ObjectNode parametersNode = objectMapper.createObjectNode();
+ if (getPriority() != null) {
+ parametersNode.put("priority", getPriority());
+ }
+ if (isSequenceStart()) {
+ parametersNode.put("sequence_start", true);
+ }
+ if (isSequenceEnd()) {
+ parametersNode.put("sequence_end", true);
+ }
+ if (parametersNode.size() > 0) {
+ requestNode.set("parameters", parametersNode);
+ }
+
+ // Add inputs
+ ArrayNode inputsArray = objectMapper.createArrayNode();
+ ObjectNode inputNode = objectMapper.createObjectNode();
+ inputNode.put("name", inputName.toUpperCase());
+
+ // Map Flink type to Triton type
+ TritonDataType tritonType =
TritonTypeMapper.toTritonDataType(inputType);
+ inputNode.put("datatype", tritonType.getTritonName());
+
+ // Serialize input data first to get actual size
+ ArrayNode dataArray = objectMapper.createArrayNode();
+ TritonTypeMapper.serializeToJsonArray(rowData, 0, inputType,
dataArray);
+
+ // Calculate and add shape based on actual data
+ int[] shape = TritonTypeMapper.calculateShape(inputType, 1, rowData,
0);
+
+ // Apply flatten-batch-dim if configured
+ if (isFlattenBatchDim() && shape.length > 1 && shape[0] == 1) {
+ // Remove the batch dimension: [1, N] -> [N]
+ int[] flattenedShape = new int[shape.length - 1];
+ System.arraycopy(shape, 1, flattenedShape, 0,
flattenedShape.length);
+ shape = flattenedShape;
+ }
+
+ ArrayNode shapeArray = objectMapper.createArrayNode();
+ for (int dim : shape) {
+ shapeArray.add(dim);
+ }
+ inputNode.set("shape", shapeArray);
+ inputNode.set("data", dataArray);
+
+ inputsArray.add(inputNode);
+ requestNode.set("inputs", inputsArray);
+
+ // Add outputs (request all outputs)
+ ArrayNode outputsArray = objectMapper.createArrayNode();
+ ObjectNode outputNode = objectMapper.createObjectNode();
+ outputNode.put("name", outputName.toUpperCase());
+ outputsArray.add(outputNode);
+ requestNode.set("outputs", outputsArray);
+
+ String requestJson = objectMapper.writeValueAsString(requestNode);
+
+ // Log the request for debugging
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(
+ "Triton inference request - Model: {}, Version: {}, Input:
{}, Shape: {}",
+ getModelName(),
+ getModelVersion(),
+ inputName,
+ java.util.Arrays.toString(shape));
+ LOG.debug("Request body: {}", requestJson);
+ }
+
+ return requestJson;
+ }
+
+ private Collection<RowData> parseInferenceResponse(String responseBody)
+ throws JsonProcessingException {
+ JsonNode responseNode = objectMapper.readTree(responseBody);
+ List<RowData> results = new ArrayList<>();
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Triton response body: {}", responseBody);
+ }
+
+ JsonNode outputsNode = responseNode.get("outputs");
+ if (outputsNode != null && outputsNode.isArray()) {
+ for (JsonNode outputNode : outputsNode) {
+ JsonNode dataNode = outputNode.get("data");
+
+ if (dataNode != null && dataNode.isArray()) {
+ if (dataNode.size() > 0) {
+ // Check if output is array type or scalar
+ // If outputType is scalar but dataNode is array,
extract first element
+ JsonNode nodeToDeserialize = dataNode;
+ if (!(outputType instanceof ArrayType) &&
dataNode.isArray()) {
+ // Scalar type - extract first element from array
+ nodeToDeserialize = dataNode.get(0);
+ }
+
+ Object deserializedData =
+
TritonTypeMapper.deserializeFromJson(nodeToDeserialize, outputType);
+
+ results.add(GenericRowData.of(deserializedData));
+ }
+ }
+ }
+ } else {
+ LOG.warn("No outputs found in Triton response");
+ }
+
+ // If no outputs found, return default value based on type
+ if (results.isEmpty()) {
+ Object defaultValue;
+ if (outputType instanceof VarCharType) {
+ defaultValue = BinaryStringData.EMPTY_UTF8;
+ } else {
+ defaultValue = null;
+ }
+ results.add(GenericRowData.of(defaultValue));
+ }
+
+ return results;
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonModelProviderFactory.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonModelProviderFactory.java
new file mode 100644
index 00000000000..516ec899b76
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonModelProviderFactory.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+import org.apache.flink.configuration.ConfigOption;
+import org.apache.flink.table.factories.FactoryUtil;
+import org.apache.flink.table.factories.ModelProviderFactory;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
+import org.apache.flink.table.ml.ModelProvider;
+
+import java.util.HashSet;
+import java.util.Set;
+
+/** {@link ModelProviderFactory} for Triton Inference Server model functions.
*/
+public class TritonModelProviderFactory implements ModelProviderFactory {
+ public static final String IDENTIFIER = "triton";
+
+ @Override
+ public ModelProvider createModelProvider(ModelProviderFactory.Context
context) {
+ FactoryUtil.ModelProviderFactoryHelper helper =
+ FactoryUtil.createModelProviderFactoryHelper(this, context);
+ helper.validate();
+
+ // For now, we create a generic inference function
+ // In the future, this could be extended to support different model
types
+ AsyncPredictFunction function =
+ new TritonInferenceModelFunction(context, helper.getOptions());
+ return new Provider(function);
+ }
+
+ @Override
+ public String factoryIdentifier() {
+ return IDENTIFIER;
+ }
+
+ @Override
+ public Set<ConfigOption<?>> requiredOptions() {
+ Set<ConfigOption<?>> set = new HashSet<>();
+ set.add(TritonOptions.ENDPOINT);
+ set.add(TritonOptions.MODEL_NAME);
+ return set;
+ }
+
+ @Override
+ public Set<ConfigOption<?>> optionalOptions() {
+ Set<ConfigOption<?>> set = new HashSet<>();
+ set.add(TritonOptions.MODEL_VERSION);
+ set.add(TritonOptions.TIMEOUT);
+ set.add(TritonOptions.FLATTEN_BATCH_DIM);
+ set.add(TritonOptions.PRIORITY);
+ set.add(TritonOptions.SEQUENCE_ID);
+ set.add(TritonOptions.SEQUENCE_START);
+ set.add(TritonOptions.SEQUENCE_END);
+ set.add(TritonOptions.COMPRESSION);
+ set.add(TritonOptions.AUTH_TOKEN);
+ set.add(TritonOptions.CUSTOM_HEADERS);
+ return set;
+ }
+
+ /** {@link ModelProvider} for Triton model functions. */
+ public static class Provider implements AsyncPredictRuntimeProvider {
+ private final AsyncPredictFunction function;
+
+ public Provider(AsyncPredictFunction function) {
+ this.function = function;
+ }
+
+ @Override
+ public AsyncPredictFunction createAsyncPredictFunction(Context
context) {
+ return function;
+ }
+
+ @Override
+ public ModelProvider copy() {
+ return new Provider(function);
+ }
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonOptions.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonOptions.java
new file mode 100644
index 00000000000..f5b920776e4
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonOptions.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+import org.apache.flink.annotation.docs.Documentation;
+import org.apache.flink.configuration.ConfigOption;
+import org.apache.flink.configuration.ConfigOptions;
+import org.apache.flink.configuration.description.Description;
+import org.apache.flink.configuration.description.LinkElement;
+
+import java.time.Duration;
+import java.util.Map;
+
+import static org.apache.flink.configuration.description.TextElement.code;
+
+/**
+ * Configuration options for Triton Inference Server model functions.
+ *
+ * <p>Documentation for these options will be added in a separate PR.
+ */
[email protected](
+ "Documentation for Triton options will be added in a separate PR")
+public class TritonOptions {
+
+ private TritonOptions() {
+ // Utility class with static options only
+ }
+
+ public static final ConfigOption<String> ENDPOINT =
+ ConfigOptions.key("endpoint")
+ .stringType()
+ .noDefaultValue()
+ .withDescription(
+ Description.builder()
+ .text(
+ "Full URL of the Triton Inference
Server endpoint, e.g., %s. "
+ + "Both HTTP and HTTPS are
supported; HTTPS is recommended for production.",
+
code("https://triton-server:8000/v2/models"))
+ .build());
+
+ public static final ConfigOption<String> MODEL_NAME =
+ ConfigOptions.key("model-name")
+ .stringType()
+ .noDefaultValue()
+ .withDescription("Name of the model to invoke on Triton
server.");
+
+ public static final ConfigOption<String> MODEL_VERSION =
+ ConfigOptions.key("model-version")
+ .stringType()
+ .defaultValue("latest")
+ .withDescription("Version of the model to use. Defaults to
'latest'.");
+
+ public static final ConfigOption<Duration> TIMEOUT =
+ ConfigOptions.key("timeout")
+ .durationType()
+ .defaultValue(Duration.ofSeconds(30))
+ .withDescription(
+ "HTTP request timeout (connect + read + write). "
+ + "This applies per individual request and
is separate from Flink's async timeout. "
+ + "Defaults to 30 seconds.");
+
+ public static final ConfigOption<Boolean> FLATTEN_BATCH_DIM =
+ ConfigOptions.key("flatten-batch-dim")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ "Whether to flatten the batch dimension for array
inputs. "
+ + "When true, shape [1,N] becomes [N].
Defaults to false.");
+
+ public static final ConfigOption<Integer> PRIORITY =
+ ConfigOptions.key("priority")
+ .intType()
+ .noDefaultValue()
+ .withDescription(
+ "Request priority level (0-255). Higher values
indicate higher priority.");
+
+ public static final ConfigOption<String> SEQUENCE_ID =
+ ConfigOptions.key("sequence-id")
+ .stringType()
+ .noDefaultValue()
+ .withDescription(
+ Description.builder()
+ .text(
+ "Sequence ID for stateful models.
A sequence represents a series of "
+ + "inference requests that
must be routed to the same model instance "
+ + "to maintain state
across requests (e.g., for RNN/LSTM models). "
+ + "See %s for more
details.",
+ LinkElement.link(
+
"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/architecture.html#stateful-models",
+ "Triton Stateful Models"))
+ .build());
+
+ public static final ConfigOption<Boolean> SEQUENCE_START =
+ ConfigOptions.key("sequence-start")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ Description.builder()
+ .text(
+ "Whether this request marks the
start of a new sequence for stateful models. "
+ + "When true, Triton will
initialize the model's state before processing this request. "
+ + "See %s for more
details.",
+ LinkElement.link(
+
"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/architecture.html#stateful-models",
+ "Triton Stateful Models"))
+ .build());
+
+ public static final ConfigOption<Boolean> SEQUENCE_END =
+ ConfigOptions.key("sequence-end")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ Description.builder()
+ .text(
+ "Whether this request marks the
end of a sequence for stateful models. "
+ + "When true, Triton will
release the model's state after processing this request. "
+ + "See %s for more
details.",
+ LinkElement.link(
+
"https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/architecture.html#stateful-models",
+ "Triton Stateful Models"))
+ .build());
+
+ public static final ConfigOption<String> COMPRESSION =
+ ConfigOptions.key("compression")
+ .stringType()
+ .noDefaultValue()
+ .withDescription(
+ Description.builder()
+ .text(
+ "Compression algorithm for request
body. Currently only %s is supported. "
+ + "When enabled, the
request body will be compressed to reduce network bandwidth.",
+ code("gzip"))
+ .build());
+
+ public static final ConfigOption<String> AUTH_TOKEN =
+ ConfigOptions.key("auth-token")
+ .stringType()
+ .noDefaultValue()
+ .withDescription("Authentication token for secured Triton
servers.");
+
+ public static final ConfigOption<Map<String, String>> CUSTOM_HEADERS =
+ ConfigOptions.key("custom-headers")
+ .mapType()
+ .noDefaultValue()
+ .withDescription(
+ Description.builder()
+ .text(
+ "Custom HTTP headers as key-value
pairs. "
+ + "Example: %s",
+
code("'X-Custom-Header:value,X-Another:value2'"))
+ .build());
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonTypeMapper.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonTypeMapper.java
new file mode 100644
index 00000000000..49eab681f94
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonTypeMapper.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.GenericArrayData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryStringData;
+import org.apache.flink.table.types.logical.ArrayType;
+import org.apache.flink.table.types.logical.BigIntType;
+import org.apache.flink.table.types.logical.BooleanType;
+import org.apache.flink.table.types.logical.DoubleType;
+import org.apache.flink.table.types.logical.FloatType;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.SmallIntType;
+import org.apache.flink.table.types.logical.TinyIntType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+
+/** Utility class for mapping between Flink logical types and Triton data
types. */
+public class TritonTypeMapper {
+
+ /**
+ * Maps a Flink LogicalType to the corresponding Triton data type.
+ *
+ * @param logicalType The Flink logical type
+ * @return The corresponding Triton data type
+ * @throws IllegalArgumentException if the type is not supported
+ */
+ public static TritonDataType toTritonDataType(LogicalType logicalType) {
+ if (logicalType instanceof BooleanType) {
+ return TritonDataType.BOOL;
+ } else if (logicalType instanceof TinyIntType) {
+ return TritonDataType.INT8;
+ } else if (logicalType instanceof SmallIntType) {
+ return TritonDataType.INT16;
+ } else if (logicalType instanceof IntType) {
+ return TritonDataType.INT32;
+ } else if (logicalType instanceof BigIntType) {
+ return TritonDataType.INT64;
+ } else if (logicalType instanceof FloatType) {
+ return TritonDataType.FP32;
+ } else if (logicalType instanceof DoubleType) {
+ return TritonDataType.FP64;
+ } else if (logicalType instanceof VarCharType) {
+ return TritonDataType.BYTES;
+ } else if (logicalType instanceof ArrayType) {
+ // For arrays, we map the element type
+ ArrayType arrayType = (ArrayType) logicalType;
+ return toTritonDataType(arrayType.getElementType());
+ } else {
+ throw new IllegalArgumentException("Unsupported Flink type for
Triton: " + logicalType);
+ }
+ }
+
+ /**
+ * Serializes Flink RowData field value to JSON array for Triton request.
+ *
+ * @param rowData The row data
+ * @param fieldIndex The field index
+ * @param logicalType The logical type of the field
+ * @param dataArray The JSON array to add data to
+ */
+ public static void serializeToJsonArray(
+ RowData rowData, int fieldIndex, LogicalType logicalType,
ArrayNode dataArray) {
+ if (rowData.isNullAt(fieldIndex)) {
+ dataArray.addNull();
+ return;
+ }
+
+ if (logicalType instanceof BooleanType) {
+ dataArray.add(rowData.getBoolean(fieldIndex));
+ } else if (logicalType instanceof TinyIntType) {
+ dataArray.add(rowData.getByte(fieldIndex));
+ } else if (logicalType instanceof SmallIntType) {
+ dataArray.add(rowData.getShort(fieldIndex));
+ } else if (logicalType instanceof IntType) {
+ dataArray.add(rowData.getInt(fieldIndex));
+ } else if (logicalType instanceof BigIntType) {
+ dataArray.add(rowData.getLong(fieldIndex));
+ } else if (logicalType instanceof FloatType) {
+ dataArray.add(rowData.getFloat(fieldIndex));
+ } else if (logicalType instanceof DoubleType) {
+ dataArray.add(rowData.getDouble(fieldIndex));
+ } else if (logicalType instanceof VarCharType) {
+ dataArray.add(rowData.getString(fieldIndex).toString());
+ } else if (logicalType instanceof ArrayType) {
+ ArrayType arrayType = (ArrayType) logicalType;
+ ArrayData arrayData = rowData.getArray(fieldIndex);
+ serializeArrayToJsonArray(arrayData, arrayType.getElementType(),
dataArray);
+ } else {
+ throw new IllegalArgumentException(
+ "Unsupported Flink type for serialization: " +
logicalType);
+ }
+ }
+
+ /**
+ * Serializes Flink ArrayData to JSON array (flattened).
+ *
+ * @param arrayData The array data
+ * @param elementType The element type
+ * @param targetArray The JSON array to add data to
+ */
+ private static void serializeArrayToJsonArray(
+ ArrayData arrayData, LogicalType elementType, ArrayNode
targetArray) {
+ int size = arrayData.size();
+ for (int i = 0; i < size; i++) {
+ if (arrayData.isNullAt(i)) {
+ targetArray.addNull();
+ continue;
+ }
+
+ if (elementType instanceof BooleanType) {
+ targetArray.add(arrayData.getBoolean(i));
+ } else if (elementType instanceof TinyIntType) {
+ targetArray.add(arrayData.getByte(i));
+ } else if (elementType instanceof SmallIntType) {
+ targetArray.add(arrayData.getShort(i));
+ } else if (elementType instanceof IntType) {
+ targetArray.add(arrayData.getInt(i));
+ } else if (elementType instanceof BigIntType) {
+ targetArray.add(arrayData.getLong(i));
+ } else if (elementType instanceof FloatType) {
+ targetArray.add(arrayData.getFloat(i));
+ } else if (elementType instanceof DoubleType) {
+ targetArray.add(arrayData.getDouble(i));
+ } else if (elementType instanceof VarCharType) {
+ targetArray.add(arrayData.getString(i).toString());
+ } else {
+ throw new IllegalArgumentException(
+ "Unsupported array element type: " + elementType);
+ }
+ }
+ }
+
+ /**
+ * Deserializes JSON data to Flink object based on logical type.
+ *
+ * @param dataNode The JSON node containing the data
+ * @param logicalType The target logical type
+ * @return The deserialized object
+ */
+ public static Object deserializeFromJson(JsonNode dataNode, LogicalType
logicalType) {
+ if (dataNode == null || dataNode.isNull()) {
+ return null;
+ }
+
+ if (logicalType instanceof BooleanType) {
+ return dataNode.asBoolean();
+ } else if (logicalType instanceof TinyIntType) {
+ return (byte) dataNode.asInt();
+ } else if (logicalType instanceof SmallIntType) {
+ return (short) dataNode.asInt();
+ } else if (logicalType instanceof IntType) {
+ return dataNode.asInt();
+ } else if (logicalType instanceof BigIntType) {
+ return dataNode.asLong();
+ } else if (logicalType instanceof FloatType) {
+ // Use floatValue() to properly handle the conversion
+ if (dataNode.isNumber()) {
+ return dataNode.floatValue();
+ } else {
+ return (float) dataNode.asDouble();
+ }
+ } else if (logicalType instanceof DoubleType) {
+ return dataNode.asDouble();
+ } else if (logicalType instanceof VarCharType) {
+ return BinaryStringData.fromString(dataNode.asText());
+ } else if (logicalType instanceof ArrayType) {
+ ArrayType arrayType = (ArrayType) logicalType;
+ return deserializeArrayFromJson(dataNode,
arrayType.getElementType());
+ } else {
+ throw new IllegalArgumentException(
+ "Unsupported Flink type for deserialization: " +
logicalType);
+ }
+ }
+
+ /**
+ * Deserializes JSON array to Flink ArrayData.
+ *
+ * @param dataNode The JSON array node
+ * @param elementType The element type
+ * @return The deserialized ArrayData
+ */
+ private static ArrayData deserializeArrayFromJson(JsonNode dataNode,
LogicalType elementType) {
+ if (!dataNode.isArray()) {
+ throw new IllegalArgumentException(
+ "Expected JSON array but got: " + dataNode.getNodeType());
+ }
+
+ int size = dataNode.size();
+
+ // Handle different element types with appropriate array types
+ if (elementType instanceof BooleanType) {
+ boolean[] array = new boolean[size];
+ int i = 0;
+ for (JsonNode element : dataNode) {
+ array[i++] = element.asBoolean();
+ }
+ return new GenericArrayData(array);
+ } else if (elementType instanceof TinyIntType) {
+ byte[] array = new byte[size];
+ int i = 0;
+ for (JsonNode element : dataNode) {
+ array[i++] = (byte) element.asInt();
+ }
+ return new GenericArrayData(array);
+ } else if (elementType instanceof SmallIntType) {
+ short[] array = new short[size];
+ int i = 0;
+ for (JsonNode element : dataNode) {
+ array[i++] = (short) element.asInt();
+ }
+ return new GenericArrayData(array);
+ } else if (elementType instanceof IntType) {
+ int[] array = new int[size];
+ int i = 0;
+ for (JsonNode element : dataNode) {
+ array[i++] = element.asInt();
+ }
+ return new GenericArrayData(array);
+ } else if (elementType instanceof BigIntType) {
+ long[] array = new long[size];
+ int i = 0;
+ for (JsonNode element : dataNode) {
+ array[i++] = element.asLong();
+ }
+ return new GenericArrayData(array);
+ } else if (elementType instanceof FloatType) {
+ float[] array = new float[size];
+ int i = 0;
+ for (JsonNode element : dataNode) {
+ array[i++] = element.isNumber() ? element.floatValue() :
(float) element.asDouble();
+ }
+ return new GenericArrayData(array);
+ } else if (elementType instanceof DoubleType) {
+ double[] array = new double[size];
+ int i = 0;
+ for (JsonNode element : dataNode) {
+ array[i++] = element.asDouble();
+ }
+ return new GenericArrayData(array);
+ } else if (elementType instanceof VarCharType) {
+ BinaryStringData[] array = new BinaryStringData[size];
+ int i = 0;
+ for (JsonNode element : dataNode) {
+ array[i++] = BinaryStringData.fromString(element.asText());
+ }
+ return new GenericArrayData(array);
+ } else {
+ throw new IllegalArgumentException("Unsupported array element
type: " + elementType);
+ }
+ }
+
+ /**
+ * Calculates the shape dimensions for the input data.
+ *
+ * @param logicalType The logical type
+ * @param batchSize The batch size
+ * @return Array of dimensions
+ */
+ public static int[] calculateShape(LogicalType logicalType, int batchSize)
{
+ if (logicalType instanceof ArrayType) {
+ // For arrays, we need to know the array size at runtime
+ // Return shape with batch size and -1 for dynamic dimension
+ return new int[] {batchSize, -1};
+ } else {
+ // For scalar types, shape is just the batch size
+ return new int[] {batchSize};
+ }
+ }
+
+ /**
+ * Calculates the shape dimensions for the input data based on actual row
data.
+ *
+ * @param logicalType The logical type
+ * @param batchSize The batch size
+ * @param rowData The actual row data
+ * @param fieldIndex The field index in the row
+ * @return Array of dimensions
+ */
+ public static int[] calculateShape(
+ LogicalType logicalType, int batchSize, RowData rowData, int
fieldIndex) {
+ if (logicalType instanceof ArrayType) {
+ // For arrays, calculate actual size from the data
+ if (rowData.isNullAt(fieldIndex)) {
+ // Null array - return shape [batchSize, 0]
+ return new int[] {batchSize, 0};
+ }
+ ArrayData arrayData = rowData.getArray(fieldIndex);
+ int arraySize = arrayData.size();
+ return new int[] {batchSize, arraySize};
+ } else {
+ // For scalar types, shape is just the batch size
+ return new int[] {batchSize};
+ }
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonUtils.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonUtils.java
new file mode 100644
index 00000000000..f94aaf9d6ce
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/TritonUtils.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+import okhttp3.OkHttpClient;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Utility class for Triton Inference Server HTTP client management.
+ *
+ * <p>This class implements a reference-counted singleton pattern for
OkHttpClient instances.
+ * Multiple function instances sharing the same timeout configuration will
reuse the same client,
+ * reducing resource consumption in high-parallelism scenarios.
+ *
+ * <p><b>Resource Management:</b>
+ *
+ * <ul>
+ * <li>Clients are cached by timeout key
+ * <li>Reference count tracks active users
+ * <li>Client is closed when reference count reaches zero
+ * <li>Thread-safe via synchronized blocks
+ * </ul>
+ *
+ * <p><b>URL Construction:</b> The {@link #buildInferenceUrl} method
normalizes endpoint URLs to
+ * conform to Triton's REST API specification: {@code
/v2/models/{name}/versions/{version}/infer}
+ */
+public class TritonUtils {
+ private static final Logger LOG =
LoggerFactory.getLogger(TritonUtils.class);
+
+ private static final Object LOCK = new Object();
+
+ private static final Map<Long, ClientValue> cache = new HashMap<>();
+
+ /**
+ * Creates or retrieves a cached HTTP client with the specified
configuration.
+ *
+ * <p>This method implements reference-counted client pooling. Clients
with identical timeout
+ * settings are shared across multiple callers.
+ *
+ * @param timeoutMs Timeout in milliseconds for connect, read, and write
operations
+ * @return A shared or new OkHttpClient instance
+ */
+ public static OkHttpClient createHttpClient(long timeoutMs) {
+ synchronized (LOCK) {
+ ClientValue value = cache.get(timeoutMs);
+ if (value != null) {
+ LOG.debug("Returning an existing Triton HTTP client.");
+ value.referenceCount.incrementAndGet();
+ return value.client;
+ }
+
+ LOG.debug("Building a new Triton HTTP client.");
+ OkHttpClient client =
+ new OkHttpClient.Builder()
+ .connectTimeout(timeoutMs, TimeUnit.MILLISECONDS)
+ .readTimeout(timeoutMs, TimeUnit.MILLISECONDS)
+ .writeTimeout(timeoutMs, TimeUnit.MILLISECONDS)
+ .retryOnConnectionFailure(true)
+ .build();
+
+ cache.put(timeoutMs, new ClientValue(client));
+ return client;
+ }
+ }
+
+ /**
+ * Releases a reference to an HTTP client. When the reference count
reaches zero, the client is
+ * closed and removed from the cache.
+ *
+ * @param client The client to release
+ */
+ public static void releaseHttpClient(OkHttpClient client) {
+ synchronized (LOCK) {
+ Long keyToRemove = null;
+ ClientValue valueToRemove = null;
+
+ for (Map.Entry<Long, ClientValue> entry : cache.entrySet()) {
+ if (entry.getValue().client == client) {
+ keyToRemove = entry.getKey();
+ valueToRemove = entry.getValue();
+ break;
+ }
+ }
+
+ if (valueToRemove != null) {
+ int count = valueToRemove.referenceCount.decrementAndGet();
+ if (count == 0) {
+ LOG.debug("Closing the Triton HTTP client.");
+ cache.remove(keyToRemove);
+ // OkHttpClient doesn't need explicit closing, but we can
clean up resources
+ client.dispatcher().executorService().shutdown();
+ client.connectionPool().evictAll();
+ }
+ }
+ }
+ }
+
+ /**
+ * Builds the inference URL for a specific model and version.
+ *
+ * <p>This method normalizes various endpoint formats to the standard
Triton REST API path:
+ *
+ * <pre>
+ * Input: http://localhost:8000 →
http://localhost:8000/v2/models/mymodel/versions/1/infer
+ * Input: http://localhost:8000/v2 →
http://localhost:8000/v2/models/mymodel/versions/1/infer
+ * Input: http://localhost:8000/v2/models →
http://localhost:8000/v2/models/mymodel/versions/1/infer
+ * </pre>
+ *
+ * @param endpoint The base URL or partial URL of the Triton server
+ * @param modelName The name of the model
+ * @param modelVersion The version of the model (e.g., "1", "latest")
+ * @return The complete inference endpoint URL
+ */
+ public static String buildInferenceUrl(String endpoint, String modelName,
String modelVersion) {
+ String baseUrl = endpoint.replaceAll("/*$", "");
+ if (!baseUrl.endsWith("/v2/models")) {
+ if (baseUrl.endsWith("/v2")) {
+ baseUrl += "/models";
+ } else {
+ baseUrl += "/v2/models";
+ }
+ }
+ return String.format("%s/%s/versions/%s/infer", baseUrl, modelName,
modelVersion);
+ }
+
+ private static class ClientValue {
+ private final OkHttpClient client;
+ private final AtomicInteger referenceCount;
+
+ private ClientValue(OkHttpClient client) {
+ this.client = client;
+ this.referenceCount = new AtomicInteger(1);
+ }
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonClientException.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonClientException.java
new file mode 100644
index 00000000000..702c4b1e3d9
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonClientException.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton.exception;
+
+/**
+ * Exception for client-side errors (HTTP 4xx status codes).
+ *
+ * <p>Indicates user configuration or input data issues that should be fixed
by the user. These
+ * errors are NOT retryable as they require configuration changes.
+ *
+ * <p><b>Common Scenarios:</b>
+ *
+ * <ul>
+ * <li>400 Bad Request: Invalid input shape or data format
+ * <li>404 Not Found: Model name or version doesn't exist
+ * <li>401 Unauthorized: Invalid authentication token
+ * </ul>
+ */
+public class TritonClientException extends TritonException {
+ private static final long serialVersionUID = 1L;
+
+ private final int httpStatus;
+
+ /**
+ * Creates a new client exception.
+ *
+ * @param message The detailed error message
+ * @param httpStatus The HTTP status code (4xx)
+ */
+ public TritonClientException(String message, int httpStatus) {
+ super(String.format("[HTTP %d] %s", httpStatus, message));
+ this.httpStatus = httpStatus;
+ }
+
+ /**
+ * Returns the HTTP status code.
+ *
+ * @return The HTTP status code (4xx)
+ */
+ public int getHttpStatus() {
+ return httpStatus;
+ }
+
+ @Override
+ public boolean isRetryable() {
+ // Client errors require configuration fixes, not retries
+ return false;
+ }
+
+ @Override
+ public ErrorCategory getCategory() {
+ return ErrorCategory.CLIENT_ERROR;
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonException.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonException.java
new file mode 100644
index 00000000000..73870113b01
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonException.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton.exception;
+
+/**
+ * Base exception for all Triton Inference Server integration errors.
+ *
+ * <p>This exception hierarchy provides typed error handling for different
failure scenarios:
+ *
+ * <ul>
+ * <li>{@link TritonClientException}: HTTP 4xx errors (user configuration
issues)
+ * <li>{@link TritonServerException}: HTTP 5xx errors (server-side issues)
+ * <li>{@link TritonNetworkException}: Network/connection failures
+ * <li>{@link TritonSchemaException}: Shape/type mismatch errors
+ * </ul>
+ */
+public class TritonException extends RuntimeException {
+ private static final long serialVersionUID = 1L;
+
+ /** Error category for classification and monitoring. */
+ public enum ErrorCategory {
+ /** Client-side errors (4xx): Bad configuration, invalid input, etc. */
+ CLIENT_ERROR,
+
+ /** Server-side errors (5xx): Inference failure, service unavailable,
etc. */
+ SERVER_ERROR,
+
+ /** Network errors: Connection timeout, DNS failure, etc. */
+ NETWORK_ERROR,
+
+ /** Schema/type errors: Shape mismatch, incompatible types, etc. */
+ SCHEMA_ERROR,
+
+ /** Unknown or unclassified errors. */
+ UNKNOWN
+ }
+
+ /**
+ * Creates a new Triton exception with the specified message.
+ *
+ * @param message The detailed error message
+ */
+ public TritonException(String message) {
+ super(message);
+ }
+
+ /**
+ * Creates a new Triton exception with the specified message and cause.
+ *
+ * @param message The detailed error message
+ * @param cause The underlying cause
+ */
+ public TritonException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ /**
+ * Returns true if this error is retryable with exponential backoff.
+ *
+ * <p>Default implementation returns false. Subclasses should override if
the error condition is
+ * transient (e.g., 503 Service Unavailable).
+ *
+ * @return true if the operation can be retried
+ */
+ public boolean isRetryable() {
+ return false;
+ }
+
+ /**
+ * Returns the error category for logging, monitoring, and alerting
purposes.
+ *
+ * <p>This can be used to:
+ *
+ * <ul>
+ * <li>Route errors to appropriate handling logic
+ * <li>Aggregate metrics by error type
+ * <li>Configure different retry strategies
+ * </ul>
+ *
+ * @return The error category
+ */
+ public ErrorCategory getCategory() {
+ return ErrorCategory.UNKNOWN;
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonNetworkException.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonNetworkException.java
new file mode 100644
index 00000000000..af2747afbda
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonNetworkException.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton.exception;
+
+/**
+ * Exception for network-level errors (connection failures, timeouts).
+ *
+ * <p>Indicates transient network issues such as DNS resolution failures,
connection timeouts, or
+ * socket errors. These errors are typically retryable with exponential
backoff.
+ *
+ * <p><b>Common Scenarios:</b>
+ *
+ * <ul>
+ * <li>Connection refused: Server not reachable
+ * <li>Connection timeout: Network latency or firewall issues
+ * <li>DNS resolution failure: Hostname cannot be resolved
+ * <li>Socket timeout: Long-running request exceeded timeout
+ * </ul>
+ */
+public class TritonNetworkException extends TritonException {
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * Creates a new network exception.
+ *
+ * @param message The detailed error message
+ * @param cause The underlying IOException or network error
+ */
+ public TritonNetworkException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ @Override
+ public boolean isRetryable() {
+ // Network errors are typically transient and retryable
+ return true;
+ }
+
+ @Override
+ public ErrorCategory getCategory() {
+ return ErrorCategory.NETWORK_ERROR;
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonSchemaException.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonSchemaException.java
new file mode 100644
index 00000000000..dba00ec09ad
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonSchemaException.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton.exception;
+
+/**
+ * Exception for schema/shape/type mismatch errors.
+ *
+ * <p>Indicates that the Flink data type or tensor shape does not match what
the Triton model
+ * expects. These errors are NOT retryable as they require schema or
configuration fixes.
+ *
+ * <p><b>Common Scenarios:</b>
+ *
+ * <ul>
+ * <li>Shape mismatch: Sent [1,224,224,3] but model expects [1,3,224,224]
+ * <li>Type mismatch: Sent FP32 but model expects INT8
+ * <li>Dimension error: Sent scalar but model expects array
+ * </ul>
+ *
+ * <p>This exception includes detailed information about both expected and
actual schemas to help
+ * users diagnose and fix the issue.
+ */
+public class TritonSchemaException extends TritonException {
+ private static final long serialVersionUID = 1L;
+
+ private final String expectedSchema;
+ private final String actualSchema;
+
+ /**
+ * Creates a new schema exception.
+ *
+ * @param message The detailed error message
+ * @param expectedSchema The schema/shape expected by Triton model
+ * @param actualSchema The schema/shape actually sent
+ */
+ public TritonSchemaException(String message, String expectedSchema, String
actualSchema) {
+ super(
+ String.format(
+ "%s\n=== Expected Schema ===\n%s\n=== Actual Schema
===\n%s",
+ message, expectedSchema, actualSchema));
+ this.expectedSchema = expectedSchema;
+ this.actualSchema = actualSchema;
+ }
+
+ /**
+ * Returns the schema/shape expected by the Triton model.
+ *
+ * @return The expected schema description
+ */
+ public String getExpectedSchema() {
+ return expectedSchema;
+ }
+
+ /**
+ * Returns the schema/shape that was actually sent.
+ *
+ * @return The actual schema description
+ */
+ public String getActualSchema() {
+ return actualSchema;
+ }
+
+ @Override
+ public boolean isRetryable() {
+ // Schema errors require configuration fixes, not retries
+ return false;
+ }
+
+ @Override
+ public ErrorCategory getCategory() {
+ return ErrorCategory.SCHEMA_ERROR;
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonServerException.java
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonServerException.java
new file mode 100644
index 00000000000..a2724b9cfbe
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/java/org/apache/flink/model/triton/exception/TritonServerException.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton.exception;
+
+/**
+ * Exception for server-side errors (HTTP 5xx status codes).
+ *
+ * <p>Indicates Triton Inference Server issues such as inference crashes, out
of memory, or service
+ * overload. Some server errors are retryable (e.g., 503 Service Unavailable,
504 Gateway Timeout).
+ *
+ * <p><b>Common Scenarios:</b>
+ *
+ * <ul>
+ * <li>500 Internal Server Error: Model inference crash (NOT retryable)
+ * <li>503 Service Unavailable: Server overloaded (retryable)
+ * <li>504 Gateway Timeout: Inference took too long (retryable)
+ * </ul>
+ */
+public class TritonServerException extends TritonException {
+ private static final long serialVersionUID = 1L;
+
+ private final int httpStatus;
+
+ /**
+ * Creates a new server exception.
+ *
+ * @param message The detailed error message
+ * @param httpStatus The HTTP status code (5xx)
+ */
+ public TritonServerException(String message, int httpStatus) {
+ super(String.format("[HTTP %d] %s", httpStatus, message));
+ this.httpStatus = httpStatus;
+ }
+
+ /**
+ * Returns the HTTP status code.
+ *
+ * @return The HTTP status code (5xx)
+ */
+ public int getHttpStatus() {
+ return httpStatus;
+ }
+
+ @Override
+ public boolean isRetryable() {
+ // 503 Service Unavailable and 504 Gateway Timeout are retryable
+ // 500 Internal Server Error typically requires investigation
+ return httpStatus == 503 || httpStatus == 504;
+ }
+
+ @Override
+ public ErrorCategory getCategory() {
+ return ErrorCategory.SERVER_ERROR;
+ }
+}
diff --git a/flink-models/flink-model-triton/src/main/resources/META-INF/NOTICE
b/flink-models/flink-model-triton/src/main/resources/META-INF/NOTICE
new file mode 100644
index 00000000000..aa39530aa65
--- /dev/null
+++ b/flink-models/flink-model-triton/src/main/resources/META-INF/NOTICE
@@ -0,0 +1,19 @@
+flink-model-triton
+Copyright 2014-2025 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
+
+This project bundles the following dependencies under the Apache Software
License 2.0 (http://www.apache.org/licenses/LICENSE-2.0.txt)
+
+- com.fasterxml.jackson.core:jackson-annotations:2.15.2
+- com.fasterxml.jackson.core:jackson-core:2.15.2
+- com.fasterxml.jackson.core:jackson-databind:2.15.2
+- com.squareup.okhttp3:okhttp:4.12.0
+- com.squareup.okio:okio:3.6.0
+- com.squareup.okio:okio-jvm:3.6.0
+- org.jetbrains:annotations:13.0
+- org.jetbrains.kotlin:kotlin-stdlib:1.8.21
+- org.jetbrains.kotlin:kotlin-stdlib-jdk7:1.8.21
+- org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.8.21
+- org.jetbrains.kotlin:kotlin-stdlib-common:1.9.10
\ No newline at end of file
diff --git
a/flink-models/flink-model-triton/src/main/resources/META-INF/services/org.apache.flink.table.factories.Factory
b/flink-models/flink-model-triton/src/main/resources/META-INF/services/org.apache.flink.table.factories.Factory
new file mode 100644
index 00000000000..abf2ba6d518
--- /dev/null
+++
b/flink-models/flink-model-triton/src/main/resources/META-INF/services/org.apache.flink.table.factories.Factory
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+org.apache.flink.model.triton.TritonModelProviderFactory
diff --git
a/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonModelProviderFactoryTest.java
b/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonModelProviderFactoryTest.java
new file mode 100644
index 00000000000..f2f465b0aaa
--- /dev/null
+++
b/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonModelProviderFactoryTest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+/** Test for {@link TritonModelProviderFactory}. */
+public class TritonModelProviderFactoryTest {
+
+ @Test
+ public void testFactoryIdentifier() {
+ TritonModelProviderFactory factory = new TritonModelProviderFactory();
+ assertEquals("triton", factory.factoryIdentifier());
+ }
+
+ @Test
+ public void testRequiredOptions() {
+ TritonModelProviderFactory factory = new TritonModelProviderFactory();
+ assertEquals(2, factory.requiredOptions().size());
+ assertTrue(factory.requiredOptions().contains(TritonOptions.ENDPOINT));
+
assertTrue(factory.requiredOptions().contains(TritonOptions.MODEL_NAME));
+ }
+
+ @Test
+ public void testOptionalOptions() {
+ TritonModelProviderFactory factory = new TritonModelProviderFactory();
+ assertEquals(10, factory.optionalOptions().size());
+
assertTrue(factory.optionalOptions().contains(TritonOptions.MODEL_VERSION));
+ assertTrue(factory.optionalOptions().contains(TritonOptions.TIMEOUT));
+
assertTrue(factory.optionalOptions().contains(TritonOptions.FLATTEN_BATCH_DIM));
+ assertTrue(factory.optionalOptions().contains(TritonOptions.PRIORITY));
+
assertTrue(factory.optionalOptions().contains(TritonOptions.SEQUENCE_ID));
+
assertTrue(factory.optionalOptions().contains(TritonOptions.SEQUENCE_START));
+
assertTrue(factory.optionalOptions().contains(TritonOptions.SEQUENCE_END));
+
assertTrue(factory.optionalOptions().contains(TritonOptions.COMPRESSION));
+
assertTrue(factory.optionalOptions().contains(TritonOptions.AUTH_TOKEN));
+
assertTrue(factory.optionalOptions().contains(TritonOptions.CUSTOM_HEADERS));
+ }
+}
diff --git
a/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonTypeMapperTest.java
b/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonTypeMapperTest.java
new file mode 100644
index 00000000000..df0c033a4cb
--- /dev/null
+++
b/flink-models/flink-model-triton/src/test/java/org/apache/flink/model/triton/TritonTypeMapperTest.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.model.triton;
+
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.GenericArrayData;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.binary.BinaryStringData;
+import org.apache.flink.table.types.logical.ArrayType;
+import org.apache.flink.table.types.logical.BigIntType;
+import org.apache.flink.table.types.logical.BooleanType;
+import org.apache.flink.table.types.logical.DoubleType;
+import org.apache.flink.table.types.logical.FloatType;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.SmallIntType;
+import org.apache.flink.table.types.logical.TinyIntType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.node.ArrayNode;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/** Test for {@link TritonTypeMapper}. */
+public class TritonTypeMapperTest {
+
+ private static final ObjectMapper objectMapper = new ObjectMapper();
+
+ @Test
+ public void testToTritonDataType() {
+ assertEquals(TritonDataType.BOOL,
TritonTypeMapper.toTritonDataType(new BooleanType()));
+ assertEquals(TritonDataType.INT8,
TritonTypeMapper.toTritonDataType(new TinyIntType()));
+ assertEquals(TritonDataType.INT16,
TritonTypeMapper.toTritonDataType(new SmallIntType()));
+ assertEquals(TritonDataType.INT32,
TritonTypeMapper.toTritonDataType(new IntType()));
+ assertEquals(TritonDataType.INT64,
TritonTypeMapper.toTritonDataType(new BigIntType()));
+ assertEquals(TritonDataType.FP32,
TritonTypeMapper.toTritonDataType(new FloatType()));
+ assertEquals(TritonDataType.FP64,
TritonTypeMapper.toTritonDataType(new DoubleType()));
+ assertEquals(
+ TritonDataType.BYTES,
+ TritonTypeMapper.toTritonDataType(new
VarCharType(VarCharType.MAX_LENGTH)));
+ }
+
+ @Test
+ public void testToTritonDataTypeForArray() {
+ // For arrays, returns the element type's Triton type
+ assertEquals(
+ TritonDataType.FP32,
+ TritonTypeMapper.toTritonDataType(new ArrayType(new
FloatType())));
+ assertEquals(
+ TritonDataType.INT32,
+ TritonTypeMapper.toTritonDataType(new ArrayType(new
IntType())));
+ }
+
+ @Test
+ public void testSerializeScalarTypes() {
+ // Test boolean
+ RowData boolRow = GenericRowData.of(true);
+ ArrayNode boolArray = objectMapper.createArrayNode();
+ TritonTypeMapper.serializeToJsonArray(boolRow, 0, new BooleanType(),
boolArray);
+ assertEquals(1, boolArray.size());
+ assertEquals(true, boolArray.get(0).asBoolean());
+
+ // Test int
+ RowData intRow = GenericRowData.of(42);
+ ArrayNode intArray = objectMapper.createArrayNode();
+ TritonTypeMapper.serializeToJsonArray(intRow, 0, new IntType(),
intArray);
+ assertEquals(1, intArray.size());
+ assertEquals(42, intArray.get(0).asInt());
+
+ // Test float
+ RowData floatRow = GenericRowData.of(3.14f);
+ ArrayNode floatArray = objectMapper.createArrayNode();
+ TritonTypeMapper.serializeToJsonArray(floatRow, 0, new FloatType(),
floatArray);
+ assertEquals(1, floatArray.size());
+ assertEquals(3.14f, floatArray.get(0).floatValue(), 0.001f);
+
+ // Test string
+ RowData stringRow =
GenericRowData.of(BinaryStringData.fromString("hello"));
+ ArrayNode stringArray = objectMapper.createArrayNode();
+ TritonTypeMapper.serializeToJsonArray(
+ stringRow, 0, new VarCharType(VarCharType.MAX_LENGTH),
stringArray);
+ assertEquals(1, stringArray.size());
+ assertEquals("hello", stringArray.get(0).asText());
+ }
+
+ @Test
+ public void testSerializeArrayType() {
+ Float[] floatArray = new Float[] {1.0f, 2.0f, 3.0f};
+ ArrayData arrayData = new GenericArrayData(floatArray);
+ RowData rowData = GenericRowData.of(arrayData);
+
+ ArrayNode jsonArray = objectMapper.createArrayNode();
+ TritonTypeMapper.serializeToJsonArray(
+ rowData, 0, new ArrayType(new FloatType()), jsonArray);
+
+ // Array should be flattened
+ assertEquals(3, jsonArray.size());
+ assertEquals(1.0f, jsonArray.get(0).floatValue(), 0.001f);
+ assertEquals(2.0f, jsonArray.get(1).floatValue(), 0.001f);
+ assertEquals(3.0f, jsonArray.get(2).floatValue(), 0.001f);
+ }
+
+ @Test
+ public void testCalculateShape() {
+ // Scalar type
+ int[] scalarShape = TritonTypeMapper.calculateShape(new IntType(), 1);
+ assertArrayEquals(new int[] {1}, scalarShape);
+
+ // Array type
+ int[] arrayShape = TritonTypeMapper.calculateShape(new ArrayType(new
FloatType()), 1);
+ assertArrayEquals(new int[] {1, -1}, arrayShape);
+
+ // Batch size > 1
+ int[] batchShape = TritonTypeMapper.calculateShape(new IntType(), 4);
+ assertArrayEquals(new int[] {4}, batchShape);
+ }
+
+ @Test
+ public void testDeserializeScalarTypes() {
+ // Test int
+ assertEquals(
+ 42,
+
TritonTypeMapper.deserializeFromJson(objectMapper.valueToTree(42), new
IntType()));
+
+ // Test float
+ Object floatResult =
+ TritonTypeMapper.deserializeFromJson(
+ objectMapper.valueToTree(3.14f), new FloatType());
+ assertEquals(3.14f, (Float) floatResult, 0.001f);
+
+ // Test string
+ Object stringResult =
+ TritonTypeMapper.deserializeFromJson(
+ objectMapper.valueToTree("hello"), new
VarCharType(VarCharType.MAX_LENGTH));
+ assertEquals("hello", stringResult.toString());
+ }
+
+ @Test
+ public void testDeserializeArrayType() {
+ float[] floatArray = new float[] {1.0f, 2.0f, 3.0f};
+ Object result =
+ TritonTypeMapper.deserializeFromJson(
+ objectMapper.valueToTree(floatArray), new
ArrayType(new FloatType()));
+
+ ArrayData arrayData = (ArrayData) result;
+ assertEquals(3, arrayData.size());
+ assertEquals(1.0f, arrayData.getFloat(0), 0.001f);
+ assertEquals(2.0f, arrayData.getFloat(1), 0.001f);
+ assertEquals(3.0f, arrayData.getFloat(2), 0.001f);
+ }
+
+ @Test
+ public void testSerializeNull() {
+ GenericRowData nullRow = new GenericRowData(1);
+ nullRow.setField(0, null);
+
+ ArrayNode jsonArray = objectMapper.createArrayNode();
+ TritonTypeMapper.serializeToJsonArray(nullRow, 0, new IntType(),
jsonArray);
+
+ assertEquals(1, jsonArray.size());
+ assertEquals(true, jsonArray.get(0).isNull());
+ }
+}
diff --git a/flink-models/pom.xml b/flink-models/pom.xml
index e7ae8ce866b..e8b5c5316b9 100644
--- a/flink-models/pom.xml
+++ b/flink-models/pom.xml
@@ -35,6 +35,7 @@ under the License.
<modules>
<module>flink-model-openai</module>
+ <module>flink-model-triton</module>
</modules>
<dependencies>
diff --git a/tools/ci/stage.sh b/tools/ci/stage.sh
index d473056fb7f..9c1129bc029 100755
--- a/tools/ci/stage.sh
+++ b/tools/ci/stage.sh
@@ -121,6 +121,7 @@ flink-metrics/flink-metrics-slf4j,\
flink-metrics/flink-metrics-otel,\
flink-connectors/flink-connector-base,\
flink-models/flink-model-openai,\
+flink-models/flink-model-triton,\
"
MODULES_TESTS="\