This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new ef04bb14902 [FLINK-39823][table] Fix incorrect type inference for
typed PTF table args
ef04bb14902 is described below
commit ef04bb14902657205e3f4131c6c16f4fa4855b89
Author: Timo Walther <[email protected]>
AuthorDate: Fri Jun 5 18:23:13 2026 +0200
[FLINK-39823][table] Fix incorrect type inference for typed PTF table args
This closes #28323.
---
.../table/types/inference/TypeInferenceUtil.java | 64 +++++++++++++++++---
.../codegen/ProcessTableRunnerGenerator.scala | 2 +-
.../api/QueryOperationSqlSerializationTest.java | 1 +
.../table/api/QueryOperationTestPrograms.java | 27 +++++++++
.../plan/stream/sql/ProcessTableFunctionTest.java | 70 ++++++++++++++++++----
.../plan/stream/sql/ProcessTableFunctionTest.xml | 39 ++++++++++++
6 files changed, 185 insertions(+), 18 deletions(-)
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
index 3cffbe80e19..d1d7c15d743 100644
---
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
@@ -97,14 +97,27 @@ public final class TypeInferenceUtil {
/** Casts the call's argument if necessary. */
public static CallContext castArguments(
TypeInference typeInference, CallContext callContext, @Nullable
DataType outputType) {
- return castArguments(typeInference, callContext, outputType, true);
+ return castArgumentsInternal(typeInference, callContext, outputType,
true, false);
}
- private static CallContext castArguments(
+ /**
+ * Casts the call's argument if necessary. Enables casting of table
arguments during code
+ * generation.
+ */
+ public static CallContext castArguments(
TypeInference typeInference,
CallContext callContext,
@Nullable DataType outputType,
- boolean throwOnInferInputFailure) {
+ boolean castTableArgs) {
+ return castArgumentsInternal(typeInference, callContext, outputType,
true, castTableArgs);
+ }
+
+ private static CallContext castArgumentsInternal(
+ TypeInference typeInference,
+ CallContext callContext,
+ @Nullable DataType outputType,
+ boolean throwOnInferInputFailure,
+ boolean castTableArgs) {
final List<DataType> actualTypes = callContext.getArgumentDataTypes();
typeInference
@@ -120,7 +133,12 @@ public final class TypeInferenceUtil {
});
final CastCallContext castCallContext =
- inferInputTypes(typeInference, callContext, outputType,
throwOnInferInputFailure);
+ inferInputTypes(
+ typeInference,
+ callContext,
+ outputType,
+ throwOnInferInputFailure,
+ castTableArgs);
// final check if the call is valid after casting
final List<DataType> expectedTypes =
castCallContext.getArgumentDataTypes();
@@ -311,7 +329,7 @@ public final class TypeInferenceUtil {
// We might not be able to infer the input types at this
moment, if the surrounding
// function does not provide an explicit input type strategy.
final CallContext adaptedContext =
- castArguments(typeInference, callContext, null, false);
+ castArgumentsInternal(typeInference, callContext,
null, false, false);
return typeInference
.getInputTypeStrategy()
.inferInputTypes(adaptedContext, false)
@@ -456,7 +474,8 @@ public final class TypeInferenceUtil {
TypeInference typeInference,
CallContext callContext,
@Nullable DataType outputType,
- boolean throwOnFailure) {
+ boolean throwOnFailure,
+ boolean castTableArgs) {
final CastCallContext castCallContext = new
CastCallContext(callContext, outputType);
@@ -481,7 +500,38 @@ public final class TypeInferenceUtil {
}
return null;
}
- return semantics.dataType();
+ final DataType expectedType =
+
expectedArg.getDataType().orElse(null);
+ final DataType actualType =
+
castCallContext.getArgumentDataTypes().get(pos);
+ if (expectedType == null) {
+ return actualType;
+ }
+ if (!supportsImplicitCast(
+
actualType.getLogicalType(),
+
expectedType.getLogicalType())) {
+ if (throwOnFailure) {
+ throw new
ValidationException(
+ String.format(
+ "Invalid
argument value. Argument '%s' expects a typed table, "
+ +
"but the provided table is incompatible and cannot "
+ +
"be cast to the target type.",
+
expectedArg.getName()));
+ }
+ return null;
+ }
+ // During planning, it is not
possible to cast table
+ // arguments.
+ // This can only happen during
code generation when a
+ // table is evaluated per row and
casts become "scalar
+ // casts".
+ // It also ensures that
pass-through columns (including
+ // the PARTITION BY column names)
are preserved during
+ // planning.
+ if (castTableArgs) {
+ return expectedType;
+ }
+ return actualType;
} else if
(expectedArg.is(StaticArgumentTrait.MODEL)) {
final ModelSemantics semantics =
callContext.getModelSemantics(pos).orElse(null);
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala
index ea9f5ca4c77..deca0f188e2 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala
@@ -96,7 +96,7 @@ object ProcessTableRunnerGenerator {
)
val functionTerm = ctx.addReusableFunction(udf)
val inference = udf.getTypeInference(dataTypeFactory)
- val castCallContext = TypeInferenceUtil.castArguments(inference,
callContext, null)
+ val castCallContext = TypeInferenceUtil.castArguments(inference,
callContext, null, true)
// Enrich argument types with conversion class
val enrichedArgumentDataTypes =
toScala(castCallContext.getArgumentDataTypes)
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
index 96e53484d20..d7a512327d6 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
@@ -70,6 +70,7 @@ public class QueryOperationSqlSerializationTest implements
TableTestProgramRunne
QueryOperationTestPrograms.ACCESSING_NESTED_COLUMN,
QueryOperationTestPrograms.PTF_ROW_SEMANTIC_TABLE,
QueryOperationTestPrograms.PTF_SET_SEMANTIC_TABLE,
+ QueryOperationTestPrograms.PTF_TYPED_SET_SEMANTIC_TABLE,
QueryOperationTestPrograms.PTF_ORDER_BY,
QueryOperationTestPrograms.INLINE_FUNCTION_SERIALIZATION,
QueryOperationTestPrograms.ML_PREDICT_MODEL_API,
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
index 71e049a84ff..1c59e7d1400 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
@@ -29,6 +29,7 @@ import
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctio
import
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ChainedSendingFunction;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.RowSemanticTableFunction;
import
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.SetSemanticTableFunction;
+import
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TypedSetSemanticTableFunction;
import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions;
import
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions;
import org.apache.flink.table.test.program.SinkTestStep;
@@ -1128,6 +1129,32 @@ public class QueryOperationTestPrograms {
"sink")
.build();
+ static final TableTestProgram PTF_TYPED_SET_SEMANTIC_TABLE =
+ TableTestProgram.of("ptf-typed-set-semantic-table", "verifies SQL
serialization")
+ .setupSql(BASIC_VALUES)
+ .setupTableSink(
+ SinkTestStep.newBuilder("sink")
+ .addSchema(KEYED_BASE_SINK_SCHEMA)
+ .consumedValues(
+ "+I[Bob, {User(s='Bob', i=12),
1}]",
+ "+I[Alice, {User(s='Alice', i=42),
1}]")
+ .build())
+ .runSql(
+ "SELECT `$$T_FUNC`.`name`, `$$T_FUNC`.`out` FROM
TABLE(\n"
+ + " inlineFunction$00(\n"
+ + " (\n"
+ + " SELECT `$$T_SOURCE`.`name`,
`$$T_SOURCE`.`score` FROM `default_catalog`.`default_database`.`t` $$T_SOURCE\n"
+ + " ) PARTITION BY (`name`), 1,
DEFAULT, 'TypedSetSemanticTableFunction')\n"
+ + ") $$T_FUNC")
+ .runTableApi(
+ env ->
+ env.fromCall(
+
TypedSetSemanticTableFunction.class,
+
env.from("t").partitionBy($("name")).asArgument("u"),
+ lit(1).asArgument("i")),
+ "sink")
+ .build();
+
public static final TableTestProgram ML_PREDICT_MODEL_API =
TableTestProgram.of("ml-predict-model-api", "ml-predict using
model API")
.setupTableSource(SIMPLE_FEATURES_SOURCE)
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java
index 17a303b93f8..d68141be11e 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java
@@ -21,7 +21,9 @@ package org.apache.flink.table.planner.plan.stream.sql;
import org.apache.flink.table.annotation.ArgumentHint;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.ExplainDetail;
+import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.ProcessTableFunction;
import org.apache.flink.table.functions.TableFunction;
@@ -58,8 +60,11 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
+import javax.annotation.Nullable;
+
import java.util.EnumSet;
import java.util.Optional;
+import java.util.function.Function;
import java.util.stream.Stream;
import static java.util.Collections.singletonList;
@@ -69,6 +74,8 @@ import static
org.apache.flink.table.annotation.ArgumentTrait.PASS_COLUMNS_THROU
import static
org.apache.flink.table.annotation.ArgumentTrait.ROW_SEMANTIC_TABLE;
import static
org.apache.flink.table.annotation.ArgumentTrait.SET_SEMANTIC_TABLE;
import static org.apache.flink.table.annotation.ArgumentTrait.SUPPORT_UPDATES;
+import static org.apache.flink.table.api.Expressions.$;
+import static org.apache.flink.table.api.Expressions.row;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/** Tests for the type inference and planning part of {@link
ProcessTableFunction}. */
@@ -126,11 +133,27 @@ public class ProcessTableFunctionTest extends
TableTestBase {
@Test
void testTypedRowSemanticTableIgnoringColumnNames() {
util.addTemporarySystemFunction("f",
TypedRowSemanticTableFunction.class);
- // function expects <STRING name, INT score>
- // but table is <STRING name, INT different>
+ // Function expects <STRING s, INT i> but table is <STRING name, INT
different>
util.verifyRelPlan("SELECT * FROM f(u => TABLE t_name_diff, i => 1)");
}
+ @Test
+ void testTypedSetSemanticTableForwardingOriginalColumnName() {
+ util.addTemporarySystemFunction("f",
TypedSetSemanticTableFunction.class);
+ // Function expects <STRING s, INT i> but table is <STRING name, INT
different>.
+ // The partition key is a valid pass-through column and its name
should be preserved.
+ util.verifyRelPlan("SELECT * FROM f(u => TABLE t_name_diff PARTITION
BY name, i => 1)");
+ }
+
+ @Test
+ void testTypedSetSemanticTableForwardingOriginalColumnNameTableApi() {
+ util.addTemporarySystemFunction("f",
TypedSetSemanticTableFunction.class);
+ // Function expects <STRING s, INT i> but table is <STRING name, INT
different>.
+ // The partition key is a valid pass-through column and its name
should be preserved.
+ util.verifyRelPlan(
+
util.tableEnv().from("t_name_diff").partitionBy($("name")).process("f", 1));
+ }
+
@Test
void testDifferentPartitionKey() {
util.addTemporarySystemFunction("f", SetSemanticTableFunction.class);
@@ -300,11 +323,14 @@ public class ProcessTableFunctionTest extends
TableTestBase {
() -> {
if (spec.selectSql != null) {
util.verifyExecPlan(spec.selectSql);
- } else {
+ } else if (spec.insertIntoSql != null) {
util.verifyExecPlan(
util.tableEnv()
.createStatementSet()
.addInsertSql(spec.insertIntoSql));
+ } else {
+ assert spec.selectTableApi != null;
+
util.verifyExecPlan(spec.selectTableApi.apply(util.tableEnv()));
}
})
.satisfies(anyCauseMatches(spec.errorMessage));
@@ -527,7 +553,18 @@ public class ProcessTableFunctionTest extends
TableTestBase {
+ "in20 => TABLE t PARTITION BY score, "
+ "in21 => TABLE t PARTITION BY score"
+ ")",
- "Unsupported table argument count. Currently, the
number of input tables is limited to 20."));
+ "Unsupported table argument count. Currently, the
number of input tables is limited to 20."),
+ ErrorSpec.ofTableApi(
+ "invalid scalar args",
+ TypedRowSemanticTableFunction.class,
+ env -> env.from("t").process("f", 1.99),
+ "Invalid argument type at position 1. Data type INT
expected but DOUBLE NOT NULL passed."),
+ ErrorSpec.ofTableApi(
+ "invalid table args in Table API",
+ TypedRowSemanticTableFunction.class,
+ env -> env.fromValues(row("Bob", 1.99)).process("f",
1),
+ "Invalid argument value. Argument 'u' expects a typed
table, but the provided "
+ + "table is incompatible and cannot be cast to
the target type."));
}
/** Testing function. */
@@ -622,19 +659,22 @@ public class ProcessTableFunctionTest extends
TableTestBase {
private static class ErrorSpec {
private final String description;
private final Class<? extends UserDefinedFunction> functionClass;
- private final String selectSql;
- private final String insertIntoSql;
+ private final @Nullable String selectSql;
+ private final @Nullable String insertIntoSql;
+ private final @Nullable Function<TableEnvironment, Table>
selectTableApi;
private final String errorMessage;
private ErrorSpec(
String description,
Class<? extends UserDefinedFunction> functionClass,
- String selectSql,
- String insertIntoSql,
+ @Nullable String selectSql,
+ @Nullable String insertIntoSql,
+ @Nullable Function<TableEnvironment, Table> selectTableApi,
String errorMessage) {
this.description = description;
this.functionClass = functionClass;
this.selectSql = selectSql;
+ this.selectTableApi = selectTableApi;
this.insertIntoSql = insertIntoSql;
this.errorMessage = errorMessage;
}
@@ -644,7 +684,7 @@ public class ProcessTableFunctionTest extends TableTestBase
{
Class<? extends UserDefinedFunction> functionClass,
String selectSql,
String errorMessage) {
- return new ErrorSpec(description, functionClass, selectSql, null,
errorMessage);
+ return new ErrorSpec(description, functionClass, selectSql, null,
null, errorMessage);
}
static ErrorSpec ofInsertInto(
@@ -652,7 +692,17 @@ public class ProcessTableFunctionTest extends
TableTestBase {
Class<? extends UserDefinedFunction> functionClass,
String insertIntoSql,
String errorMessage) {
- return new ErrorSpec(description, functionClass, null,
insertIntoSql, errorMessage);
+ return new ErrorSpec(
+ description, functionClass, null, insertIntoSql, null,
errorMessage);
+ }
+
+ static ErrorSpec ofTableApi(
+ String description,
+ Class<? extends UserDefinedFunction> functionClass,
+ Function<TableEnvironment, Table> selectTableApi,
+ String errorMessage) {
+ return new ErrorSpec(
+ description, functionClass, null, null, selectTableApi,
errorMessage);
}
@Override
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.xml
index f0f340f4eb7..36c40cfc695 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.xml
@@ -291,6 +291,45 @@ LogicalProject(out=[$0])
ProcessTableFunction(invocation=[f(TABLE(#0), 1, DEFAULT(), DEFAULT())],
uid=[null], select=[out], rowType=[RecordType(VARCHAR(2147483647) out)])
+- Calc(select=['Bob' AS name, 12 AS different])
+- Values(tuples=[[{ 0 }]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testTypedSetSemanticTableForwardingOriginalColumnName">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM f(u => TABLE t_name_diff PARTITION BY name, i =>
1)]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(name=[$0], out=[$1])
++- LogicalTableFunctionScan(invocation=[f(TABLE(#0) PARTITION BY($0), 1,
DEFAULT(), DEFAULT())], rowType=[RecordType(CHAR(3) name, VARCHAR(2147483647)
out)])
+ +- LogicalProject(name=[$0], different=[$1])
+ +- LogicalProject(name=[_UTF-16LE'Bob'], different=[12])
+ +- LogicalValues(tuples=[[{ 0 }]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+ProcessTableFunction(invocation=[f(TABLE(#0) PARTITION BY($0), 1, DEFAULT(),
DEFAULT())], uid=[f], select=[name,out], rowType=[RecordType(CHAR(3) name,
VARCHAR(2147483647) out)])
++- Exchange(distribution=[hash[name]])
+ +- Calc(select=['Bob' AS name, 12 AS different])
+ +- Values(tuples=[[{ 0 }]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase
name="testTypedSetSemanticTableForwardingOriginalColumnNameTableApi">
+ <Resource name="ast">
+ <![CDATA[
+LogicalTableFunctionScan(invocation=[f(TABLE(#0) PARTITION BY($0), 1,
DEFAULT(), _UTF-16LE'f')], rowType=[RecordType(CHAR(3) name,
VARCHAR(2147483647) out)])
++- LogicalProject(name=[_UTF-16LE'Bob'], different=[12])
+ +- LogicalValues(tuples=[[{ 0 }]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+ProcessTableFunction(invocation=[f(TABLE(#0) PARTITION BY($0), 1, DEFAULT(),
_UTF-16LE'f')], uid=[f], select=[name,out], rowType=[RecordType(CHAR(3) name,
VARCHAR(2147483647) out)])
++- Exchange(distribution=[hash[name]])
+ +- Calc(select=['Bob' AS name, 12 AS different])
+ +- Values(tuples=[[{ 0 }]])
]]>
</Resource>
</TestCase>