This is an automated email from the ASF dual-hosted git repository.

lincoln pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 54b0d6a858c [FLINK-35827][table-planner] Fix equivalence comparison 
between row type fields and constants
54b0d6a858c is described below

commit 54b0d6a858cd18e57fde60966a01cd673cb6da7a
Author: Xuyang <xyzhong...@163.com>
AuthorDate: Wed Sep 11 16:58:20 2024 +0800

    [FLINK-35827][table-planner] Fix equivalence comparison between row type 
fields and constants
    
    This closes #25229
---
 .../planner/codegen/EqualiserCodeGenerator.scala   | 103 ++++++++++++++-------
 .../planner/codegen/calls/ScalarOperatorGens.scala |  26 +++++-
 .../table/planner/plan/stream/sql/CalcTest.xml     |  81 +++++++++-------
 .../table/planner/plan/stream/table/CalcTest.xml   |  14 +++
 .../table/planner/expressions/RowTypeTest.scala    |   7 ++
 .../table/planner/plan/stream/sql/CalcTest.scala   |  17 +++-
 .../table/planner/plan/stream/table/CalcTest.scala |  18 ++++
 7 files changed, 202 insertions(+), 64 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
index b1f2adc267e..4ae1a443336 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/EqualiserCodeGenerator.scala
@@ -19,6 +19,7 @@ package org.apache.flink.table.planner.codegen
 
 import org.apache.flink.configuration.Configuration
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
+import 
org.apache.flink.table.planner.codegen.EqualiserCodeGenerator.generateRecordEqualiserCode
 import org.apache.flink.table.planner.codegen.Indenter.toISC
 import 
