This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.9 by this push:
     new 632a6db  [FLINK-14200][table] Fix NPE for Temporal Table Function Join 
when left side is a query instead of a source (#10782)
632a6db is described below

commit 632a6dbebb81c950f0c72da64bdc7e5112820a9d
Author: Jark Wu <[email protected]>
AuthorDate: Tue Jan 7 17:48:46 2020 +0800

    [FLINK-14200][table] Fix NPE for Temporal Table Function Join when left 
side is a query instead of a source (#10782)
---
 .../catalog/QueryOperationCatalogViewTable.java    |  2 +-
 .../table/planner/calcite/FlinkRelBuilder.scala    |  4 +-
 ...relateToJoinFromTemporalTableFunctionRule.scala | 17 +++-
 .../plan/stream/sql/MiniBatchIntervalInferTest.xml | 42 ++++++++++
 .../plan/stream/sql/join/TemporalJoinTest.xml      | 27 ++++++
 .../stream/sql/MiniBatchIntervalInferTest.scala    |  3 +-
 .../plan/stream/sql/join/TemporalJoinTest.scala    | 14 ++++
 .../runtime/stream/sql/TemporalJoinITCase.scala    | 93 +++++++++++++++++++--
 .../catalog/QueryOperationCatalogViewTable.java    |  2 +-
 .../flink/table/calcite/FlinkRelBuilder.scala      |  4 +-
 .../LogicalCorrelateToTemporalTableJoinRule.scala  | 17 +++-
 .../api/stream/sql/TemporalTableJoinTest.scala     | 29 +++++++
 .../runtime/stream/sql/TemporalJoinITCase.scala    | 95 ++++++++++++++++++++--
 13 files changed, 322 insertions(+), 27 deletions(-)

diff --git 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/QueryOperationCatalogViewTable.java
 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/QueryOperationCatalogViewTable.java
index 1251dc4..5d7be6a 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/QueryOperationCatalogViewTable.java
+++ 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/QueryOperationCatalogViewTable.java
@@ -79,7 +79,7 @@ public class QueryOperationCatalogViewTable extends 
FlinkTable implements Transl
 
        @Override
        public RelNode toRel(RelOptTable.ToRelContext context, RelOptTable 
relOptTable) {
-               FlinkRelBuilder relBuilder = 
FlinkRelBuilder.of(context.getCluster(), relOptTable);
+               FlinkRelBuilder relBuilder = 
FlinkRelBuilder.of(context.getCluster(), relOptTable.getRelOptSchema());
 
                RelNode relNode = 
relBuilder.queryOperation(catalogView.getQueryOperation()).build();
                return RelOptUtil.createCastRel(relNode, 
rowType.apply(relBuilder.getTypeFactory()), false);
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
index 31f8079..e2dff7f 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
@@ -159,11 +159,11 @@ object FlinkRelBuilder {
     }
   }
 
-  def of(cluster: RelOptCluster, relTable: RelOptTable): FlinkRelBuilder = {
+  def of(cluster: RelOptCluster, relOptSchema: RelOptSchema): FlinkRelBuilder 
= {
     val clusterContext = cluster.getPlanner.getContext
     new FlinkRelBuilder(
       clusterContext,
       cluster,
-      relTable.getRelOptSchema)
+      relOptSchema)
   }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
index 62a4872..b40a7dd 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
@@ -32,8 +32,9 @@ import 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{hasRoot, is
 import org.apache.flink.util.Preconditions.checkState
 
 import org.apache.calcite.plan.RelOptRule.{any, none, operand, some}
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
-import org.apache.calcite.rel.RelNode
+import org.apache.calcite.plan.hep.HepRelVertex
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptSchema}
+import org.apache.calcite.rel.{BiRel, RelNode, SingleRel}
 import org.apache.calcite.rel.core.{JoinRelType, TableFunctionScan}
 import org.apache.calcite.rel.logical.LogicalCorrelate
 import org.apache.calcite.rex._
@@ -95,7 +96,7 @@ class LogicalCorrelateToJoinFromTemporalTableFunctionRule
           .getUnderlyingHistoryTable
         val rexBuilder = cluster.getRexBuilder
 
-        val relBuilder = FlinkRelBuilder.of(cluster, leftNode.getTable)
+        val relBuilder = FlinkRelBuilder.of(cluster, getRelOptSchema(leftNode))
         val temporalTable: RelNode = 
relBuilder.queryOperation(underlyingHistoryTable).build()
         // expand QueryOperationCatalogViewTable in TableScan
         val shuttle = new ExpandTableScanShuttle
@@ -145,6 +146,16 @@ class LogicalCorrelateToJoinFromTemporalTableFunctionRule
     rexBuilder.makeInputRef(
       rightDataTypeField.getType, rightReferencesOffset + 
rightDataTypeField.getIndex)
   }
+
+  /**
+    * Gets [[RelOptSchema]] from the leaf [[RelNode]] which holds a non-null 
[[RelOptSchema]].
+    */
+  private def getRelOptSchema(relNode: RelNode): RelOptSchema = relNode match {
+    case hep: HepRelVertex => getRelOptSchema(hep.getCurrentRel)
+    case single: SingleRel => getRelOptSchema(single.getInput)
+    case bi: BiRel => getRelOptSchema(bi.getLeft)
+    case _ => relNode.getTable.getRelOptSchema
+  }
 }
 
 object LogicalCorrelateToJoinFromTemporalTableFunctionRule {
diff --git 
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.xml
 
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.xml
index b82a3d5..0b5f851 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.xml
+++ 
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.xml
@@ -448,6 +448,48 @@ GlobalGroupAggregate(groupBy=[cnt], select=[cnt, 
COUNT(count$0) AS EXPR$1])
 ]]>
     </Resource>
   </TestCase>
