This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new ef04bb14902 [FLINK-39823][table] Fix incorrect type inference for 
typed PTF table args
ef04bb14902 is described below

commit ef04bb14902657205e3f4131c6c16f4fa4855b89
Author: Timo Walther <[email protected]>
AuthorDate: Fri Jun 5 18:23:13 2026 +0200

    [FLINK-39823][table] Fix incorrect type inference for typed PTF table args
    
    This closes #28323.
---
 .../table/types/inference/TypeInferenceUtil.java   | 64 +++++++++++++++++---
 .../codegen/ProcessTableRunnerGenerator.scala      |  2 +-
 .../api/QueryOperationSqlSerializationTest.java    |  1 +
 .../table/api/QueryOperationTestPrograms.java      | 27 +++++++++
 .../plan/stream/sql/ProcessTableFunctionTest.java  | 70 ++++++++++++++++++----
 .../plan/stream/sql/ProcessTableFunctionTest.xml   | 39 ++++++++++++
 6 files changed, 185 insertions(+), 18 deletions(-)

diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
index 3cffbe80e19..d1d7c15d743 100644
--- 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java
@@ -97,14 +97,27 @@ public final class TypeInferenceUtil {
     /** Casts the call's argument if necessary. */
     public static CallContext castArguments(
             TypeInference typeInference, CallContext callContext, @Nullable 
DataType outputType) {
-        return castArguments(typeInference, callContext, outputType, true);
+        return castArgumentsInternal(typeInference, callContext, outputType, 
true, false);
     }
 
-    private static CallContext castArguments(
+    /**
+     * Casts the call's argument if necessary. Enables casting of table 
arguments during code
+     * generation.
+     */
+    public static CallContext castArguments(
             TypeInference typeInference,
             CallContext callContext,
             @Nullable DataType outputType,
-            boolean throwOnInferInputFailure) {
+            boolean castTableArgs) {
+        return castArgumentsInternal(typeInference, callContext, outputType, 
true, castTableArgs);
+    }
+
+    private static CallContext castArgumentsInternal(
+            TypeInference typeInference,
+            CallContext callContext,
+            @Nullable DataType outputType,
+            boolean throwOnInferInputFailure,
+            boolean castTableArgs) {
         final List<DataType> actualTypes = callContext.getArgumentDataTypes();
 
         typeInference
@@ -120,7 +133,12 @@ public final class TypeInferenceUtil {
                         });
 
         final CastCallContext castCallContext =
-                inferInputTypes(typeInference, callContext, outputType, 
throwOnInferInputFailure);
+                inferInputTypes(
+                        typeInference,
+                        callContext,
+                        outputType,
+                        throwOnInferInputFailure,
+                        castTableArgs);
 
         // final check if the call is valid after casting
         final List<DataType> expectedTypes = 
castCallContext.getArgumentDataTypes();
@@ -311,7 +329,7 @@ public final class TypeInferenceUtil {
                 // We might not be able to infer the input types at this 
moment, if the surrounding
                 // function does not provide an explicit input type strategy.
                 final CallContext adaptedContext =
-                        castArguments(typeInference, callContext, null, false);
+                        castArgumentsInternal(typeInference, callContext, 
null, false, false);
                 return typeInference
                         .getInputTypeStrategy()
                         .inferInputTypes(adaptedContext, false)
@@ -456,7 +474,8 @@ public final class TypeInferenceUtil {
             TypeInference typeInference,
             CallContext callContext,
             @Nullable DataType outputType,
-            boolean throwOnFailure) {
+            boolean throwOnFailure,
+            boolean castTableArgs) {
 
         final CastCallContext castCallContext = new 
CastCallContext(callContext, outputType);
 
@@ -481,7 +500,38 @@ public final class TypeInferenceUtil {
                                                 }
                                                 return null;
                                             }
-                                            return semantics.dataType();
+                                            final DataType expectedType =
+                                                    
expectedArg.getDataType().orElse(null);
+                                            final DataType actualType =
+                                                    
castCallContext.getArgumentDataTypes().get(pos);
+                                            if (expectedType == null) {
+                                                return actualType;
+                                            }
+                                            if (!supportsImplicitCast(
+                                                    
actualType.getLogicalType(),
+                                                    
expectedType.getLogicalType())) {
+                                                if (throwOnFailure) {
+                                                    throw new 
ValidationException(
+                                                            String.format(
+                                                                    "Invalid 
argument value. Argument '%s' expects a typed table, "
+                                                                            + 
"but the provided table is incompatible and cannot "
+                                                                            + 
"be cast to the target type.",
+                                                                    
expectedArg.getName()));
+                                                }
+                                                return null;
+                                            }
+                                            // During planning, it is not 
possible to cast table
+                                            // arguments.
+                                            // This can only happen during 
code generation when a
+                                            // table is evaluated per row and 
casts become "scalar
+                                            // casts".
+                                            // It also ensures that 
pass-through columns (including
+                                            // the PARTITION BY column names) 
are preserved during
+                                            // planning.
+                                            if (castTableArgs) {
+                                                return expectedType;
+                                            }
+                                            return actualType;
                                         } else if 
(expectedArg.is(StaticArgumentTrait.MODEL)) {
                                             final ModelSemantics semantics =
                                                     
callContext.getModelSemantics(pos).orElse(null);
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala
index ea9f5ca4c77..deca0f188e2 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProcessTableRunnerGenerator.scala
@@ -96,7 +96,7 @@ object ProcessTableRunnerGenerator {
     )
     val functionTerm = ctx.addReusableFunction(udf)
     val inference = udf.getTypeInference(dataTypeFactory)
-    val castCallContext = TypeInferenceUtil.castArguments(inference, 
callContext, null)
+    val castCallContext = TypeInferenceUtil.castArguments(inference, 
callContext, null, true)
 
     // Enrich argument types with conversion class
     val enrichedArgumentDataTypes = 
toScala(castCallContext.getArgumentDataTypes)
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
index 96e53484d20..d7a512327d6 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationSqlSerializationTest.java
@@ -70,6 +70,7 @@ public class QueryOperationSqlSerializationTest implements 
TableTestProgramRunne
                 QueryOperationTestPrograms.ACCESSING_NESTED_COLUMN,
                 QueryOperationTestPrograms.PTF_ROW_SEMANTIC_TABLE,
                 QueryOperationTestPrograms.PTF_SET_SEMANTIC_TABLE,
+                QueryOperationTestPrograms.PTF_TYPED_SET_SEMANTIC_TABLE,
                 QueryOperationTestPrograms.PTF_ORDER_BY,
                 QueryOperationTestPrograms.INLINE_FUNCTION_SERIALIZATION,
                 QueryOperationTestPrograms.ML_PREDICT_MODEL_API,
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
index 71e049a84ff..1c59e7d1400 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/api/QueryOperationTestPrograms.java
@@ -29,6 +29,7 @@ import 
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctio
 import 
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.ChainedSendingFunction;
 import 
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.RowSemanticTableFunction;
 import 
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.SetSemanticTableFunction;
+import 
org.apache.flink.table.planner.plan.nodes.exec.stream.ProcessTableFunctionTestUtils.TypedSetSemanticTableFunction;
 import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions;
 import 
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions;
 import org.apache.flink.table.test.program.SinkTestStep;
@@ -1128,6 +1129,32 @@ public class QueryOperationTestPrograms {
                             "sink")
                     .build();
 
+    static final TableTestProgram PTF_TYPED_SET_SEMANTIC_TABLE =
+            TableTestProgram.of("ptf-typed-set-semantic-table", "verifies SQL 
serialization")
+                    .setupSql(BASIC_VALUES)
+                    .setupTableSink(
+                            SinkTestStep.newBuilder("sink")
+                                    .addSchema(KEYED_BASE_SINK_SCHEMA)
+                                    .consumedValues(
+                                            "+I[Bob, {User(s='Bob', i=12), 
1}]",
+                                            "+I[Alice, {User(s='Alice', i=42), 
1}]")
+                                    .build())
+                    .runSql(
+                            "SELECT `$$T_FUNC`.`name`, `$$T_FUNC`.`out` FROM 
TABLE(\n"
+                                    + "    inlineFunction$00(\n"
+                                    + "        (\n"
+                                    + "            SELECT `$$T_SOURCE`.`name`, 
`$$T_SOURCE`.`score` FROM `default_catalog`.`default_database`.`t` $$T_SOURCE\n"
+                                    + "        ) PARTITION BY (`name`), 1, 
DEFAULT, 'TypedSetSemanticTableFunction')\n"
+                                    + ") $$T_FUNC")
+                    .runTableApi(
+                            env ->
+                                    env.fromCall(
+                                            
TypedSetSemanticTableFunction.class,
+                                            
env.from("t").partitionBy($("name")).asArgument("u"),
+                                            lit(1).asArgument("i")),
+                            "sink")
+                    .build();
+
     public static final TableTestProgram ML_PREDICT_MODEL_API =
             TableTestProgram.of("ml-predict-model-api", "ml-predict using 
model API")
                     .setupTableSource(SIMPLE_FEATURES_SOURCE)
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java
index 17a303b93f8..d68141be11e 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.java
@@ -21,7 +21,9 @@ package org.apache.flink.table.planner.plan.stream.sql;
 import org.apache.flink.table.annotation.ArgumentHint;
 import org.apache.flink.table.api.DataTypes;
 import org.apache.flink.table.api.ExplainDetail;
+import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.TableEnvironment;
 import org.apache.flink.table.catalog.DataTypeFactory;
 import org.apache.flink.table.functions.ProcessTableFunction;
 import org.apache.flink.table.functions.TableFunction;
@@ -58,8 +60,11 @@ import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.MethodSource;
 
+import javax.annotation.Nullable;
+
 import java.util.EnumSet;
 import java.util.Optional;
+import java.util.function.Function;
 import java.util.stream.Stream;
 
 import static java.util.Collections.singletonList;
@@ -69,6 +74,8 @@ import static 
org.apache.flink.table.annotation.ArgumentTrait.PASS_COLUMNS_THROU
 import static 
org.apache.flink.table.annotation.ArgumentTrait.ROW_SEMANTIC_TABLE;
 import static 
org.apache.flink.table.annotation.ArgumentTrait.SET_SEMANTIC_TABLE;
 import static org.apache.flink.table.annotation.ArgumentTrait.SUPPORT_UPDATES;
+import static org.apache.flink.table.api.Expressions.$;
+import static org.apache.flink.table.api.Expressions.row;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Tests for the type inference and planning part of {@link 
ProcessTableFunction}. */
@@ -126,11 +133,27 @@ public class ProcessTableFunctionTest extends 
TableTestBase {
     @Test
     void testTypedRowSemanticTableIgnoringColumnNames() {
         util.addTemporarySystemFunction("f", 
TypedRowSemanticTableFunction.class);
-        // function expects <STRING name, INT score>
-        // but table is <STRING name, INT different>
+        // Function expects <STRING s, INT i> but table is <STRING name, INT 
different>
         util.verifyRelPlan("SELECT * FROM f(u => TABLE t_name_diff, i => 1)");
     }
 
+    @Test
+    void testTypedSetSemanticTableForwardingOriginalColumnName() {
+        util.addTemporarySystemFunction("f", 
TypedSetSemanticTableFunction.class);
+        // Function expects <STRING s, INT i> but table is <STRING name, INT 
different>.
+        // The partition key is a valid pass-through column and its name 
should be preserved.
+        util.verifyRelPlan("SELECT * FROM f(u => TABLE t_name_diff PARTITION 
BY name, i => 1)");
+    }
+
+    @Test
+    void testTypedSetSemanticTableForwardingOriginalColumnNameTableApi() {
+        util.addTemporarySystemFunction("f", 
TypedSetSemanticTableFunction.class);
+        // Function expects <STRING s, INT i> but table is <STRING name, INT 
different>.
+        // The partition key is a valid pass-through column and its name 
should be preserved.
+        util.verifyRelPlan(
+                
util.tableEnv().from("t_name_diff").partitionBy($("name")).process("f", 1));
+    }
+
     @Test
     void testDifferentPartitionKey() {
         util.addTemporarySystemFunction("f", SetSemanticTableFunction.class);
@@ -300,11 +323,14 @@ public class ProcessTableFunctionTest extends 
TableTestBase {
                         () -> {
                             if (spec.selectSql != null) {
                                 util.verifyExecPlan(spec.selectSql);
-                            } else {
+                            } else if (spec.insertIntoSql != null) {
                                 util.verifyExecPlan(
                                         util.tableEnv()
                                                 .createStatementSet()
                                                 
.addInsertSql(spec.insertIntoSql));
+                            } else {
+                                assert spec.selectTableApi != null;
+                                
util.verifyExecPlan(spec.selectTableApi.apply(util.tableEnv()));
                             }
                         })
                 .satisfies(anyCauseMatches(spec.errorMessage));
@@ -527,7 +553,18 @@ public class ProcessTableFunctionTest extends 
TableTestBase {
                                 + "in20 => TABLE t PARTITION BY score, "
                                 + "in21 => TABLE t PARTITION BY score"
                                 + ")",
-                        "Unsupported table argument count. Currently, the 
number of input tables is limited to 20."));
+                        "Unsupported table argument count. Currently, the 
number of input tables is limited to 20."),
+                ErrorSpec.ofTableApi(
+                        "invalid scalar args",
+                        TypedRowSemanticTableFunction.class,
+                        env -> env.from("t").process("f", 1.99),
+                        "Invalid argument type at position 1. Data type INT 
expected but DOUBLE NOT NULL passed."),
+                ErrorSpec.ofTableApi(
+                        "invalid table args in Table API",
+                        TypedRowSemanticTableFunction.class,
+                        env -> env.fromValues(row("Bob", 1.99)).process("f", 
1),
+                        "Invalid argument value. Argument 'u' expects a typed 
table, but the provided "
+                                + "table is incompatible and cannot be cast to 
the target type."));
     }
 
     /** Testing function. */
@@ -622,19 +659,22 @@ public class ProcessTableFunctionTest extends 
TableTestBase {
     private static class ErrorSpec {
         private final String description;
         private final Class<? extends UserDefinedFunction> functionClass;
-        private final String selectSql;
-        private final String insertIntoSql;
+        private final @Nullable String selectSql;
+        private final @Nullable String insertIntoSql;
+        private final @Nullable Function<TableEnvironment, Table> 
selectTableApi;
         private final String errorMessage;
 
         private ErrorSpec(
                 String description,
                 Class<? extends UserDefinedFunction> functionClass,
-                String selectSql,
-                String insertIntoSql,
+                @Nullable String selectSql,
+                @Nullable String insertIntoSql,
+                @Nullable Function<TableEnvironment, Table> selectTableApi,
                 String errorMessage) {
             this.description = description;
             this.functionClass = functionClass;
             this.selectSql = selectSql;
+            this.selectTableApi = selectTableApi;
             this.insertIntoSql = insertIntoSql;
             this.errorMessage = errorMessage;
         }
@@ -644,7 +684,7 @@ public class ProcessTableFunctionTest extends TableTestBase 
{
                 Class<? extends UserDefinedFunction> functionClass,
                 String selectSql,
                 String errorMessage) {
-            return new ErrorSpec(description, functionClass, selectSql, null, 
errorMessage);
+            return new ErrorSpec(description, functionClass, selectSql, null, 
null, errorMessage);
         }
 
         static ErrorSpec ofInsertInto(
@@ -652,7 +692,17 @@ public class ProcessTableFunctionTest extends 
TableTestBase {
                 Class<? extends UserDefinedFunction> functionClass,
                 String insertIntoSql,
                 String errorMessage) {
-            return new ErrorSpec(description, functionClass, null, 
insertIntoSql, errorMessage);
+            return new ErrorSpec(
+                    description, functionClass, null, insertIntoSql, null, 
errorMessage);
+        }
+
+        static ErrorSpec ofTableApi(
+                String description,
+                Class<? extends UserDefinedFunction> functionClass,
+                Function<TableEnvironment, Table> selectTableApi,
+                String errorMessage) {
+            return new ErrorSpec(
+                    description, functionClass, null, null, selectTableApi, 
errorMessage);
         }
 
         @Override
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.xml
index f0f340f4eb7..36c40cfc695 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ProcessTableFunctionTest.xml
@@ -291,6 +291,45 @@ LogicalProject(out=[$0])
 ProcessTableFunction(invocation=[f(TABLE(#0), 1, DEFAULT(), DEFAULT())], 
uid=[null], select=[out], rowType=[RecordType(VARCHAR(2147483647) out)])
 +- Calc(select=['Bob' AS name, 12 AS different])
    +- Values(tuples=[[{ 0 }]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testTypedSetSemanticTableForwardingOriginalColumnName">
+    <Resource name="sql">
+      <![CDATA[SELECT * FROM f(u => TABLE t_name_diff PARTITION BY name, i => 
1)]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(name=[$0], out=[$1])
++- LogicalTableFunctionScan(invocation=[f(TABLE(#0) PARTITION BY($0), 1, 
DEFAULT(), DEFAULT())], rowType=[RecordType(CHAR(3) name, VARCHAR(2147483647) 
out)])
+   +- LogicalProject(name=[$0], different=[$1])
+      +- LogicalProject(name=[_UTF-16LE'Bob'], different=[12])
+         +- LogicalValues(tuples=[[{ 0 }]])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+ProcessTableFunction(invocation=[f(TABLE(#0) PARTITION BY($0), 1, DEFAULT(), 
DEFAULT())], uid=[f], select=[name,out], rowType=[RecordType(CHAR(3) name, 
VARCHAR(2147483647) out)])
++- Exchange(distribution=[hash[name]])
+   +- Calc(select=['Bob' AS name, 12 AS different])
+      +- Values(tuples=[[{ 0 }]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase 
name="testTypedSetSemanticTableForwardingOriginalColumnNameTableApi">
+    <Resource name="ast">
+      <![CDATA[
+LogicalTableFunctionScan(invocation=[f(TABLE(#0) PARTITION BY($0), 1, 
DEFAULT(), _UTF-16LE'f')], rowType=[RecordType(CHAR(3) name, 
VARCHAR(2147483647) out)])
++- LogicalProject(name=[_UTF-16LE'Bob'], different=[12])
+   +- LogicalValues(tuples=[[{ 0 }]])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+ProcessTableFunction(invocation=[f(TABLE(#0) PARTITION BY($0), 1, DEFAULT(), 
_UTF-16LE'f')], uid=[f], select=[name,out], rowType=[RecordType(CHAR(3) name, 
VARCHAR(2147483647) out)])
++- Exchange(distribution=[hash[name]])
+   +- Calc(select=['Bob' AS name, 12 AS different])
+      +- Values(tuples=[[{ 0 }]])
 ]]>
     </Resource>
   </TestCase>

Reply via email to