org.apache.flink.table.planner.codegen.calls.ScalarOperatorGens.generateEquals
 import org.apache.flink.table.runtime.generated.{GeneratedRecordEqualiser, 
RecordEqualiser}
@@ -30,14 +31,21 @@ import 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{getFieldTyp
 import scala.annotation.tailrec
 import scala.collection.JavaConverters._
 
-class EqualiserCodeGenerator(fieldTypes: Array[LogicalType], classLoader: 
ClassLoader) {
+class EqualiserCodeGenerator(
+    leftFieldTypes: Array[LogicalType],
+    rightFieldTypes: Array[LogicalType],
+    classLoader: ClassLoader) {
 
   private val RECORD_EQUALISER = className[RecordEqualiser]
   private val LEFT_INPUT = "left"
   private val RIGHT_INPUT = "right"
 
   def this(rowType: RowType, classLoader: ClassLoader) = {
-    this(rowType.getChildren.asScala.toArray, classLoader)
+    this(rowType.getChildren.asScala.toArray, 
rowType.getChildren.asScala.toArray, classLoader)
+  }
+
+  def this(fieldTypes: Array[LogicalType], classLoader: ClassLoader) = {
+    this(fieldTypes, fieldTypes, classLoader)
   }
 
   def generateRecordEqualiser(name: String): GeneratedRecordEqualiser = {
@@ -45,8 +53,8 @@ class EqualiserCodeGenerator(fieldTypes: Array[LogicalType], 
classLoader: ClassL
     val ctx = new CodeGeneratorContext(new Configuration, classLoader)
     val className = newName(ctx, name)
 
-    val equalsMethodCodes = for (idx <- fieldTypes.indices) yield 
generateEqualsMethod(ctx, idx)
-    val equalsMethodCalls = for (idx <- fieldTypes.indices) yield {
+    val equalsMethodCodes = for (idx <- leftFieldTypes.indices) yield 
generateEqualsMethod(ctx, idx)
+    val equalsMethodCalls = for (idx <- leftFieldTypes.indices) yield {
       val methodName = getEqualsMethodName(idx)
       s"""result = result && $methodName($LEFT_INPUT, $RIGHT_INPUT);"""
     }
@@ -93,18 +101,28 @@ class EqualiserCodeGenerator(fieldTypes: 
Array[LogicalType], classLoader: ClassL
       ("boolean", "isNullRight")
     )
 
-    val fieldType = fieldTypes(idx)
-    val fieldTypeTerm = primitiveTypeTermForType(fieldType)
+    val leftFieldType = leftFieldTypes(idx)
+    val leftFieldTypeTerm = primitiveTypeTermForType(leftFieldType)
+    val rightFieldType = rightFieldTypes(idx)
+    val rightFieldTypeTerm = primitiveTypeTermForType(rightFieldType)
+
     val Seq(leftFieldTerm, rightFieldTerm) = ctx.addReusableLocalVariables(
-      (fieldTypeTerm, "leftField"),
-      (fieldTypeTerm, "rightField")
+      (leftFieldTypeTerm, "leftField"),
+      (rightFieldTypeTerm, "rightField")
     )
 
-    val leftReadCode = rowFieldReadAccess(idx, LEFT_INPUT, fieldType)
-    val rightReadCode = rowFieldReadAccess(idx, RIGHT_INPUT, fieldType)
+    val leftReadCode = rowFieldReadAccess(idx, LEFT_INPUT, leftFieldType)
+    val rightReadCode = rowFieldReadAccess(idx, RIGHT_INPUT, rightFieldType)
 
     val (equalsCode, equalsResult) =
-      generateEqualsCode(ctx, fieldType, leftFieldTerm, rightFieldTerm, 
leftNullTerm, rightNullTerm)
+      generateEqualsCode(
+        ctx,
+        leftFieldType,
+        rightFieldType,
+        leftFieldTerm,
+        rightFieldTerm,
+        leftNullTerm,
+        rightNullTerm)
 
     s"""
        |private boolean $methodName($ROW_DATA $LEFT_INPUT, $ROW_DATA 
$RIGHT_INPUT) {
@@ -131,33 +149,27 @@ class EqualiserCodeGenerator(fieldTypes: 
Array[LogicalType], classLoader: ClassL
 
   private def generateEqualsCode(
       ctx: CodeGeneratorContext,
-      fieldType: LogicalType,
+      leftFieldType: LogicalType,
+      rightFieldType: LogicalType,
       leftFieldTerm: String,
       rightFieldTerm: String,
       leftNullTerm: String,
       rightNullTerm: String) = {
     // TODO merge ScalarOperatorGens.generateEquals.
-    if (isInternalPrimitive(fieldType)) {
+    if (isInternalPrimitive(leftFieldType) && 
isInternalPrimitive(rightFieldType)) {
       ("", s"$leftFieldTerm == $rightFieldTerm")
-    } else if (isCompositeType(fieldType)) {
-      val equaliserGenerator =
-        new EqualiserCodeGenerator(getFieldTypes(fieldType).asScala.toArray, 
ctx.classLoader)
-      val generatedEqualiser = 
equaliserGenerator.generateRecordEqualiser("fieldGeneratedEqualiser")
-      val generatedEqualiserTerm =
-        ctx.addReusableObject(generatedEqualiser, "fieldGeneratedEqualiser")
-      val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName
-      val equaliserTerm = newName(ctx, "equaliser")
-      ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = 
null;")
-      ctx.addReusableInitStatement(
-        s"""
-           |$equaliserTerm = ($equaliserTypeTerm)
-           |  
$generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader());
-           |""".stripMargin)
-      ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)")
+    } else if (isCompositeType(leftFieldType) && 
isCompositeType(rightFieldType)) {
+      generateRecordEqualiserCode(
+        ctx,
+        leftFieldType,
+        rightFieldType,
+        leftFieldTerm,
+        rightFieldTerm,
+        "fieldGeneratedEqualiser")
     } else {
-      val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", 
fieldType)
-      val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", 
fieldType)
-      val resultType = new BooleanType(fieldType.isNullable)
+      val left = GeneratedExpression(leftFieldTerm, leftNullTerm, "", 
leftFieldType)
+      val right = GeneratedExpression(rightFieldTerm, rightNullTerm, "", 
rightFieldType)
+      val resultType = new BooleanType(leftFieldType.isNullable || 
rightFieldType.isNullable)
       val gen = generateEquals(ctx, left, right, resultType)
       (gen.code, gen.resultTerm)
     }
@@ -174,3 +186,32 @@ class EqualiserCodeGenerator(fieldTypes: 
Array[LogicalType], classLoader: ClassL
     case _ => false
   }
 }
+
+object EqualiserCodeGenerator {
+
+  def generateRecordEqualiserCode(
+      ctx: CodeGeneratorContext,
+      leftFieldType: LogicalType,
+      rightFieldType: LogicalType,
+      leftFieldTerm: String,
+      rightFieldTerm: String,
+      generatedEqualiserName: String): (String, String) = {
+    val equaliserGenerator =
+      new EqualiserCodeGenerator(
+        getFieldTypes(leftFieldType).asScala.toArray,
+        getFieldTypes(rightFieldType).asScala.toArray,
+        ctx.classLoader)
+    val generatedEqualiser = 
equaliserGenerator.generateRecordEqualiser(generatedEqualiserName)
+    val generatedEqualiserTerm =
+      ctx.addReusableObject(generatedEqualiser, generatedEqualiserName)
+    val equaliserTypeTerm = classOf[RecordEqualiser].getCanonicalName
+    val equaliserTerm = newName(ctx, "equaliser")
+    ctx.addReusableMember(s"private $equaliserTypeTerm $equaliserTerm = null;")
+    ctx.addReusableInitStatement(
+      s"""
+         |$equaliserTerm = ($equaliserTypeTerm)
+         |  
$generatedEqualiserTerm.newInstance(Thread.currentThread().getContextClassLoader());
+         |""".stripMargin)
+    ("", s"$equaliserTerm.equals($leftFieldTerm, $rightFieldTerm)")
+  }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
index 69081918fd2..7590299ba2b 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala
@@ -23,7 +23,7 @@ import org.apache.flink.table.data.binary.BinaryArrayData
 import org.apache.flink.table.data.util.MapDataUtil
 import org.apache.flink.table.data.utils.CastExecutor
 import org.apache.flink.table.data.writer.{BinaryArrayWriter, BinaryRowWriter}
-import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, 
CodeGenException, GeneratedExpression}
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, 
CodeGenException, EqualiserCodeGenerator, GeneratedExpression}
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
 import 
org.apache.flink.table.planner.codegen.GeneratedExpression.{ALWAYS_NULL, 
NEVER_NULL, NO_CODE}
 import org.apache.flink.table.planner.codegen.GenerateUtils._
@@ -413,6 +413,10 @@ object ScalarOperatorGens {
           resultType),
         resultType)
     }
+    // row types
+    else if (isRow(left.resultType) && canEqual) {
+      wrapExpressionIfNonEq(nonEq, generateRowComparison(ctx, left, right, 
resultType), resultType)
+    }
     // multiset types
     else if (isMultiset(left.resultType) && canEqual) {
       val multisetType = left.resultType.asInstanceOf[MultisetType]
@@ -1818,6 +1822,26 @@ object ScalarOperatorGens {
         (stmt, resultTerm)
     }
 
+  private def generateRowComparison(
+      ctx: CodeGeneratorContext,
+      left: GeneratedExpression,
+      right: GeneratedExpression,
+      resultType: LogicalType): GeneratedExpression = {
+    generateCallWithStmtIfArgsNotNull(ctx, resultType, Seq(left, right)) {
+      args =>
+        val leftTerm = args.head
+        val rightTerm = args(1)
+
+        EqualiserCodeGenerator.generateRecordEqualiserCode(
+          ctx,
+          left.resultType,
+          right.resultType,
+          leftTerm,
+          rightTerm,
+          "rowGeneratedEqualiser")
+    }
+  }
+
   // 
------------------------------------------------------------------------------------------
 
   private def generateUnaryOperatorIfNotNull(
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml
index 215a6aa8eea..20aac897cd2 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml
@@ -30,6 +30,37 @@ 
LogicalProject(EXPR$0=[ARRAY(_UTF-16LE'Hi':VARCHAR(2147483647) CHARACTER SET "UT
       <![CDATA[
 Calc(select=[ARRAY('Hi', 'Hello', c) AS EXPR$0])
 +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, 
source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testCalcMergeWithCorrelate">
+    <Resource name="sql">
+      <![CDATA[
+SELECT a, r FROM (
+ SELECT a, random_udf(b) r FROM (
+  select a, b, c1 FROM MyTable, LATERAL TABLE(str_split(c)) AS T(c1)
+ ) t
+)
+WHERE r > 10
+]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(a=[$0], r=[$1])
++- LogicalFilter(condition=[>($1, 10)])
+   +- LogicalProject(a=[$0], r=[random_udf($1)])
+      +- LogicalProject(a=[$0], b=[$1], c1=[$3])
+         +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{2}])
+            :- LogicalTableScan(table=[[default_catalog, default_database, 
MyTable, source: [TestTableSource(a, b, c)]]])
+            +- LogicalTableFunctionScan(invocation=[str_split($cor0.c)], 
rowType=[RecordType(VARCHAR(2147483647) EXPR$0)])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+Calc(select=[a, r], where=[>(r, 10)])
++- Calc(select=[a, random_udf(b) AS r])
+   +- Correlate(invocation=[str_split($cor0.c)], 
correlate=[table(str_split($cor0.c))], select=[a,b,c,EXPR$0], 
rowType=[RecordType(BIGINT a, INTEGER b, VARCHAR(2147483647) c, 
VARCHAR(2147483647) EXPR$0)], joinType=[INNER])
+      +- LegacyTableSourceScan(table=[[default_catalog, default_database, 
MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
 ]]>
     </Resource>
   </TestCase>
@@ -69,37 +100,6 @@ LogicalProject(a=[$0])
       <![CDATA[
 Calc(select=[a], where=[>(random_udf(b), 10)])
 +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, 
source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
-]]>
-    </Resource>
-  </TestCase>
-  <TestCase name="testCalcMergeWithCorrelate">
-    <Resource name="sql">
-      <![CDATA[
-SELECT a, r FROM (
- SELECT a, random_udf(b) r FROM (
-  select a, b, c1 FROM MyTable, LATERAL TABLE(str_split(c)) AS T(c1)
- ) t
-)
-WHERE r > 10
-]]>
-    </Resource>
-    <Resource name="ast">
-      <![CDATA[
-LogicalProject(a=[$0], r=[$1])
-+- LogicalFilter(condition=[>($1, 10)])
-   +- LogicalProject(a=[$0], r=[random_udf($1)])
-      +- LogicalProject(a=[$0], b=[$1], c1=[$3])
-         +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{2}])
-            :- LogicalTableScan(table=[[default_catalog, default_database, 
MyTable, source: [TestTableSource(a, b, c)]]])
-            +- LogicalTableFunctionScan(invocation=[str_split($cor0.c)], 
rowType=[RecordType(VARCHAR(2147483647) EXPR$0)])
-]]>
-    </Resource>
-    <Resource name="optimized rel plan">
-       <![CDATA[
-Calc(select=[a, r], where=[>(r, 10)])
-+- Calc(select=[a, random_udf(b) AS r])
-   +- Correlate(invocation=[str_split($cor0.c)], 
correlate=[table(str_split($cor0.c))], select=[a,b,c,EXPR$0], 
rowType=[RecordType(BIGINT a, INTEGER b, VARCHAR(2147483647) c, 
VARCHAR(2147483647) EXPR$0)], joinType=[INNER])
-      +- LegacyTableSourceScan(table=[[default_catalog, default_database, 
MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
 ]]>
     </Resource>
   </TestCase>
@@ -496,6 +496,25 @@ LogicalProject(1-_./Ü=[$0], b=[$1], c=[$2])
     <Resource name="optimized exec plan">
       <![CDATA[
 LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, 
source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testRowTypeEquality">
+    <Resource name="sql">
+      <![CDATA[
+SELECT my_row = ROW(1, 'str') from src
+]]>
+    </Resource>
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(EXPR$0=[=(CAST($0):RecordType(INTEGER a, VARCHAR(2147483647) 
CHARACTER SET "UTF-16LE" b), CAST(ROW(1, _UTF-16LE'str')):RecordType(INTEGER a, 
VARCHAR(2147483647) CHARACTER SET "UTF-16LE" b) NOT NULL)])
++- LogicalTableScan(table=[[default_catalog, default_database, src]])
+]]>
+    </Resource>
+    <Resource name="optimized exec plan">
+      <![CDATA[
+Calc(select=[(CAST(my_row AS RecordType(INTEGER a, VARCHAR(2147483647) b)) = 
CAST(ROW(1, 'str') AS RecordType(INTEGER a, VARCHAR(2147483647) b))) AS EXPR$0])
++- TableSourceScan(table=[[default_catalog, default_database, src]], 
fields=[my_row])
 ]]>
     </Resource>
   </TestCase>
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/CalcTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/CalcTest.xml
index f68e26af411..60f5f30a142 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/CalcTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/CalcTest.xml
@@ -116,6 +116,20 @@ LogicalFilter(condition=[OR(SEARCH($1, Sarg[(-∞..1), 
(1..2), (2..3), (3..4), (
       <![CDATA[
 Calc(select=[a, b, c], where=[(SEARCH(b, Sarg[(-∞..1), (1..2), (2..3), (3..4), 
(4..5), (5..6), (6..7), (7..8), (8..9), (9..10), (10..11), (11..12), (12..13), 
(13..14), (14..15), (15..16), (16..17), (17..18), (18..19), (19..20), (20..21), 
(21..22), (22..23), (23..24), (24..25), (25..26), (26..27), (27..28), (28..29), 
(29..30), (30..+∞)]) OR (c <> 'xx'))])
 +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, 
source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testRowTypeEquality">
+    <Resource name="ast">
+      <![CDATA[
+LogicalProject(_c0=[=($0, ROW(1, _UTF-16LE'str'))])
++- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+]]>
+    </Resource>
+    <Resource name="optimized exec plan">
+      <![CDATA[
+Calc(select=[(my_row = ROW(1, 'str')) AS _c0])
++- TableSourceScan(table=[[default_catalog, default_database, MyTable]], 
fields=[my_row])
 ]]>
     </Resource>
   </TestCase>
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/RowTypeTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/RowTypeTest.scala
index e179be85802..58ccdf1104d 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/RowTypeTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/RowTypeTest.scala
@@ -110,4 +110,11 @@ class RowTypeTest extends RowTypeTestBase {
           ))
       .withMessageContaining("Cast function cannot convert value")
   }
+
+  @Test
+  def testRowTypeEquality(): Unit = {
+    testAllApis('f2 === row(2, "foo", true), "f2 = row(2, 'foo', true)", 
"TRUE")
+
+    testAllApis('f3 === row(3, row(2, "foo", true)), "f3 = row(3, row(2, 
'foo', true))", "TRUE")
+  }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala
index 1c62fc054e2..7df283295db 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/CalcTest.scala
@@ -23,7 +23,7 @@ import org.apache.flink.api.scala._
 import org.apache.flink.table.api._
 import org.apache.flink.table.planner.plan.utils.MyPojo
 import 
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.NonDeterministicUdf
-import 
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.{JavaTableFunc1,
 StringSplit}
+import 
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedTableFunctions.StringSplit
 import org.apache.flink.table.planner.utils.TableTestBase
 
 import org.assertj.core.api.Assertions.assertThatExceptionOfType
@@ -217,4 +217,19 @@ class CalcTest extends TableTestBase {
         |""".stripMargin
     util.verifyRelPlan(sqlQuery)
   }
+
+  @Test
+  def testRowTypeEquality(): Unit = {
+    util.addTable(s"""
+                     |CREATE TABLE src (
+                     |  my_row ROW(a INT, b STRING)
+                     |) WITH (
+                     |  'connector' = 'values'
+                     |  )
+                     |""".stripMargin)
+
+    util.verifyExecPlan(s"""
+                           |SELECT my_row = ROW(1, 'str') from src
+                           |""".stripMargin)
+  }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/table/CalcTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/table/CalcTest.scala
index 554eeba0a38..e5cacf54ded 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/table/CalcTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/table/CalcTest.scala
@@ -155,4 +155,22 @@ class CalcTest extends TableTestBase {
 
     util.verifyExecPlan(resultTable)
   }
+
+  @Test
+  def testRowTypeEquality(): Unit = {
+    val util = streamTestUtil()
+    util.addTable(s"""
+                     |CREATE TABLE MyTable (
+                     |  my_row ROW(a INT, b STRING)
+                     |) WITH (
+                     |  'connector' = 'values'
+                     |  )
+                     |""".stripMargin)
+
+    val resultTable = util.tableEnv
+      .from("MyTable")
+      .select('my_row === row(1, "str"))
+
+    util.verifyExecPlan(resultTable)
+  }
 }

Reply via email to