+  <TestCase name="testTemporalTableFunctionJoinWithMiniBatch">
+    <Resource name="sql">
+      <![CDATA[
+ SELECT r_a, COUNT(o_a)
+ FROM (
+   SELECT o.a as o_a, r.a as r_a
+   FROM Orders As o,
+   LATERAL TABLE (Rates(o.rowtime)) as r
+   WHERE o.b = r.b
+ )
+ GROUP BY r_a
+      ]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[COUNT($1)])
++- LogicalProject(r_a=[$5], o_a=[$0])
+   +- LogicalFilter(condition=[=($1, $6)])
+      +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{4}])
+         :- LogicalWatermarkAssigner(fields=[a, b, c, proctime, rowtime], 
rowtimeField=[rowtime], watermarkDelay=[0])
+         :  +- LogicalTableScan(table=[[default_catalog, default_database, 
MyTable1]])
+         +- LogicalTableFunctionScan(invocation=[Rates($cor0.rowtime)], 
rowType=[RecordType(INTEGER a, VARCHAR(2147483647) b, BIGINT c, TIME 
ATTRIBUTE(PROCTIME) proctime, TIME ATTRIBUTE(ROWTIME) rowtime)], 
elementType=[class [Ljava.lang.Object;])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+GlobalGroupAggregate(groupBy=[r_a], select=[r_a, COUNT(count$0) AS EXPR$1])
++- Exchange(distribution=[hash[r_a]])
+   +- LocalGroupAggregate(groupBy=[r_a], select=[r_a, COUNT(o_a) AS count$0])
+      +- Calc(select=[a0 AS r_a, a AS o_a])
+         +- TemporalJoin(joinType=[InnerJoin], 
where=[AND(__TEMPORAL_JOIN_CONDITION(rowtime, rowtime0, b0), =(b, b0))], 
select=[a, b, rowtime, a0, b0, rowtime0])
+            :- Exchange(distribution=[hash[b]])
+            :  +- Calc(select=[a, b, rowtime])
+            :     +- WatermarkAssigner(fields=[a, b, c, proctime, rowtime], 
rowtimeField=[rowtime], watermarkDelay=[0], miniBatchInterval=[Rowtime, 1000ms])
+            :        +- DataStreamScan(table=[[default_catalog, 
default_database, MyTable1]], fields=[a, b, c, proctime, rowtime])
+            +- Exchange(distribution=[hash[b]])
+               +- Calc(select=[a, b, rowtime])
+                  +- WatermarkAssigner(fields=[a, b, c, proctime, rowtime], 
rowtimeField=[rowtime], watermarkDelay=[0], miniBatchInterval=[Rowtime, 1000ms])
+                     +- DataStreamScan(table=[[default_catalog, 
default_database, MyTable2]], fields=[a, b, c, proctime, rowtime])
+]]>
+    </Resource>
+  </TestCase>
   <TestCase name="testWindowJoinWithMiniBatch">
     <Resource name="sql">
       <![CDATA[
diff --git 
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.xml
 
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.xml
index 9533bb2..31b93bf 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.xml
+++ 
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.xml
@@ -74,6 +74,33 @@ Join(joinType=[InnerJoin], where=[=(t3_secondary_key, 
secondary_key)], select=[r
 ]]>
     </Resource>
   </TestCase>
+  <TestCase name="testJoinOnQueryLeft">
+    <Resource name="sql">
+      <![CDATA[SELECT o_amount * rate as rate FROM Orders2 AS o, LATERAL TABLE 
(Rates(o.o_rowtime)) AS r WHERE currency = o_currency]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(rate=[*($0, $4)])
++- LogicalFilter(condition=[=($3, $1)])
+   +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{2}])
+      :- LogicalProject(o_amount=[$0], o_currency=[$1], o_rowtime=[$2])
+      :  +- LogicalFilter(condition=[>($0, 1000)])
+      :     +- LogicalTableScan(table=[[default_catalog, default_database, 
Orders]])
+      +- LogicalTableFunctionScan(invocation=[Rates($cor0.o_rowtime)], 
rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME 
ATTRIBUTE(ROWTIME) rowtime)], elementType=[class [Ljava.lang.Object;])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+Calc(select=[*(o_amount, rate) AS rate])
++- TemporalJoin(joinType=[InnerJoin], 
where=[AND(__TEMPORAL_JOIN_CONDITION(o_rowtime, rowtime, currency), =(currency, 
o_currency))], select=[o_amount, o_currency, o_rowtime, currency, rate, 
rowtime])
+   :- Exchange(distribution=[hash[o_currency]])
+   :  +- Calc(select=[o_amount, o_currency, o_rowtime], where=[>(o_amount, 
1000)])
+   :     +- DataStreamScan(table=[[default_catalog, default_database, 
Orders]], fields=[o_amount, o_currency, o_rowtime])
+   +- Exchange(distribution=[hash[currency]])
+      +- DataStreamScan(table=[[default_catalog, default_database, 
RatesHistory]], fields=[currency, rate, rowtime])
+]]>
+    </Resource>
+  </TestCase>
   <TestCase name="testSimpleProctimeJoin">
     <Resource name="sql">
       <![CDATA[SELECT o_amount * rate as rate FROM ProctimeOrders AS o, 
LATERAL TABLE (ProctimeRates(o.o_proctime)) AS r WHERE currency = o_currency]]>
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.scala
index 9de3886..299c440 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.scala
@@ -150,8 +150,7 @@ class MiniBatchIntervalInferTest extends TableTestBase {
     util.verifyPlan(sql)
   }
 
