This is an automated email from the ASF dual-hosted git repository. lincoln pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 54b0d6a858c [FLINK-35827][table-planner] Fix equivalence comparison between row type fields and constants 54b0d6a858c is described below commit 54b0d6a858cd18e57fde60966a01cd673cb6da7a Author: Xuyang <xyzhong...@163.com> AuthorDate: Wed Sep 11 16:58:20 2024 +0800 [FLINK-35827][table-planner] Fix equivalence comparison between row type fields and constants This closes #25229 --- .../planner/codegen/EqualiserCodeGenerator.scala | 103 ++++++++++++++------- .../planner/codegen/calls/ScalarOperatorGens.scala | 26 +++++- .../table/planner/plan/stream/sql/CalcTest.xml | 81 +++++++++------- .../table/planner/plan/stream/table/CalcTest.xml | 14 +++ .../table/planner/expressions/RowTypeTest.scala | 7 ++ .../table/planner/plan/stream/sql/CalcTest.scala | 17 +++- .../table/planner/plan/stream/table/CalcTest.scala | 18 ++++ 7 files changed, 202 insertions(+), 64 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala index b1f2adc267e..4ae1a443336 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.planner.codegen import org.apache.flink.configuration.Configuration import org.apache.flink.table.planner.codegen.CodeGenUtils._ +import org.apache.flink.table.planner.codegen.EqualiserCodeGenerator.generateRecordEqualiserCode import org.apache.flink.table.planner.codegen.Indenter.toISC import org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.generateEquals import org.apache.flink.table.runtime.generated.{GeneratedRecordEqualiser, RecordEqualiser} @@ -30,14 +31,21 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldTyp import scala.annotation.tailrec import scala.collection.JavaConverters._ -class EqualiserCodeGenerator(fieldTypes: Array[LogicalType], classLoader: ClassLoader) { +class EqualiserCodeGenerator( + leftFieldTypes: Array[LogicalType], + rightFieldTypes: Array[LogicalType], + classLoader: ClassLoader) { private val RECORD_EQUALISER = className[RecordEqualiser] private val LEFT_INPUT = "left" private val RIGHT_INPUT = "right" def this(rowType: RowType, classLoader: ClassLoader) = { - this(rowType.getChildren.asScala.toArray, classLoader) + this(rowType.getChildren.asScala.toArray, rowType.getChildren.asScala.toArray, classLoader) + } + + def this(fieldTypes: Array[LogicalType], classLoader: ClassLoader) = { + this(fieldTypes, fieldTypes, classLoader) } def generateRecordEqualiser(name: String): GeneratedRecordEqualiser = { @@ -45,8 +53,8 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType], classLoader: ClassL val ctx = new CodeGeneratorContext(new Configuration, classLoader) val className = newName(ctx, name) - val equalsMethodCodes = for (idx <- fieldTypes.indices) yield generateEqualsMethod(ctx, idx) - val equalsMethodCalls = for (idx <- fieldTypes.indices) yield { + val equalsMethodCodes = for (idx <- leftFieldTypes.indices) yield generateEqualsMethod(ctx, idx) + val equalsMethodCalls = for (idx <- leftFieldTypes.indices) yield { val methodName = getEqualsMethodName(idx) s"""result = result && $methodName($LEFT_INPUT, $RIGHT_INPUT);""" } @@ -93,18 +101,28 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType], classLoader: ClassL ("boolean", "isNullRight") ) - val fieldType = fieldTypes(idx) - val fieldTypeTerm = primitiveTypeTermForType(fieldType) + val leftFieldType = leftFieldTypes(idx) + val leftFieldTypeTerm = primitiveTypeTermForType(leftFieldType) + val rightFieldType = rightFieldTypes(idx) + val rightFieldTypeTerm = primitiveTypeTermForType(rightFieldType) + val Seq(leftFieldTerm, rightFieldTerm) = ctx.addReusableLocalVariables( - (fieldTypeTerm, "leftField"), - (fieldTypeTerm, "rightField") + (leftFieldTypeTerm, "leftField"), + (rightFieldTypeTerm, "rightField") ) - val leftReadCode = rowFieldReadAccess(idx, LEFT_INPUT, fieldType) - val rightReadCode = rowFieldReadAccess(idx, RIGHT_INPUT, fieldType) + val leftReadCode = rowFieldReadAccess(idx, LEFT_INPUT, leftFieldType) + val rightReadCode = rowFieldReadAccess(idx, RIGHT_INPUT, rightFieldType) val (equalsCode, equalsResult) = - generateEqualsCode(ctx, fieldType, leftFieldTerm, rightFieldTerm, leftNullTerm, rightNullTerm) + generateEqualsCode( + ctx, + leftFieldType, + rightFieldType, + leftFieldTerm, + rightFieldTerm, + leftNullTerm, + rightNullTerm) s""" |private boolean $methodName($ROW_DATA $LEFT_INPUT, $ROW_DATA $RIGHT_INPUT) { @@ -131,33 +149,27 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType], classLoader: ClassL private def generateEqualsCode( ctx: CodeGeneratorContext, - fieldType: LogicalType, + leftFieldType: LogicalType, + rightFieldType: LogicalType, leftFieldTerm: String, rightFieldTerm: String, leftNullTerm: String, rightNullTerm: String) = { // TODO merge ScalarOperatorGens.generateEquals. - if (isInternalPrimitive(fieldType)) { + if (isInternalPrimitive(leftFieldType) && isInternalPrimitive(rightFieldType)) { ("", s"$leftFieldTerm == $rightFieldTerm") - } else if (isCompositeType(fieldType)) { - val equaliserGenerator = - new EqualiserCodeGenerator(getFieldTypes(fieldType).asScala.toArray, ctx.classLoader) - val generatedEqualiser = equaliserGenerator.generateRecordEqualiser("fieldGeneratedEqualiser") - val generatedEqualiserTerm = - ctx.addReusableObject(generatedEqualiser, "fieldGeneratedEqualiser") - val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName - val equaliserTerm = newName(ctx, "equaliser") - ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") - ctx.addReusableInitStatement( - s""" - |$equaliserTerm = ($equaliserTypeTerm) - | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); - |""".stripMargin) - ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)") + } else if (isCompositeType(leftFieldType) && isCompositeType(rightFieldType)) { + generateRecordEqualiserCode( + ctx, + leftFieldType, + rightFieldType, + leftFieldTerm, + rightFieldTerm, + "fieldGeneratedEqualiser") } else { - val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", fieldType) - val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", fieldType) - val resultType = new BooleanType(fieldType.isNullable) + val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", leftFieldType) + val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", rightFieldType) + val resultType = new BooleanType(leftFieldType.isNullable || rightFieldType.isNullable) val gen = generateEquals(ctx, left, right, resultType) (gen.code, gen.resultTerm) } @@ -174,3 +186,32 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType], classLoader: ClassL case _ => false } } + +object EqualiserCodeGenerator { + + def generateRecordEqualiserCode( + ctx: CodeGeneratorContext, + leftFieldType: LogicalType, + rightFieldType: LogicalType, + leftFieldTerm: String, + rightFieldTerm: String, + generatedEqualiserName: String): (String, String) = { + val equaliserGenerator = + new EqualiserCodeGenerator( + getFieldTypes(leftFieldType).asScala.toArray, + getFieldTypes(rightFieldType).asScala.toArray, + ctx.classLoader) + val generatedEqualiser = equaliserGenerator.generateRecordEqualiser(generatedEqualiserName) + val generatedEqualiserTerm = + ctx.addReusableObject(generatedEqualiser, generatedEqualiserName) + val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName + val equaliserTerm = newName(ctx, "equaliser") + ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;") + ctx.addReusableInitStatement( + s""" + |$equaliserTerm = ($equaliserTypeTerm) + | $generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader()); + |""".stripMargin) + ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)") + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala index 69081918fd2..7590299ba2b 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala @@ -23,7 +23,7 @@ import org.apache.flink.table.data.binary.BinaryArrayData import org.apache.flink.table.data.util.MapDataUtil import org.apache.flink.table.data.utils.CastExecutor import org.apache.flink.table.data.writer.{BinaryArrayWriter, BinaryRowWriter} -import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenException, GeneratedExpression} +import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, CodeGenException, EqualiserCodeGenerator, GeneratedExpression} import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.GeneratedExpression.{ALWAYS_NULL, NEVER_NULL, NO_CODE} import org.apache.flink.table.planner.codegen.GenerateUtils._ @@ -413,6 +413,10 @@ object ScalarOperatorGens { resultType), resultType) } + // row types + else if (isRow(left.resultType) && canEqual) { + wrapExpressionIfNonEq(nonEq, generateRowComparison(ctx, left, right, resultType), resultType) + } // multiset types else if (isMultiset(left.resultType) && canEqual) { val multisetType = left.resultType.asInstanceOf[MultisetType] @@ -1818,6 +1822,26 @@ object ScalarOperatorGens { (stmt, resultTerm) } + private def generateRowComparison( + ctx: CodeGeneratorContext, + left: GeneratedExpression, + right: GeneratedExpression, + resultType: LogicalType): GeneratedExpression = { + generateCallWithStmtIfArgsNotNull(ctx, resultType, Seq(left, right)) { + args => + val leftTerm = args.head + val rightTerm = args(1) + + EqualiserCodeGenerator.generateRecordEqualiserCode( + ctx, + left.resultType, + right.resultType, + leftTerm, + rightTerm, + "rowGeneratedEqualiser") + } + } + // ------------------------------------------------------------------------------------------ private def generateUnaryOperatorIfNotNull( diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml index 215a6aa8eea..20aac897cd2 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml @@ -30,6 +30,37 @@ LogicalProject(EXPR$0=[ARRAY(_UTF-16LE'Hi':VARCHAR(2147483647) CHARACTER SET "UT <![CDATA[ Calc(select=[ARRAY('Hi', 'Hello', c) AS EXPR$0]) +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + </Resource> + </TestCase> + <TestCase name="testCalcMergeWithCorrelate"> + <Resource name="sql"> + <![CDATA[ +SELECT a, r FROM ( + SELECT a, random_udf(b) r FROM ( + select a, b, c1 FROM MyTable, LATERAL TABLE(str_split(c)) AS T(c1) + ) t +) +WHERE r > 10 +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(a=[$0], r=[$1]) ++- LogicalFilter(condition=[>($1, 10)]) + +- LogicalProject(a=[$0], r=[random_udf($1)]) + +- LogicalProject(a=[$0], b=[$1], c1=[$3]) + +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}]) + :- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]]) + +- LogicalTableFunctionScan(invocation=[str_split($cor0.c)], rowType=[RecordType(VARCHAR(2147483647) EXPR$0)]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Calc(select=[a, r], where=[>(r, 10)]) ++- Calc(select=[a, random_udf(b) AS r]) + +- Correlate(invocation=[str_split($cor0.c)], correlate=[table(str_split($cor0.c))], select=[a,b,c,EXPR$0], rowType=[RecordType(BIGINT a, INTEGER b, VARCHAR(2147483647) c, VARCHAR(2147483647) EXPR$0)], joinType=[INNER]) + +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> </Resource> </TestCase> @@ -69,37 +100,6 @@ LogicalProject(a=[$0]) <![CDATA[ Calc(select=[a], where=[>(random_udf(b), 10)]) +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) -]]> - </Resource> - </TestCase> - <TestCase name="testCalcMergeWithCorrelate"> - <Resource name="sql"> - <![CDATA[ -SELECT a, r FROM ( - SELECT a, random_udf(b) r FROM ( - select a, b, c1 FROM MyTable, LATERAL TABLE(str_split(c)) AS T(c1) - ) t -) -WHERE r > 10 -]]> - </Resource> - <Resource name="ast"> - <![CDATA[ -LogicalProject(a=[$0], r=[$1]) -+- LogicalFilter(condition=[>($1, 10)]) - +- LogicalProject(a=[$0], r=[random_udf($1)]) - +- LogicalProject(a=[$0], b=[$1], c1=[$3]) - +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{2}]) - :- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]]) - +- LogicalTableFunctionScan(invocation=[str_split($cor0.c)], rowType=[RecordType(VARCHAR(2147483647) EXPR$0)]) -]]> - </Resource> - <Resource name="optimized rel plan"> - <![CDATA[ -Calc(select=[a, r], where=[>(r, 10)]) -+- Calc(select=[a, random_udf(b) AS r]) - +- Correlate(invocation=[str_split($cor0.c)], correlate=[table(str_split($cor0.c))], select=[a,b,c,EXPR$0], rowType=[RecordType(BIGINT a, INTEGER b, VARCHAR(2147483647) c, VARCHAR(2147483647) EXPR$0)], joinType=[INNER]) - +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) ]]> </Resource> </TestCase> @@ -496,6 +496,25 @@ LogicalProject(1-_./Ü=[$0], b=[$1], c=[$2]) <Resource name="optimized exec plan"> <![CDATA[ LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + </Resource> + </TestCase> + <TestCase name="testRowTypeEquality"> + <Resource name="sql"> + <![CDATA[ +SELECT my_row = ROW(1, 'str') from src +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(EXPR$0=[=(CAST($0):RecordType(INTEGER a, VARCHAR(2147483647) CHARACTER SET "UTF-16LE" b), CAST(ROW(1, _UTF-16LE'str')):RecordType(INTEGER a, VARCHAR(2147483647) CHARACTER SET "UTF-16LE" b) NOT NULL)]) ++- LogicalTableScan(table=[[default_catalog, default_database, src]]) +]]> + </Resource> + <Resource name="optimized exec plan"> + <![CDATA[ +Calc(select=[(CAST(my_row AS RecordType(INTEGER a, VARCHAR(2147483647) b)) = CAST(ROW(1, 'str') AS RecordType(INTEGER a, VARCHAR(2147483647) b))) AS EXPR$0]) ++- TableSourceScan(table=[[default_catalog, default_database, src]], fields=[my_row]) ]]> </Resource> </TestCase> diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/CalcTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/CalcTest.xml index f68e26af411..60f5f30a142 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/CalcTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/CalcTest.xml @@ -116,6 +116,20 @@ LogicalFilter(condition=[OR(SEARCH($1, Sarg[(-∞..1), (1..2), (2..3), (3..4), ( <![CDATA[ Calc(select=[a, b, c], where=[(SEARCH(b, Sarg[(-∞..1), (1..2), (2..3), (3..4), (4..5), (5..6), (6..7), (7..8), (8..9), (9..10), (10..11), (11..12), (12..13), (13..14), (14..15), (15..16), (16..17), (17..18), (18..19), (19..20), (20..21), (21..22), (22..23), (23..24), (24..25), (25..26), (26..27), (27..28), (28..29), (29..30), (30..+∞)]) OR (c <> 'xx'))]) +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c]) +]]> + </Resource> + </TestCase> + <TestCase name="testRowTypeEquality"> + <Resource name="ast"> + <![CDATA[ +LogicalProject(_c0=[=($0, ROW(1, _UTF-16LE'str'))]) ++- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) +]]> + </Resource> + <Resource name="optimized exec plan"> + <![CDATA[ +Calc(select=[(my_row = ROW(1, 'str')) AS _c0]) ++- TableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[my_row]) ]]> </Resource> </TestCase> diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/RowTypeTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/RowTypeTest.scala index e179be85802..58ccdf1104d 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/RowTypeTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/RowTypeTest.scala @@ -110,4 +110,11 @@ class RowTypeTest extends RowTypeTestBase { )) .withMessageContaining("Cast function cannot convert value") } + + @Test + def testRowTypeEquality(): Unit = { + testAllApis('f2 === row(2, "foo", true), "f2 = row(2, 'foo', true)", "TRUE") + + testAllApis('f3 === row(3, row(2, "foo", true)), "f3 = row(3, row(2, 'foo', true))", "TRUE") + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala index 1c62fc054e2..7df283295db 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala @@ -23,7 +23,7 @@ import org.apache.flink.api.scala._ import org.apache.flink.table.api._ import org.apache.flink.table.planner.plan.utils.MyPojo import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.NonDeterministicUdf -import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.{JavaTableFunc1, StringSplit} +import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.StringSplit import org.apache.flink.table.planner.utils.TableTestBase import org.assertj.core.api.Assertions.assertThatExceptionOfType @@ -217,4 +217,19 @@ class CalcTest extends TableTestBase { |""".stripMargin util.verifyRelPlan(sqlQuery) } + + @Test + def testRowTypeEquality(): Unit = { + util.addTable(s""" + |CREATE TABLE src ( + | my_row ROW(a INT, b STRING) + |) WITH ( + | 'connector' = 'values' + | ) + |""".stripMargin) + + util.verifyExecPlan(s""" + |SELECT my_row = ROW(1, 'str') from src + |""".stripMargin) + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/table/CalcTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/table/CalcTest.scala index 554eeba0a38..e5cacf54ded 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/table/CalcTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/table/CalcTest.scala @@ -155,4 +155,22 @@ class CalcTest extends TableTestBase { util.verifyExecPlan(resultTable) } + + @Test + def testRowTypeEquality(): Unit = { + val util = streamTestUtil() + util.addTable(s""" + |CREATE TABLE MyTable ( + | my_row ROW(a INT, b STRING) + |) WITH ( + | 'connector' = 'values' + | ) + |""".stripMargin) + + val resultTable = util.tableEnv + .from("MyTable") + .select('my_row === row(1, "str")) + + util.verifyExecPlan(resultTable) + } }