This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch release-1.9
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push:
new 632a6db [FLINK-14200][table] Fix NPE for Temporal Table Function Join
when left side is a query instead of a source (#10782)
632a6db is described below
commit 632a6dbebb81c950f0c72da64bdc7e5112820a9d
Author: Jark Wu <[email protected]>
AuthorDate: Tue Jan 7 17:48:46 2020 +0800
[FLINK-14200][table] Fix NPE for Temporal Table Function Join when left
side is a query instead of a source (#10782)
---
.../catalog/QueryOperationCatalogViewTable.java | 2 +-
.../table/planner/calcite/FlinkRelBuilder.scala | 4 +-
...relateToJoinFromTemporalTableFunctionRule.scala | 17 +++-
.../plan/stream/sql/MiniBatchIntervalInferTest.xml | 42 ++++++++++
.../plan/stream/sql/join/TemporalJoinTest.xml | 27 ++++++
.../stream/sql/MiniBatchIntervalInferTest.scala | 3 +-
.../plan/stream/sql/join/TemporalJoinTest.scala | 14 ++++
.../runtime/stream/sql/TemporalJoinITCase.scala | 93 +++++++++++++++++++--
.../catalog/QueryOperationCatalogViewTable.java | 2 +-
.../flink/table/calcite/FlinkRelBuilder.scala | 4 +-
.../LogicalCorrelateToTemporalTableJoinRule.scala | 17 +++-
.../api/stream/sql/TemporalTableJoinTest.scala | 29 +++++++
.../runtime/stream/sql/TemporalJoinITCase.scala | 95 ++++++++++++++++++++--
13 files changed, 322 insertions(+), 27 deletions(-)
diff --git
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/QueryOperationCatalogViewTable.java
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/QueryOperationCatalogViewTable.java
index 1251dc4..5d7be6a 100644
---
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/QueryOperationCatalogViewTable.java
+++
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/catalog/QueryOperationCatalogViewTable.java
@@ -79,7 +79,7 @@ public class QueryOperationCatalogViewTable extends
FlinkTable implements Transl
@Override
public RelNode toRel(RelOptTable.ToRelContext context, RelOptTable
relOptTable) {
- FlinkRelBuilder relBuilder =
FlinkRelBuilder.of(context.getCluster(), relOptTable);
+ FlinkRelBuilder relBuilder =
FlinkRelBuilder.of(context.getCluster(), relOptTable.getRelOptSchema());
RelNode relNode =
relBuilder.queryOperation(catalogView.getQueryOperation()).build();
return RelOptUtil.createCastRel(relNode,
rowType.apply(relBuilder.getTypeFactory()), false);
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
index 31f8079..e2dff7f 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/FlinkRelBuilder.scala
@@ -159,11 +159,11 @@ object FlinkRelBuilder {
}
}
- def of(cluster: RelOptCluster, relTable: RelOptTable): FlinkRelBuilder = {
+ def of(cluster: RelOptCluster, relOptSchema: RelOptSchema): FlinkRelBuilder
= {
val clusterContext = cluster.getPlanner.getContext
new FlinkRelBuilder(
clusterContext,
cluster,
- relTable.getRelOptSchema)
+ relOptSchema)
}
}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
index 62a4872..b40a7dd 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
@@ -32,8 +32,9 @@ import
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{hasRoot, is
import org.apache.flink.util.Preconditions.checkState
import org.apache.calcite.plan.RelOptRule.{any, none, operand, some}
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
-import org.apache.calcite.rel.RelNode
+import org.apache.calcite.plan.hep.HepRelVertex
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptSchema}
+import org.apache.calcite.rel.{BiRel, RelNode, SingleRel}
import org.apache.calcite.rel.core.{JoinRelType, TableFunctionScan}
import org.apache.calcite.rel.logical.LogicalCorrelate
import org.apache.calcite.rex._
@@ -95,7 +96,7 @@ class LogicalCorrelateToJoinFromTemporalTableFunctionRule
.getUnderlyingHistoryTable
val rexBuilder = cluster.getRexBuilder
- val relBuilder = FlinkRelBuilder.of(cluster, leftNode.getTable)
+ val relBuilder = FlinkRelBuilder.of(cluster, getRelOptSchema(leftNode))
val temporalTable: RelNode =
relBuilder.queryOperation(underlyingHistoryTable).build()
// expand QueryOperationCatalogViewTable in TableScan
val shuttle = new ExpandTableScanShuttle
@@ -145,6 +146,16 @@ class LogicalCorrelateToJoinFromTemporalTableFunctionRule
rexBuilder.makeInputRef(
rightDataTypeField.getType, rightReferencesOffset +
rightDataTypeField.getIndex)
}
+
+ /**
+ * Gets [[RelOptSchema]] from the leaf [[RelNode]] which holds a non-null
[[RelOptSchema]].
+ */
+ private def getRelOptSchema(relNode: RelNode): RelOptSchema = relNode match {
+ case hep: HepRelVertex => getRelOptSchema(hep.getCurrentRel)
+ case single: SingleRel => getRelOptSchema(single.getInput)
+ case bi: BiRel => getRelOptSchema(bi.getLeft)
+ case _ => relNode.getTable.getRelOptSchema
+ }
}
object LogicalCorrelateToJoinFromTemporalTableFunctionRule {
diff --git
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.xml
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.xml
index b82a3d5..0b5f851 100644
---
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.xml
+++
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.xml
@@ -448,6 +448,48 @@ GlobalGroupAggregate(groupBy=[cnt], select=[cnt,
COUNT(count$0) AS EXPR$1])
]]>
</Resource>
</TestCase>
+ <TestCase name="testTemporalTableFunctionJoinWithMiniBatch">
+ <Resource name="sql">
+ <![CDATA[
+ SELECT r_a, COUNT(o_a)
+ FROM (
+ SELECT o.a as o_a, r.a as r_a
+ FROM Orders As o,
+ LATERAL TABLE (Rates(o.rowtime)) as r
+ WHERE o.b = r.b
+ )
+ GROUP BY r_a
+ ]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[COUNT($1)])
++- LogicalProject(r_a=[$5], o_a=[$0])
+ +- LogicalFilter(condition=[=($1, $6)])
+ +- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{4}])
+ :- LogicalWatermarkAssigner(fields=[a, b, c, proctime, rowtime],
rowtimeField=[rowtime], watermarkDelay=[0])
+ : +- LogicalTableScan(table=[[default_catalog, default_database,
MyTable1]])
+ +- LogicalTableFunctionScan(invocation=[Rates($cor0.rowtime)],
rowType=[RecordType(INTEGER a, VARCHAR(2147483647) b, BIGINT c, TIME
ATTRIBUTE(PROCTIME) proctime, TIME ATTRIBUTE(ROWTIME) rowtime)],
elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+GlobalGroupAggregate(groupBy=[r_a], select=[r_a, COUNT(count$0) AS EXPR$1])
++- Exchange(distribution=[hash[r_a]])
+ +- LocalGroupAggregate(groupBy=[r_a], select=[r_a, COUNT(o_a) AS count$0])
+ +- Calc(select=[a0 AS r_a, a AS o_a])
+ +- TemporalJoin(joinType=[InnerJoin],
where=[AND(__TEMPORAL_JOIN_CONDITION(rowtime, rowtime0, b0), =(b, b0))],
select=[a, b, rowtime, a0, b0, rowtime0])
+ :- Exchange(distribution=[hash[b]])
+ : +- Calc(select=[a, b, rowtime])
+ : +- WatermarkAssigner(fields=[a, b, c, proctime, rowtime],
rowtimeField=[rowtime], watermarkDelay=[0], miniBatchInterval=[Rowtime, 1000ms])
+ : +- DataStreamScan(table=[[default_catalog,
default_database, MyTable1]], fields=[a, b, c, proctime, rowtime])
+ +- Exchange(distribution=[hash[b]])
+ +- Calc(select=[a, b, rowtime])
+ +- WatermarkAssigner(fields=[a, b, c, proctime, rowtime],
rowtimeField=[rowtime], watermarkDelay=[0], miniBatchInterval=[Rowtime, 1000ms])
+ +- DataStreamScan(table=[[default_catalog,
default_database, MyTable2]], fields=[a, b, c, proctime, rowtime])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testWindowJoinWithMiniBatch">
<Resource name="sql">
<![CDATA[
diff --git
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.xml
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.xml
index 9533bb2..31b93bf 100644
---
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.xml
+++
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.xml
@@ -74,6 +74,33 @@ Join(joinType=[InnerJoin], where=[=(t3_secondary_key,
secondary_key)], select=[r
]]>
</Resource>
</TestCase>
+ <TestCase name="testJoinOnQueryLeft">
+ <Resource name="sql">
+ <![CDATA[SELECT o_amount * rate as rate FROM Orders2 AS o, LATERAL TABLE
(Rates(o.o_rowtime)) AS r WHERE currency = o_currency]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(rate=[*($0, $4)])
++- LogicalFilter(condition=[=($3, $1)])
+ +- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{2}])
+ :- LogicalProject(o_amount=[$0], o_currency=[$1], o_rowtime=[$2])
+ : +- LogicalFilter(condition=[>($0, 1000)])
+ : +- LogicalTableScan(table=[[default_catalog, default_database,
Orders]])
+ +- LogicalTableFunctionScan(invocation=[Rates($cor0.o_rowtime)],
rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME
ATTRIBUTE(ROWTIME) rowtime)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[*(o_amount, rate) AS rate])
++- TemporalJoin(joinType=[InnerJoin],
where=[AND(__TEMPORAL_JOIN_CONDITION(o_rowtime, rowtime, currency), =(currency,
o_currency))], select=[o_amount, o_currency, o_rowtime, currency, rate,
rowtime])
+ :- Exchange(distribution=[hash[o_currency]])
+ : +- Calc(select=[o_amount, o_currency, o_rowtime], where=[>(o_amount,
1000)])
+ : +- DataStreamScan(table=[[default_catalog, default_database,
Orders]], fields=[o_amount, o_currency, o_rowtime])
+ +- Exchange(distribution=[hash[currency]])
+ +- DataStreamScan(table=[[default_catalog, default_database,
RatesHistory]], fields=[currency, rate, rowtime])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testSimpleProctimeJoin">
<Resource name="sql">
<![CDATA[SELECT o_amount * rate as rate FROM ProctimeOrders AS o,
LATERAL TABLE (ProctimeRates(o.o_proctime)) AS r WHERE currency = o_currency]]>
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.scala
index 9de3886..299c440 100644
---
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.scala
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/MiniBatchIntervalInferTest.scala
@@ -150,8 +150,7 @@ class MiniBatchIntervalInferTest extends TableTestBase {
util.verifyPlan(sql)
}
- @Test(expected = classOf[NullPointerException])
- // TODO remove the exception after TableImpl implements
createTemporalTableFunction
+ @Test
def testTemporalTableFunctionJoinWithMiniBatch(): Unit = {
util.addTableWithWatermark("Orders", util.tableEnv.scan("MyTable1"),
"rowtime", 0)
util.addTableWithWatermark("RatesHistory", util.tableEnv.scan("MyTable2"),
"rowtime", 0)
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.scala
index 4391d1b..0bb49ef 100644
---
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.scala
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/TemporalJoinTest.scala
@@ -73,6 +73,20 @@ class TemporalJoinTest extends TableTestBase {
util.verifyPlan(sqlQuery)
}
+ @Test
+ def testJoinOnQueryLeft(): Unit = {
+ val orders = util.tableEnv.sqlQuery("SELECT * FROM Orders WHERE o_amount >
1000")
+ util.tableEnv.registerTable("Orders2", orders)
+
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders2 AS o, " +
+ "LATERAL TABLE (Rates(o.o_rowtime)) AS r " +
+ "WHERE currency = o_currency"
+
+ util.verifyPlan(sqlQuery)
+ }
+
/**
* Test versioned joins with more complicated query.
* Important thing here is that we have complex OR join condition
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TemporalJoinITCase.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TemporalJoinITCase.scala
index 3e7fcec..3b2e204 100644
---
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TemporalJoinITCase.scala
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TemporalJoinITCase.scala
@@ -131,11 +131,11 @@ class TemporalJoinITCase(state: StateBackendMode)
val orders = env
.fromCollection(ordersData)
- .assignTimestampsAndWatermarks(new TimestampExtractor[Long, String]())
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(Long, String,
Timestamp)]())
.toTable(tEnv, 'amount, 'currency, 'rowtime.rowtime)
val ratesHistory = env
.fromCollection(ratesHistoryData)
- .assignTimestampsAndWatermarks(new TimestampExtractor[String, Long]())
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(String, Long,
Timestamp)]())
.toTable(tEnv, 'currency, 'rate, 'rowtime.rowtime)
tEnv.registerTable("Orders", orders)
@@ -158,11 +158,92 @@ class TemporalJoinITCase(state: StateBackendMode)
assertEquals(expectedOutput, sink.getAppendResults.toSet)
}
+
+ @Test
+ def testNestedTemporalJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = StreamTableEnvironment.create(env, TableTestUtil.STREAM_SETTING)
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+
+ val sqlQuery =
+ """
+ |SELECT
+ | o.orderId,
+ | (o.amount * p.price * r.rate) as total_price
+ |FROM
+ | Orders AS o,
+ | LATERAL TABLE (Prices(o.rowtime)) AS p,
+ | LATERAL TABLE (Rates(o.rowtime)) AS r
+ |WHERE
+ | o.productId = p.productId AND
+ | r.currency = p.currency
+ |""".stripMargin
+
+ val ordersData = new mutable.MutableList[(Long, String, Long, Timestamp)]
+ ordersData.+=((1L, "A1", 2L, new Timestamp(2L)))
+ ordersData.+=((2L, "A2", 1L, new Timestamp(3L)))
+ ordersData.+=((3L, "A4", 50L, new Timestamp(4L)))
+ ordersData.+=((4L, "A1", 3L, new Timestamp(5L)))
+ val orders = env
+ .fromCollection(ordersData)
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(Long, String,
Long, Timestamp)]())
+ .toTable(tEnv, 'orderId, 'productId, 'amount, 'rowtime.rowtime)
+
+ val ratesHistoryData = new mutable.MutableList[(String, Long, Timestamp)]
+ ratesHistoryData.+=(("US Dollar", 102L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Euro", 114L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Yen", 1L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Euro", 116L, new Timestamp(5L)))
+ ratesHistoryData.+=(("Euro", 119L, new Timestamp(7L)))
+ val ratesHistory = env
+ .fromCollection(ratesHistoryData)
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(String, Long,
Timestamp)]())
+ .toTable(tEnv, 'currency, 'rate, 'rowtime.rowtime)
+
+ val pricesHistoryData = new mutable.MutableList[(String, String, Double,
Timestamp)]
+ pricesHistoryData.+=(("A2", "US Dollar", 10.2D, new Timestamp(1L)))
+ pricesHistoryData.+=(("A1", "Euro", 11.4D, new Timestamp(1L)))
+ pricesHistoryData.+=(("A4", "Yen", 1D, new Timestamp(1L)))
+ pricesHistoryData.+=(("A1", "Euro", 11.6D, new Timestamp(5L)))
+ pricesHistoryData.+=(("A1", "Euro", 11.9D, new Timestamp(7L)))
+ val pricesHistory = env
+ .fromCollection(pricesHistoryData)
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(String, String,
Double, Timestamp)]())
+ .toTable(tEnv, 'productId, 'currency, 'price, 'rowtime.rowtime)
+
+ tEnv.registerTable("Orders", orders)
+ tEnv.registerTable("RatesHistory", ratesHistory)
+ tEnv.registerFunction(
+ "Rates",
+ ratesHistory.createTemporalTableFunction("rowtime", "currency"))
+ tEnv.registerFunction(
+ "Prices",
+ pricesHistory.createTemporalTableFunction("rowtime", "productId"))
+
+ tEnv.registerTable("TemporalJoinResult", tEnv.sqlQuery(sqlQuery))
+
+ // Scan from registered table to test for interplay between
+ // LogicalCorrelateToTemporalTableJoinRule and TableScanRule
+ val result = tEnv.scan("TemporalJoinResult").toAppendStream[Row]
+ val sink = new TestingAppendSink
+ result.addSink(sink)
+ env.execute()
+
+ val expected = List(
+ s"1,${2 * 114 * 11.4}",
+ s"2,${1 * 102 * 10.2}",
+ s"3,${50 * 1 * 1.0}",
+ s"4,${3 * 116 * 11.6}")
+ assertEquals(expected.sorted, sink.getAppendResults.sorted)
+ }
}
-class TimestampExtractor[T1, T2]
- extends BoundedOutOfOrdernessTimestampExtractor[(T1, T2,
Timestamp)](Time.seconds(10)) {
- override def extractTimestamp(element: (T1, T2, Timestamp)): Long = {
- element._3.getTime
+class TimestampExtractor[T <: Product]
+ extends BoundedOutOfOrdernessTimestampExtractor[T](Time.seconds(10)) {
+ override def extractTimestamp(element: T): Long = element match {
+ case (_, _, ts: Timestamp) => ts.getTime
+ case (_, _, _, ts: Timestamp) => ts.getTime
+ case _ => throw new IllegalArgumentException(
+ "Expected the last element in a tuple to be of a Timestamp type.")
}
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/QueryOperationCatalogViewTable.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/QueryOperationCatalogViewTable.java
index 9dd2690..cbb3c5d 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/QueryOperationCatalogViewTable.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/catalog/QueryOperationCatalogViewTable.java
@@ -60,7 +60,7 @@ public class QueryOperationCatalogViewTable extends
AbstractTable implements Tra
@Override
public RelNode toRel(RelOptTable.ToRelContext context, RelOptTable
relOptTable) {
- FlinkRelBuilder relBuilder =
FlinkRelBuilder.of(context.getCluster(), relOptTable);
+ FlinkRelBuilder relBuilder =
FlinkRelBuilder.of(context.getCluster(), relOptTable.getRelOptSchema());
RelNode relNode =
relBuilder.tableOperation(catalogView.getQueryOperation()).build();
return RelOptUtil.createCastRel(relNode,
rowType.apply(relBuilder.getTypeFactory()), false);
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
index 8368667..0a6cef7 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/calcite/FlinkRelBuilder.scala
@@ -120,13 +120,13 @@ object FlinkRelBuilder {
*/
case class NamedWindowProperty(name: String, property: WindowProperty)
- def of(cluster: RelOptCluster, relTable: RelOptTable): FlinkRelBuilder = {
+ def of(cluster: RelOptCluster, relOptSchema: RelOptSchema): FlinkRelBuilder
= {
val clusterContext = cluster.getPlanner.getContext
new FlinkRelBuilder(
clusterContext,
cluster,
- relTable.getRelOptSchema,
+ relOptSchema,
clusterContext.unwrap(classOf[ExpressionBridge[PlannerExpression]]))
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
index 72e6290..e87e356 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
@@ -19,8 +19,9 @@
package org.apache.flink.table.plan.rules.logical
import org.apache.calcite.plan.RelOptRule.{any, none, operand, some}
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
-import org.apache.calcite.rel.RelNode
+import org.apache.calcite.plan.hep.HepRelVertex
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptSchema}
+import org.apache.calcite.rel.{BiRel, RelNode, SingleRel}
import org.apache.calcite.rel.core.TableFunctionScan
import org.apache.calcite.rel.logical.LogicalCorrelate
import org.apache.calcite.rex._
@@ -88,7 +89,7 @@ class LogicalCorrelateToTemporalTableJoinRule
.getUnderlyingHistoryTable
val rexBuilder = cluster.getRexBuilder
- val relBuilder = FlinkRelBuilder.of(cluster, leftNode.getTable)
+ val relBuilder = FlinkRelBuilder.of(cluster, getRelOptSchema(leftNode))
val rightNode: RelNode =
relBuilder.tableOperation(underlyingHistoryTable).build()
val rightTimeIndicatorExpression = createRightExpression(
@@ -139,6 +140,16 @@ class LogicalCorrelateToTemporalTableJoinRule
rexBuilder.makeInputRef(
rightDataTypeField.getType, rightReferencesOffset +
rightDataTypeField.getIndex)
}
+
+ /**
+ * Gets [[RelOptSchema]] from the leaf [[RelNode]] which holds a non-null
[[RelOptSchema]].
+ */
+ private def getRelOptSchema(relNode: RelNode): RelOptSchema = relNode match {
+ case hep: HepRelVertex => getRelOptSchema(hep.getCurrentRel)
+ case single: SingleRel => getRelOptSchema(single.getInput)
+ case bi: BiRel => getRelOptSchema(bi.getLeft)
+ case _ => relNode.getTable.getRelOptSchema
+ }
}
object LogicalCorrelateToTemporalTableJoinRule {
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
index 6051249..05874bc 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/TemporalTableJoinTest.scala
@@ -74,6 +74,35 @@ class TemporalTableJoinTest extends TableTestBase {
util.verifySql(sqlQuery, getExpectedSimpleProctimeJoinPlan())
}
+ @Test
+ def testJoinOnQueryLeft(): Unit = {
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM (SELECT * FROM Orders WHERE o_amount > 1000) AS o, " +
+ "LATERAL TABLE (Rates(o.o_rowtime)) AS r " +
+ "WHERE currency = o_currency"
+
+ val expected = unaryNode(
+ "DataStreamCalc",
+ binaryNode(
+ "DataStreamTemporalTableJoin",
+ unaryNode("DataStreamCalc",
+ streamTableNode(orders),
+ term("select", "o_amount, o_currency, o_rowtime"),
+ term("where", ">(o_amount, 1000)")),
+ streamTableNode(ratesHistory),
+ term("where",
+ "AND(" +
+ s"${TEMPORAL_JOIN_CONDITION.getName}(o_rowtime, rowtime,
currency), " +
+ "=(currency, o_currency))"),
+ term("join", "o_amount", "o_currency", "o_rowtime", "currency",
"rate", "rowtime"),
+ term("joinType", "InnerJoin")
+ ),
+ term("select", "*(o_amount, rate) AS rate")
+ )
+ util.verifySql(sqlQuery, expected)
+ }
+
/**
* Test versioned joins with more complicated query.
* Important thing here is that we have complex OR join condition
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
index 463f3f06..6466e1e 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
@@ -19,7 +19,6 @@
package org.apache.flink.table.runtime.stream.sql
import java.sql.Timestamp
-
import org.apache.flink.api.scala._
import org.apache.flink.streaming.api.TimeCharacteristic
import
org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor
@@ -130,11 +129,11 @@ class TemporalJoinITCase extends
StreamingWithStateTestBase {
val orders = env
.fromCollection(ordersData)
- .assignTimestampsAndWatermarks(new TimestampExtractor[Long, String]())
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(Long, String,
Timestamp)]())
.toTable(tEnv, 'amount, 'currency, 'rowtime.rowtime)
val ratesHistory = env
.fromCollection(ratesHistoryData)
- .assignTimestampsAndWatermarks(new TimestampExtractor[String, Long]())
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(String, Long,
Timestamp)]())
.toTable(tEnv, 'currency, 'rate, 'rowtime.rowtime)
tEnv.registerTable("Orders", orders)
@@ -153,11 +152,93 @@ class TemporalJoinITCase extends
StreamingWithStateTestBase {
assertEquals(expectedOutput, StreamITCase.testResults.toSet)
}
+
+ @Test
+ def testNestedTemporalJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = StreamTableEnvironment.create(env)
+ env.setStateBackend(getStateBackend)
+ StreamITCase.clear
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+
+ val sqlQuery =
+ """
+ |SELECT
+ | o.orderId,
+ | (o.amount * p.price * r.rate) as total_price
+ |FROM
+ | Orders AS o,
+ | LATERAL TABLE (Prices(o.rowtime)) AS p,
+ | LATERAL TABLE (Rates(o.rowtime)) AS r
+ |WHERE
+ | o.productId = p.productId AND
+ | r.currency = p.currency
+ |""".stripMargin
+
+ val ordersData = new mutable.MutableList[(Long, String, Long, Timestamp)]
+ ordersData.+=((1L, "A1", 2L, new Timestamp(2L)))
+ ordersData.+=((2L, "A2", 1L, new Timestamp(3L)))
+ ordersData.+=((3L, "A4", 50L, new Timestamp(4L)))
+ ordersData.+=((4L, "A1", 3L, new Timestamp(5L)))
+ val orders = env
+ .fromCollection(ordersData)
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(Long, String,
Long, Timestamp)]())
+ .toTable(tEnv, 'orderId, 'productId, 'amount, 'rowtime.rowtime)
+
+ val ratesHistoryData = new mutable.MutableList[(String, Long, Timestamp)]
+ ratesHistoryData.+=(("US Dollar", 102L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Euro", 114L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Yen", 1L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Euro", 116L, new Timestamp(5L)))
+ ratesHistoryData.+=(("Euro", 119L, new Timestamp(7L)))
+ val ratesHistory = env
+ .fromCollection(ratesHistoryData)
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(String, Long,
Timestamp)]())
+ .toTable(tEnv, 'currency, 'rate, 'rowtime.rowtime)
+
+ val pricesHistoryData = new mutable.MutableList[(String, String, Double,
Timestamp)]
+ pricesHistoryData.+=(("A2", "US Dollar", 10.2D, new Timestamp(1L)))
+ pricesHistoryData.+=(("A1", "Euro", 11.4D, new Timestamp(1L)))
+ pricesHistoryData.+=(("A4", "Yen", 1D, new Timestamp(1L)))
+ pricesHistoryData.+=(("A1", "Euro", 11.6D, new Timestamp(5L)))
+ pricesHistoryData.+=(("A1", "Euro", 11.9D, new Timestamp(7L)))
+ val pricesHistory = env
+ .fromCollection(pricesHistoryData)
+ .assignTimestampsAndWatermarks(new TimestampExtractor[(String, String,
Double, Timestamp)]())
+ .toTable(tEnv, 'productId, 'currency, 'price, 'rowtime.rowtime)
+
+ tEnv.registerTable("Orders", orders)
+ tEnv.registerTable("RatesHistory", ratesHistory)
+ tEnv.registerFunction(
+ "Rates",
+ ratesHistory.createTemporalTableFunction("rowtime", "currency"))
+ tEnv.registerFunction(
+ "Prices",
+ pricesHistory.createTemporalTableFunction("rowtime", "productId"))
+
+ tEnv.registerTable("TemporalJoinResult", tEnv.sqlQuery(sqlQuery))
+
+ // Scan from registered table to test for interplay between
+ // LogicalCorrelateToTemporalTableJoinRule and TableScanRule
+ val result = tEnv.scan("TemporalJoinResult").toAppendStream[Row]
+ result.addSink(new StreamITCase.StringSink[Row])
+ env.execute()
+
+ val expected = List(
+ s"1,${2 * 114 * 11.4}",
+ s"2,${1 * 102 * 10.2}",
+ s"3,${50 * 1 * 1.0}",
+ s"4,${3 * 116 * 11.6}")
+ assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+ }
}
-class TimestampExtractor[T1, T2]
- extends BoundedOutOfOrdernessTimestampExtractor[(T1, T2,
Timestamp)](Time.seconds(10)) {
- override def extractTimestamp(element: (T1, T2, Timestamp)): Long = {
- element._3.getTime
+class TimestampExtractor[T <: Product]
+ extends BoundedOutOfOrdernessTimestampExtractor[T](Time.seconds(10)) {
+ override def extractTimestamp(element: T): Long = element match {
+ case (_, _, ts: Timestamp) => ts.getTime
+ case (_, _, _, ts: Timestamp) => ts.getTime
+ case _ => throw new IllegalArgumentException(
+ "Expected the last element in a tuple to be of a Timestamp type.")
}
}