-  @Test(expected = classOf[NullPointerException])
-  // TODO remove the exception after TableImpl implements 
createTemporalTableFunction
+  @Test
   def testTemporalTableFunctionJoinWithMiniBatch(): Unit = {
     util.addTableWithWatermark("Orders", util.tableEnv.scan("MyTable1"), 
"rowtime", 0)
     util.addTableWithWatermark("RatesHistory", util.tableEnv.scan("MyTable2"), 
"rowtime", 0)
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.scala
index 4391d1b..0bb49ef 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.scala
@@ -73,6 +73,20 @@ class TemporalJoinTest extends TableTestBase {
     util.verifyPlan(sqlQuery)
   }
 
+  @Test
+  def testJoinOnQueryLeft(): Unit = {
+    val orders = util.tableEnv.sqlQuery("SELECT * FROM Orders WHERE o_amount > 
1000")
+    util.tableEnv.registerTable("Orders2", orders)
+
+    val sqlQuery = "SELECT " +
+      "o_amount * rate as rate " +
+      "FROM Orders2 AS o, " +
+      "LATERAL TABLE (Rates(o.o_rowtime)) AS r " +
+      "WHERE currency = o_currency"
+
+    util.verifyPlan(sqlQuery)
+  }
+
   /**
     * Test versioned joins with more complicated query.
     * Important thing here is that we have complex OR join condition
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TemporalJoinITCase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TemporalJoinITCase.scala
index 3e7fcec..3b2e204 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TemporalJoinITCase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TemporalJoinITCase.scala
@@ -131,11 +131,11 @@ class TemporalJoinITCase(state: StateBackendMode)
 
     val orders = env
       .fromCollection(ordersData)
-      .assignTimestampsAndWatermarks(new TimestampExtractor[Long, String]())
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(Long, String, 
Timestamp)]())
       .toTable(tEnv, 'amount, 'currency, 'rowtime.rowtime)
     val ratesHistory = env
       .fromCollection(ratesHistoryData)
-      .assignTimestampsAndWatermarks(new TimestampExtractor[String, Long]())
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(String, Long, 
Timestamp)]())
       .toTable(tEnv, 'currency, 'rate, 'rowtime.rowtime)
 
     tEnv.registerTable("Orders", orders)
@@ -158,11 +158,92 @@ class TemporalJoinITCase(state: StateBackendMode)
 
     assertEquals(expectedOutput, sink.getAppendResults.toSet)
   }
+
+  @Test
+  def testNestedTemporalJoin(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
+    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+
+    val sqlQuery =
+      """
+        |SELECT
+        |  o.orderId,
+        |  (o.amount * p.price * r.rate) as total_price
+        |FROM
+        |  Orders AS o,
+        |  LATERAL TABLE (Prices(o.rowtime)) AS p,
+        |  LATERAL TABLE (Rates(o.rowtime)) AS r
+        |WHERE
+        |  o.productId = p.productId AND
+        |  r.currency = p.currency
+        |""".stripMargin
+
+    val ordersData = new mutable.MutableList[(Long, String, Long, Timestamp)]
+    ordersData.+=((1L, "A1", 2L, new Timestamp(2L)))
+    ordersData.+=((2L, "A2", 1L, new Timestamp(3L)))
+    ordersData.+=((3L, "A4", 50L, new Timestamp(4L)))
+    ordersData.+=((4L, "A1", 3L, new Timestamp(5L)))
+    val orders = env
+      .fromCollection(ordersData)
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(Long, String, 
Long, Timestamp)]())
+      .toTable(tEnv, 'orderId, 'productId, 'amount, 'rowtime.rowtime)
+
+    val ratesHistoryData = new mutable.MutableList[(String, Long, Timestamp)]
+    ratesHistoryData.+=(("US Dollar", 102L, new Timestamp(1L)))
+    ratesHistoryData.+=(("Euro", 114L, new Timestamp(1L)))
+    ratesHistoryData.+=(("Yen", 1L, new Timestamp(1L)))
+    ratesHistoryData.+=(("Euro", 116L, new Timestamp(5L)))
+    ratesHistoryData.+=(("Euro", 119L, new Timestamp(7L)))
+    val ratesHistory = env
+      .fromCollection(ratesHistoryData)
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(String, Long, 
Timestamp)]())
+      .toTable(tEnv, 'currency, 'rate, 'rowtime.rowtime)
+
+    val pricesHistoryData = new mutable.MutableList[(String, String, Double, 
Timestamp)]
+    pricesHistoryData.+=(("A2", "US Dollar", 10.2D, new Timestamp(1L)))
+    pricesHistoryData.+=(("A1", "Euro", 11.4D, new Timestamp(1L)))
+    pricesHistoryData.+=(("A4", "Yen", 1D, new Timestamp(1L)))
+    pricesHistoryData.+=(("A1", "Euro", 11.6D, new Timestamp(5L)))
+    pricesHistoryData.+=(("A1", "Euro", 11.9D, new Timestamp(7L)))
+    val pricesHistory = env
+      .fromCollection(pricesHistoryData)
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(String, String, 
Double, Timestamp)]())
+      .toTable(tEnv, 'productId, 'currency, 'price, 'rowtime.rowtime)
+
+    tEnv.registerTable("Orders", orders)
+    tEnv.registerTable("RatesHistory", ratesHistory)
+    tEnv.registerFunction(
+      "Rates",
+      ratesHistory.createTemporalTableFunction("rowtime", "currency"))
+    tEnv.registerFunction(
+      "Prices",
+      pricesHistory.createTemporalTableFunction("rowtime", "productId"))
+
+    tEnv.registerTable("TemporalJoinResult", tEnv.sqlQuery(sqlQuery))
+
+    // Scan from registered table to test for interplay between
+    // LogicalCorrelateToTemporalTableJoinRule and TableScanRule
+    val result = tEnv.scan("TemporalJoinResult").toAppendStream[Row]
+    val sink = new TestingAppendSink
+    result.addSink(sink)
+    env.execute()
+
+    val expected = List(
+      s"1,${2 * 114 * 11.4}",
+      s"2,${1 * 102 * 10.2}",
+      s"3,${50 * 1 * 1.0}",
+      s"4,${3 * 116 * 11.6}")
+    assertEquals(expected.sorted, sink.getAppendResults.sorted)
+  }
 }
 
-class TimestampExtractor[T1, T2]
-  extends BoundedOutOfOrdernessTimestampExtractor[(T1, T2, 
Timestamp)](Time.seconds(10))  {
-  override def extractTimestamp(element: (T1, T2, Timestamp)): Long = {
-    element._3.getTime
+class TimestampExtractor[T <: Product]
+  extends BoundedOutOfOrdernessTimestampExtractor[T](Time.seconds(10))  {
+  override def extractTimestamp(element: T): Long = element match {
+    case (_, _, ts: Timestamp) => ts.getTime
+    case (_, _, _, ts: Timestamp) => ts.getTime
+    case _ => throw new IllegalArgumentException(
+      "Expected the last element in a tuple to be of a Timestamp type.")
   }
 }
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/QueryOperationCatalogViewTable.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/QueryOperationCatalogViewTable.java
index 9dd2690..cbb3c5d 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/QueryOperationCatalogViewTable.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/QueryOperationCatalogViewTable.java
@@ -60,7 +60,7 @@ public class QueryOperationCatalogViewTable extends 
AbstractTable implements Tra
 
        @Override
        public RelNode toRel(RelOptTable.ToRelContext context, RelOptTable 
relOptTable) {
-               FlinkRelBuilder relBuilder = 
FlinkRelBuilder.of(context.getCluster(), relOptTable);
+               FlinkRelBuilder relBuilder = 
FlinkRelBuilder.of(context.getCluster(), relOptTable.getRelOptSchema());
 
                RelNode relNode = 
relBuilder.tableOperation(catalogView.getQueryOperation()).build();
                return RelOptUtil.createCastRel(relNode, 
rowType.apply(relBuilder.getTypeFactory()), false);
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
index 8368667..0a6cef7 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
@@ -120,13 +120,13 @@ object FlinkRelBuilder {
     */
   case class NamedWindowProperty(name: String, property: WindowProperty)
 
-  def of(cluster: RelOptCluster, relTable: RelOptTable): FlinkRelBuilder = {
+  def of(cluster: RelOptCluster, relOptSchema: RelOptSchema): FlinkRelBuilder 
= {
     val clusterContext = cluster.getPlanner.getContext
 
     new FlinkRelBuilder(
       clusterContext,
       cluster,
-      relTable.getRelOptSchema,
+      relOptSchema,
       clusterContext.unwrap(classOf[ExpressionBridge[PlannerExpression]]))
   }
 
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
index 72e6290..e87e356 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
@@ -19,8 +19,9 @@
 package org.apache.flink.table.plan.rules.logical
 
 import org.apache.calcite.plan.RelOptRule.{any, none, operand, some}
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
-import org.apache.calcite.rel.RelNode
+import org.apache.calcite.plan.hep.HepRelVertex
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptSchema}
+import org.apache.calcite.rel.{BiRel, RelNode, SingleRel}
 import org.apache.calcite.rel.core.TableFunctionScan
 import org.apache.calcite.rel.logical.LogicalCorrelate
 import org.apache.calcite.rex._
@@ -88,7 +89,7 @@ class LogicalCorrelateToTemporalTableJoinRule
           .getUnderlyingHistoryTable
         val rexBuilder = cluster.getRexBuilder
 
-        val relBuilder = FlinkRelBuilder.of(cluster, leftNode.getTable)
+        val relBuilder = FlinkRelBuilder.of(cluster, getRelOptSchema(leftNode))
         val rightNode: RelNode = 
relBuilder.tableOperation(underlyingHistoryTable).build()
 
         val rightTimeIndicatorExpression = createRightExpression(
@@ -139,6 +140,16 @@ class LogicalCorrelateToTemporalTableJoinRule
     rexBuilder.makeInputRef(
       rightDataTypeField.getType, rightReferencesOffset + 
rightDataTypeField.getIndex)
   }
+
+  /**
+    * Gets [[RelOptSchema]] from the leaf [[RelNode]] which holds a non-null 
[[RelOptSchema]].
+    */
+  private def getRelOptSchema(relNode: RelNode): RelOptSchema = relNode match {
+    case hep: HepRelVertex => getRelOptSchema(hep.getCurrentRel)
+    case single: SingleRel => getRelOptSchema(single.getInput)
+    case bi: BiRel => getRelOptSchema(bi.getLeft)
+    case _ => relNode.getTable.getRelOptSchema
+  }
 }
 
 object LogicalCorrelateToTemporalTableJoinRule {
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
index 6051249..05874bc 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
@@ -74,6 +74,35 @@ class TemporalTableJoinTest extends TableTestBase {
     util.verifySql(sqlQuery, getExpectedSimpleProctimeJoinPlan())
   }
 
+  @Test
+  def testJoinOnQueryLeft(): Unit = {
+    val sqlQuery = "SELECT " +
+      "o_amount * rate as rate " +
+      "FROM (SELECT * FROM Orders WHERE o_amount > 1000) AS o, " +
+      "LATERAL TABLE (Rates(o.o_rowtime)) AS r " +
+      "WHERE currency = o_currency"
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      binaryNode(
+        "DataStreamTemporalTableJoin",
+        unaryNode("DataStreamCalc",
+          streamTableNode(orders),
+          term("select", "o_amount, o_currency, o_rowtime"),
+          term("where", ">(o_amount, 1000)")),
+        streamTableNode(ratesHistory),
+        term("where",
+          "AND(" +
+            s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime, 
currency), " +
+            "=(currency, o_currency))"),
+        term("join", "o_amount", "o_currency", "o_rowtime", "currency", 
"rate", "rowtime"),
+        term("joinType", "InnerJoin")
+      ),
+      term("select", "*(o_amount, rate) AS rate")
+    )
+    util.verifySql(sqlQuery, expected)
+  }
+
   /**
     * Test versioned joins with more complicated query.
     * Important thing here is that we have complex OR join condition
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
index 463f3f06..6466e1e 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
@@ -19,7 +19,6 @@
 package org.apache.flink.table.runtime.stream.sql
 
 import java.sql.Timestamp
-
 import org.apache.flink.api.scala._
 import org.apache.flink.streaming.api.TimeCharacteristic
 import 
org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor
@@ -130,11 +129,11 @@ class TemporalJoinITCase extends 
StreamingWithStateTestBase {
 
     val orders = env
       .fromCollection(ordersData)
-      .assignTimestampsAndWatermarks(new TimestampExtractor[Long, String]())
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(Long, String, 
Timestamp)]())
       .toTable(tEnv, 'amount, 'currency, 'rowtime.rowtime)
     val ratesHistory = env
       .fromCollection(ratesHistoryData)
-      .assignTimestampsAndWatermarks(new TimestampExtractor[String, Long]())
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(String, Long, 
Timestamp)]())
       .toTable(tEnv, 'currency, 'rate, 'rowtime.rowtime)
 
     tEnv.registerTable("Orders", orders)
@@ -153,11 +152,93 @@ class TemporalJoinITCase extends 
StreamingWithStateTestBase {
 
     assertEquals(expectedOutput, StreamITCase.testResults.toSet)
   }
+
+  @Test
+  def testNestedTemporalJoin(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = StreamTableEnvironment.create(env)
+    env.setStateBackend(getStateBackend)
+    StreamITCase.clear
+    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+
+    val sqlQuery =
+      """
+        |SELECT
+        |  o.orderId,
+        |  (o.amount * p.price * r.rate) as total_price
+        |FROM
+        |  Orders AS o,
+        |  LATERAL TABLE (Prices(o.rowtime)) AS p,
+        |  LATERAL TABLE (Rates(o.rowtime)) AS r
+        |WHERE
+        |  o.productId = p.productId AND
+        |  r.currency = p.currency
+        |""".stripMargin
+
+    val ordersData = new mutable.MutableList[(Long, String, Long, Timestamp)]
+    ordersData.+=((1L, "A1", 2L, new Timestamp(2L)))
+    ordersData.+=((2L, "A2", 1L, new Timestamp(3L)))
+    ordersData.+=((3L, "A4", 50L, new Timestamp(4L)))
+    ordersData.+=((4L, "A1", 3L, new Timestamp(5L)))
+    val orders = env
+      .fromCollection(ordersData)
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(Long, String, 
Long, Timestamp)]())
+      .toTable(tEnv, 'orderId, 'productId, 'amount, 'rowtime.rowtime)
+
+    val ratesHistoryData = new mutable.MutableList[(String, Long, Timestamp)]
+    ratesHistoryData.+=(("US Dollar", 102L, new Timestamp(1L)))
+    ratesHistoryData.+=(("Euro", 114L, new Timestamp(1L)))
+    ratesHistoryData.+=(("Yen", 1L, new Timestamp(1L)))
+    ratesHistoryData.+=(("Euro", 116L, new Timestamp(5L)))
+    ratesHistoryData.+=(("Euro", 119L, new Timestamp(7L)))
+    val ratesHistory = env
+      .fromCollection(ratesHistoryData)
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(String, Long, 
Timestamp)]())
+      .toTable(tEnv, 'currency, 'rate, 'rowtime.rowtime)
+
+    val pricesHistoryData = new mutable.MutableList[(String, String, Double, 
Timestamp)]
+    pricesHistoryData.+=(("A2", "US Dollar", 10.2D, new Timestamp(1L)))
+    pricesHistoryData.+=(("A1", "Euro", 11.4D, new Timestamp(1L)))
+    pricesHistoryData.+=(("A4", "Yen", 1D, new Timestamp(1L)))
+    pricesHistoryData.+=(("A1", "Euro", 11.6D, new Timestamp(5L)))
+    pricesHistoryData.+=(("A1", "Euro", 11.9D, new Timestamp(7L)))
+    val pricesHistory = env
+      .fromCollection(pricesHistoryData)
+      .assignTimestampsAndWatermarks(new TimestampExtractor[(String, String, 
Double, Timestamp)]())
+      .toTable(tEnv, 'productId, 'currency, 'price, 'rowtime.rowtime)
+
+    tEnv.registerTable("Orders", orders)
+    tEnv.registerTable("RatesHistory", ratesHistory)
+    tEnv.registerFunction(
+      "Rates",
+      ratesHistory.createTemporalTableFunction("rowtime", "currency"))
+    tEnv.registerFunction(
+      "Prices",
+      pricesHistory.createTemporalTableFunction("rowtime", "productId"))
+
+    tEnv.registerTable("TemporalJoinResult", tEnv.sqlQuery(sqlQuery))
+
+    // Scan from registered table to test for interplay between
+    // LogicalCorrelateToTemporalTableJoinRule and TableScanRule
+    val result = tEnv.scan("TemporalJoinResult").toAppendStream[Row]
+    result.addSink(new StreamITCase.StringSink[Row])
+    env.execute()
+
+    val expected = List(
+      s"1,${2 * 114 * 11.4}",
+      s"2,${1 * 102 * 10.2}",
+      s"3,${50 * 1 * 1.0}",
+      s"4,${3 * 116 * 11.6}")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
 }
 
-class TimestampExtractor[T1, T2]
-  extends BoundedOutOfOrdernessTimestampExtractor[(T1, T2, 
Timestamp)](Time.seconds(10))  {
-  override def extractTimestamp(element: (T1, T2, Timestamp)): Long = {
-    element._3.getTime
+class TimestampExtractor[T <: Product]
+  extends BoundedOutOfOrdernessTimestampExtractor[T](Time.seconds(10))  {
+  override def extractTimestamp(element: T): Long = element match {
+    case (_, _, ts: Timestamp) => ts.getTime
+    case (_, _, _, ts: Timestamp) => ts.getTime
+    case _ => throw new IllegalArgumentException(
+      "Expected the last element in a tuple to be of a Timestamp type.")
   }
 }

Reply via email to