This is an automated email from the ASF dual-hosted git repository.
jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 4b7ca24 [FLINK-13003][table-planner-blink] Support Temporal
TableFunction Join in processing time and event time
4b7ca24 is described below
commit 4b7ca247431e727ac0533cf767d43615bbaf07c0
Author: Jark Wu <[email protected]>
AuthorDate: Sun Jun 23 20:19:17 2019 +0800
[FLINK-13003][table-planner-blink] Support Temporal TableFunction Join in
processing time and event time
This closes #8901
---
.../flink/table/api/StreamTableEnvironment.scala | 16 +-
.../apache/flink/table/api/TableEnvironment.scala | 15 +-
.../org/apache/flink/table/api/TableImpl.scala | 18 +-
.../table/calcite/RelTimeIndicatorConverter.scala | 48 ++-
.../logical/FlinkLogicalTableFunctionScan.scala | 24 +-
.../physical/stream/StreamExecTemporalJoin.scala | 425 +++++++++++++++++++++
.../table/plan/rules/FlinkBatchRuleSets.scala | 2 +-
.../table/plan/rules/FlinkStreamRuleSets.scala | 4 +-
...relateToJoinFromTemporalTableFunctionRule.scala | 230 +++++++++++
...icalCorrelateToJoinFromTemporalTableRule.scala} | 15 +-
.../rules/physical/stream/StreamExecJoinRule.scala | 6 +-
.../stream/StreamExecTemporalJoinRule.scala | 102 +++++
.../flink/table/plan/util/RexDefaultVisitor.scala | 66 ++++
.../flink/table/plan/util/TemporalJoinUtil.scala | 105 +++++
.../flink/table/plan/util/WindowJoinUtil.scala | 8 +-
.../plan/stream/sql/join/TemporalJoinTest.xml | 101 +++++
.../plan/batch/sql/join/TemporalJoinTest.scala | 110 ++++++
.../plan/stream/sql/join/TemporalJoinTest.scala | 130 +++++++
...AbstractTwoInputStreamOperatorWithTTLTest.scala | 185 +++++++++
.../table/runtime/harness/HarnessTestBase.scala | 10 +-
.../runtime/stream/sql/TemporalJoinITCase.scala | 168 ++++++++
.../flink/table/dataformat/util/BaseRowUtil.java | 4 +
.../join/stream/AbstractStreamingJoinOperator.java | 4 +-
...seTwoInputStreamOperatorWithStateRetention.java | 162 ++++++++
.../temporal/TemporalProcessTimeJoinOperator.java | 126 ++++++
.../join/temporal/TemporalRowTimeJoinOperator.java | 400 +++++++++++++++++++
26 files changed, 2428 insertions(+), 56 deletions(-)
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
index 24b7a20..3a5edd3 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
@@ -487,14 +487,14 @@ abstract class StreamTableEnvironment(
}
}
- fields.zipWithIndex.foreach {
- case ("rowtime", idx) =>
- extractRowtime(idx, "rowtime", None)
-
- case ("proctime", idx) =>
- extractProctime(idx, "proctime")
-
- case (name, _) => fieldNames = name :: fieldNames
+ fields.zipWithIndex.foreach { case (name, idx) =>
+ if (name.endsWith("rowtime")) {
+ extractRowtime(idx, name, None)
+ } else if (name.endsWith("proctime")) {
+ extractProctime(idx, name)
+ } else {
+ fieldNames = name :: fieldNames
+ }
}
if (rowtime.isDefined && fieldNames.contains(rowtime.get._2)) {
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
index 9b43156..01fed17 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
@@ -696,14 +696,15 @@ abstract class TableEnvironment(
// determine schema definition mode (by position or by name)
val isRefByPos = isReferenceByPosition(t, fields)
- fields.zipWithIndex flatMap {
- case ("proctime" | "rowtime", _) =>
- None
- case (name, idx) =>
- if (isRefByPos) {
- Some((idx, name))
+ fields.zipWithIndex flatMap { case (name, idx) =>
+ if (name.endsWith("rowtime") || name.endsWith("proctime")) {
+ None
} else {
- referenceByName(name, t).map((_, name))
+ if (isRefByPos) {
+ Some((idx, name))
+ } else {
+ referenceByName(name, t).map((_, name))
+ }
}
}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableImpl.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableImpl.scala
index dda0a7f..4dadc1a 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableImpl.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/api/TableImpl.scala
@@ -18,8 +18,8 @@
package org.apache.flink.table.api
-import org.apache.flink.table.expressions.Expression
-import org.apache.flink.table.functions.TemporalTableFunction
+import org.apache.flink.table.expressions.{Expression,
FieldReferenceExpression}
+import org.apache.flink.table.functions.{TemporalTableFunction,
TemporalTableFunctionImpl}
import org.apache.flink.table.operations.QueryOperation
/**
@@ -48,8 +48,18 @@ class TableImpl(val tableEnv: TableEnvironment,
operationTree: QueryOperation) e
override def select(fields: Expression*): Table = ???
override def createTemporalTableFunction(
- timeAttribute: String,
- primaryKey: String): TemporalTableFunction = ???
+ timeAttribute: String,
+ primaryKey: String): TemporalTableFunction = {
+ val resolvedTimeAttribute = resolveExpression(timeAttribute)
+ val resolvedPrimaryKey = resolveExpression(primaryKey)
+ TemporalTableFunctionImpl.create(operationTree, resolvedTimeAttribute,
resolvedPrimaryKey)
+ }
+
+ private def resolveExpression(name: String): FieldReferenceExpression = {
+ val idx = tableSchema.getFieldNames.indexOf(name)
+ val fieldType = tableSchema.getFieldDataTypes()(idx)
+ new FieldReferenceExpression(name, fieldType, 0, idx)
+ }
override def createTemporalTableFunction(
timeAttribute: Expression,
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
index 0447e79..289b149 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/calcite/RelTimeIndicatorConverter.scala
@@ -24,7 +24,6 @@ import
org.apache.flink.table.functions.sql.FlinkSqlOperatorTable
import org.apache.flink.table.plan.nodes.calcite._
import org.apache.flink.table.plan.schema.TimeIndicatorRelDataType
import org.apache.flink.table.types.logical.TimestampType
-
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core._
import org.apache.calcite.rel.logical._
@@ -32,6 +31,7 @@ import org.apache.calcite.rel.{RelNode, RelShuttle}
import org.apache.calcite.rex._
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.flink.table.plan.util.TemporalJoinUtil
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
@@ -168,28 +168,38 @@ class RelTimeIndicatorConverter(rexBuilder: RexBuilder)
extends RelShuttle {
val left = join.getLeft.accept(this)
val right = join.getRight.accept(this)
- // TODO supports temporal join
- val newCondition = join.getCondition.accept(new RexShuttle {
- private val leftFieldCount = left.getRowType.getFieldCount
- private val leftFields = left.getRowType.getFieldList.toList
- private val leftRightFields =
- (left.getRowType.getFieldList ++ right.getRowType.getFieldList).toList
-
- override def visitInputRef(inputRef: RexInputRef): RexNode = {
- if (isTimeIndicatorType(inputRef.getType)) {
- val fields = if (inputRef.getIndex < leftFieldCount) {
- leftFields
+ if (TemporalJoinUtil.containsTemporalJoinCondition(join.getCondition)) {
+ // temporal table function join
+ val rewrittenTemporalJoin = join.copy(join.getTraitSet, List(left,
right))
+
+ // Materialize all of the time attributes from the right side of
temporal join
+ val indicesToMaterialize = (left.getRowType.getFieldCount until
+ rewrittenTemporalJoin.getRowType.getFieldCount).toSet
+
+ materializerUtils.projectAndMaterializeFields(rewrittenTemporalJoin,
indicesToMaterialize)
+ } else {
+ val newCondition = join.getCondition.accept(new RexShuttle {
+ private val leftFieldCount = left.getRowType.getFieldCount
+ private val leftFields = left.getRowType.getFieldList.toList
+ private val leftRightFields =
+ (left.getRowType.getFieldList ++
right.getRowType.getFieldList).toList
+
+ override def visitInputRef(inputRef: RexInputRef): RexNode = {
+ if (isTimeIndicatorType(inputRef.getType)) {
+ val fields = if (inputRef.getIndex < leftFieldCount) {
+ leftFields
+ } else {
+ leftRightFields
+ }
+ RexInputRef.of(inputRef.getIndex, fields)
} else {
- leftRightFields
+ super.visitInputRef(inputRef)
}
- RexInputRef.of(inputRef.getIndex, fields)
- } else {
- super.visitInputRef(inputRef)
}
- }
- })
+ })
- LogicalJoin.create(left, right, newCondition, join.getVariablesSet,
join.getJoinType)
+ LogicalJoin.create(left, right, newCondition, join.getVariablesSet,
join.getJoinType)
+ }
}
override def visit(correlate: LogicalCorrelate): RelNode = {
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala
index 22cb665..2b96e44 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalTableFunctionScan.scala
@@ -19,7 +19,6 @@
package org.apache.flink.table.plan.nodes.logical
import org.apache.flink.table.plan.nodes.FlinkConventions
-
import com.google.common.collect.ImmutableList
import org.apache.calcite.plan.{Convention, RelOptCluster, RelOptRuleCall,
RelTraitSet}
import org.apache.calcite.rel.RelNode
@@ -28,9 +27,11 @@ import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.core.TableFunctionScan
import org.apache.calcite.rel.logical.LogicalTableFunctionScan
import org.apache.calcite.rel.metadata.RelColumnMapping
-import org.apache.calcite.rex.{RexLiteral, RexNode, RexUtil}
+import org.apache.calcite.rex.{RexCall, RexLiteral, RexNode, RexUtil}
import org.apache.calcite.sql.SemiJoinType
import org.apache.calcite.util.ImmutableBitSet
+import org.apache.flink.table.functions.TemporalTableFunction
+import org.apache.flink.table.functions.utils.TableSqlFunction
import java.lang.reflect.Type
import java.util
@@ -85,8 +86,23 @@ class FlinkLogicalTableFunctionScanConverter
"FlinkLogicalTableFunctionScanConverter") {
override def matches(call: RelOptRuleCall): Boolean = {
- // TODO This rule do not match to TemporalTableFunction
- super.matches(call)
+ val logicalTableFunction: LogicalTableFunctionScan = call.rel(0)
+
+ !isTemporalTableFunctionCall(logicalTableFunction)
+ }
+
+ private def isTemporalTableFunctionCall(
+ logicalTableFunction: LogicalTableFunctionScan): Boolean = {
+
+ if (!logicalTableFunction.getCall.isInstanceOf[RexCall]) {
+ return false
+ }
+ val rexCall = logicalTableFunction.getCall.asInstanceOf[RexCall]
+ if (!rexCall.getOperator.isInstanceOf[TableSqlFunction]) {
+ return false
+ }
+ val tableFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+ tableFunction.getTableFunction.isInstanceOf[TemporalTableFunction]
}
def convert(rel: RelNode): RelNode = {
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecTemporalJoin.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecTemporalJoin.scala
new file mode 100644
index 0000000..e165395
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/nodes/physical/stream/StreamExecTemporalJoin.scala
@@ -0,0 +1,425 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.nodes.physical.stream
+
+import org.apache.flink.api.dag.Transformation
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator
+import org.apache.flink.streaming.api.transformations.TwoInputTransformation
+import org.apache.flink.table.api.{StreamTableEnvironment, TableConfig,
TableException, ValidationException}
+import org.apache.flink.table.calcite.FlinkTypeFactory
+import
org.apache.flink.table.calcite.FlinkTypeFactory.{isProctimeIndicatorType,
isRowtimeIndicatorType}
+import org.apache.flink.table.codegen.{CodeGeneratorContext,
ExprCodeGenerator, FunctionCodeGenerator}
+import org.apache.flink.table.dataformat.BaseRow
+import org.apache.flink.table.generated.GeneratedJoinCondition
+import org.apache.flink.table.plan.nodes.common.CommonPhysicalJoin
+import org.apache.flink.table.plan.nodes.exec.{ExecNode, StreamExecNode}
+import
org.apache.flink.table.plan.util.TemporalJoinUtil.TEMPORAL_JOIN_CONDITION
+import org.apache.flink.table.plan.util.{InputRefVisitor, KeySelectorUtil,
RelExplainUtil, TemporalJoinUtil}
+import
org.apache.flink.table.runtime.join.temporal.{TemporalProcessTimeJoinOperator,
TemporalRowTimeJoinOperator}
+import org.apache.flink.table.runtime.keyselector.BaseRowKeySelector
+import org.apache.flink.table.types.logical.RowType
+import org.apache.flink.table.typeutils.BaseRowTypeInfo
+import org.apache.flink.util.Preconditions.checkState
+
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.core.{Join, JoinInfo, JoinRelType}
+import org.apache.calcite.rex._
+
+import java.util
+
+import scala.collection.JavaConversions._
+
+class StreamExecTemporalJoin(
+ cluster: RelOptCluster,
+ traitSet: RelTraitSet,
+ leftRel: RelNode,
+ rightRel: RelNode,
+ condition: RexNode,
+ joinType: JoinRelType)
+ extends CommonPhysicalJoin(cluster, traitSet, leftRel, rightRel, condition,
joinType)
+ with StreamPhysicalRel
+ with StreamExecNode[BaseRow] {
+
+ override def producesUpdates: Boolean = false
+
+ override def needsUpdatesAsRetraction(input: RelNode): Boolean = false
+
+ override def consumesRetractions: Boolean = false
+
+ override def producesRetractions: Boolean = false
+
+ override def requireWatermark: Boolean = {
+ val nonEquiJoinRex = getJoinInfo.getRemaining(cluster.getRexBuilder)
+
+ var rowtimeJoin: Boolean = false
+ val visitor = new RexVisitorImpl[Unit](true) {
+ override def visitCall(call: RexCall): Unit = {
+ if (call.getOperator == TEMPORAL_JOIN_CONDITION) {
+ rowtimeJoin = TemporalJoinUtil.isRowtimeCall(call)
+ } else {
+ call.getOperands.foreach(node => node.accept(this))
+ }
+ }
+ }
+ nonEquiJoinRex.accept(visitor)
+ rowtimeJoin
+ }
+
+ override def copy(
+ traitSet: RelTraitSet,
+ conditionExpr: RexNode,
+ left: RelNode,
+ right: RelNode,
+ joinType: JoinRelType,
+ semiJoinDone: Boolean): Join = {
+ new StreamExecTemporalJoin(
+ cluster,
+ traitSet,
+ left,
+ right,
+ conditionExpr,
+ joinType)
+ }
+
+ //~ ExecNode methods
-----------------------------------------------------------
+
+ override def getInputNodes: util.List[ExecNode[StreamTableEnvironment, _]] =
{
+ getInputs.map(_.asInstanceOf[ExecNode[StreamTableEnvironment, _]])
+ }
+
+ override def replaceInputNode(
+ ordinalInParent: Int,
+ newInputNode: ExecNode[StreamTableEnvironment, _]): Unit = {
+ replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode])
+ }
+
+ override protected def translateToPlanInternal(
+ tableEnv: StreamTableEnvironment): Transformation[BaseRow] = {
+
+ validateKeyTypes()
+
+ val returnType = FlinkTypeFactory.toLogicalRowType(getRowType)
+
+ val joinTranslator = StreamExecTemporalJoinToCoProcessTranslator.create(
+ this.toString,
+ tableEnv.getConfig,
+ returnType,
+ leftRel,
+ rightRel,
+ getJoinInfo,
+ cluster.getRexBuilder)
+
+ val joinOperator = joinTranslator.getJoinOperator(joinType,
returnType.getFieldNames)
+ val leftKeySelector = joinTranslator.getLeftKeySelector
+ val rightKeySelector = joinTranslator.getRightKeySelector
+
+ val leftTransform = getInputNodes.get(0).translateToPlan(tableEnv)
+ .asInstanceOf[Transformation[BaseRow]]
+ val rightTransform = getInputNodes.get(1).translateToPlan(tableEnv)
+ .asInstanceOf[Transformation[BaseRow]]
+
+ val ret = new TwoInputTransformation[BaseRow, BaseRow, BaseRow](
+ leftTransform,
+ rightTransform,
+ getJoinOperatorName,
+ joinOperator,
+ BaseRowTypeInfo.of(returnType),
+ getResource.getParallelism)
+
+ if (getResource.getMaxParallelism > 0) {
+ ret.setMaxParallelism(getResource.getMaxParallelism)
+ }
+
+ // set KeyType and Selector for state
+ ret.setStateKeySelectors(leftKeySelector, rightKeySelector)
+
ret.setStateKeyType(leftKeySelector.asInstanceOf[ResultTypeQueryable[_]].getProducedType)
+ ret
+ }
+
+ private def validateKeyTypes(): Unit = {
+ // at least one equality expression
+ val leftFields = left.getRowType.getFieldList
+ val rightFields = right.getRowType.getFieldList
+
+ getJoinInfo.pairs().toList.foreach(pair => {
+ val leftKeyType = leftFields.get(pair.source).getType.getSqlTypeName
+ val rightKeyType = rightFields.get(pair.target).getType.getSqlTypeName
+ // check if keys are compatible
+ if (leftKeyType != rightKeyType) {
+ throw new TableException(
+ "Equality join predicate on incompatible types.\n" +
+ s"\tLeft: $left,\n" +
+ s"\tRight: $right,\n" +
+ s"\tCondition: (${RelExplainUtil.expressionToString(
+ getCondition, inputRowType, getExpressionString)})"
+ )
+ }
+ })
+ }
+
+ private def getJoinOperatorName: String = {
+ val where = RelExplainUtil.expressionToString(getCondition, inputRowType,
getExpressionString)
+ val select = getRowType.getFieldNames.mkString(", ")
+ s"TemporalTableJoin(where: ($where), select: ($select)"
+ }
+}
+
+
+/**
+ * @param rightTimeAttributeInputReference is defined only for event time
joins.
+ */
+class StreamExecTemporalJoinToCoProcessTranslator private (
+ textualRepresentation: String,
+ config: TableConfig,
+ returnType: RowType,
+ leftInputType: RowType,
+ rightInputType: RowType,
+ joinInfo: JoinInfo,
+ rexBuilder: RexBuilder,
+ leftTimeAttributeInputReference: Int,
+ rightTimeAttributeInputReference: Option[Int],
+ remainingNonEquiJoinPredicates: RexNode) {
+
+ val nonEquiJoinPredicates: Option[RexNode] =
Some(remainingNonEquiJoinPredicates)
+
+ def getLeftKeySelector: BaseRowKeySelector = {
+ KeySelectorUtil.getBaseRowSelector(
+ joinInfo.leftKeys.toIntArray,
+ BaseRowTypeInfo.of(leftInputType)
+ )
+ }
+
+ def getRightKeySelector: BaseRowKeySelector = {
+ KeySelectorUtil.getBaseRowSelector(
+ joinInfo.rightKeys.toIntArray,
+ BaseRowTypeInfo.of(rightInputType)
+ )
+ }
+
+ def getJoinOperator(
+ joinType: JoinRelType,
+ returnFieldNames: Seq[String]): TwoInputStreamOperator[BaseRow, BaseRow,
BaseRow] = {
+
+ // input must not be nullable, because the runtime join function will make
sure
+ // the code-generated function won't process null inputs
+ val ctx = CodeGeneratorContext(config)
+ val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
+ .bindInput(leftInputType)
+ .bindSecondInput(rightInputType)
+
+ val body = if (nonEquiJoinPredicates.isEmpty) {
+ // only equality condition
+ "return true;"
+ } else {
+ val condition =
exprGenerator.generateExpression(nonEquiJoinPredicates.get)
+ s"""
+ |${condition.code}
+ |return ${condition.resultTerm};
+ |""".stripMargin
+ }
+
+ val generatedJoinCondition = FunctionCodeGenerator.generateJoinCondition(
+ ctx,
+ "ConditionFunction",
+ body)
+
+ createJoinOperator(config, joinType, generatedJoinCondition)
+ }
+
+ protected def createJoinOperator(
+ tableConfig: TableConfig,
+ joinType: JoinRelType,
+ generatedJoinCondition: GeneratedJoinCondition)
+ : TwoInputStreamOperator[BaseRow, BaseRow, BaseRow] = {
+
+ val minRetentionTime = tableConfig.getMinIdleStateRetentionTime
+ val maxRetentionTime = tableConfig.getMaxIdleStateRetentionTime
+ joinType match {
+ case JoinRelType.INNER =>
+ if (rightTimeAttributeInputReference.isDefined) {
+ new TemporalRowTimeJoinOperator(
+ BaseRowTypeInfo.of(leftInputType),
+ BaseRowTypeInfo.of(rightInputType),
+ generatedJoinCondition,
+ leftTimeAttributeInputReference,
+ rightTimeAttributeInputReference.get,
+ minRetentionTime,
+ maxRetentionTime)
+ } else {
+ new TemporalProcessTimeJoinOperator(
+ BaseRowTypeInfo.of(rightInputType),
+ generatedJoinCondition,
+ minRetentionTime,
+ maxRetentionTime)
+ }
+ case _ =>
+ throw new ValidationException(
+ s"Only ${JoinRelType.INNER} temporal join is supported in
[$textualRepresentation]")
+ }
+ }
+}
+
+object StreamExecTemporalJoinToCoProcessTranslator {
+ def create(
+ textualRepresentation: String,
+ config: TableConfig,
+ returnType: RowType,
+ leftInput: RelNode,
+ rightInput: RelNode,
+ joinInfo: JoinInfo,
+ rexBuilder: RexBuilder): StreamExecTemporalJoinToCoProcessTranslator = {
+
+ checkState(
+ !joinInfo.isEqui,
+ "Missing %s in join condition",
+ TEMPORAL_JOIN_CONDITION)
+
+ val leftType = FlinkTypeFactory.toLogicalRowType(leftInput.getRowType)
+ val rightType = FlinkTypeFactory.toLogicalRowType(rightInput.getRowType)
+ val nonEquiJoinRex: RexNode = joinInfo.getRemaining(rexBuilder)
+ val temporalJoinConditionExtractor = new TemporalJoinConditionExtractor(
+ textualRepresentation,
+ leftType.getFieldCount,
+ joinInfo,
+ rexBuilder)
+
+ val remainingNonEquiJoinPredicates =
temporalJoinConditionExtractor.apply(nonEquiJoinRex)
+
+ checkState(
+ temporalJoinConditionExtractor.leftTimeAttribute.isDefined &&
+ temporalJoinConditionExtractor.rightPrimaryKeyExpression.isDefined,
+ "Missing %s in join condition",
+ TEMPORAL_JOIN_CONDITION)
+
+ new StreamExecTemporalJoinToCoProcessTranslator(
+ textualRepresentation,
+ config,
+ returnType,
+ leftType,
+ rightType,
+ joinInfo,
+ rexBuilder,
+ extractInputReference(
+ temporalJoinConditionExtractor.leftTimeAttribute.get,
+ textualRepresentation),
+ temporalJoinConditionExtractor.rightTimeAttribute.map(
+ rightTimeAttribute =>
+ extractInputReference(
+ rightTimeAttribute,
+ textualRepresentation
+ ) - leftType.getFieldCount),
+ remainingNonEquiJoinPredicates)
+ }
+
+ private def extractInputReference(rexNode: RexNode, textualRepresentation:
String): Int = {
+ val inputReferenceVisitor = new InputRefVisitor
+ rexNode.accept(inputReferenceVisitor)
+ checkState(
+ inputReferenceVisitor.getFields.length == 1,
+ "Failed to find input reference in [%s]",
+ textualRepresentation)
+ inputReferenceVisitor.getFields.head
+ }
+
+ private class TemporalJoinConditionExtractor(
+ textualRepresentation: String,
+ rightKeysStartingOffset: Int,
+ joinInfo: JoinInfo,
+ rexBuilder: RexBuilder)
+
+ extends RexShuttle {
+
+ var leftTimeAttribute: Option[RexNode] = None
+
+ var rightTimeAttribute: Option[RexNode] = None
+
+ var rightPrimaryKeyExpression: Option[RexNode] = None
+
+ override def visitCall(call: RexCall): RexNode = {
+ if (call.getOperator != TEMPORAL_JOIN_CONDITION) {
+ return super.visitCall(call)
+ }
+
+ checkState(
+ leftTimeAttribute.isEmpty
+ && rightPrimaryKeyExpression.isEmpty
+ && rightTimeAttribute.isEmpty,
+ "Multiple %s functions in [%s]",
+ TEMPORAL_JOIN_CONDITION,
+ textualRepresentation)
+
+ if (TemporalJoinUtil.isRowtimeCall(call)) {
+ leftTimeAttribute = Some(call.getOperands.get(0))
+ rightTimeAttribute = Some(call.getOperands.get(1))
+
+ rightPrimaryKeyExpression =
Some(validateRightPrimaryKey(call.getOperands.get(2)))
+
+ if (!isRowtimeIndicatorType(rightTimeAttribute.get.getType)) {
+ throw new ValidationException(
+ s"Non rowtime timeAttribute [${rightTimeAttribute.get.getType}] " +
+ s"used to create TemporalTableFunction")
+ }
+ if (!isRowtimeIndicatorType(leftTimeAttribute.get.getType)) {
+ throw new ValidationException(
+ s"Non rowtime timeAttribute [${leftTimeAttribute.get.getType}] " +
+ s"passed as the argument to TemporalTableFunction")
+ }
+ }
+ else if (TemporalJoinUtil.isProctimeCall(call)) {
+ leftTimeAttribute = Some(call.getOperands.get(0))
+ rightPrimaryKeyExpression =
Some(validateRightPrimaryKey(call.getOperands.get(1)))
+
+ if (!isProctimeIndicatorType(leftTimeAttribute.get.getType)) {
+ throw new ValidationException(
+ s"Non processing timeAttribute [${leftTimeAttribute.get.getType}]
" +
+ s"passed as the argument to TemporalTableFunction")
+ }
+ }
+ else {
+ throw new IllegalStateException(
+ s"Unsupported invocation $call in [$textualRepresentation]")
+ }
+ rexBuilder.makeLiteral(true)
+ }
+
+ private def validateRightPrimaryKey(rightPrimaryKey: RexNode): RexNode = {
+ if (joinInfo.rightKeys.size() != 1) {
+ throw new ValidationException(
+ s"Only single column join key is supported. " +
+ s"Found ${joinInfo.rightKeys} in [$textualRepresentation]")
+ }
+ val rightJoinKeyInputReference = joinInfo.rightKeys.get(0) +
rightKeysStartingOffset
+
+ val rightPrimaryKeyInputReference = extractInputReference(
+ rightPrimaryKey,
+ textualRepresentation)
+
+ if (rightPrimaryKeyInputReference != rightJoinKeyInputReference) {
+ throw new ValidationException(
+ s"Join key [$rightJoinKeyInputReference] must be the same as " +
+ s"temporal table's primary key [$rightPrimaryKey] " +
+ s"in [$textualRepresentation]")
+ }
+
+ rightPrimaryKey
+ }
+ }
+}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
index 5c00719..75f1d5a 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkBatchRuleSets.scala
@@ -54,7 +54,7 @@ object FlinkBatchRuleSets {
* can create new plan nodes.
*/
val EXPAND_PLAN_RULES: RuleSet = RuleSets.ofList(
- LogicalCorrelateToTemporalTableJoinRule.INSTANCE,
+ LogicalCorrelateToJoinFromTemporalTableRule.INSTANCE,
TableScanRule.INSTANCE)
val POST_EXPAND_CLEAN_UP_RULES: RuleSet = RuleSets.ofList(
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
index 2f37753..0f8b219 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/FlinkStreamRuleSets.scala
@@ -54,7 +54,8 @@ object FlinkStreamRuleSets {
* can create new plan nodes.
*/
val EXPAND_PLAN_RULES: RuleSet = RuleSets.ofList(
- LogicalCorrelateToTemporalTableJoinRule.INSTANCE,
+ LogicalCorrelateToJoinFromTemporalTableRule.INSTANCE,
+ LogicalCorrelateToJoinFromTemporalTableFunctionRule.INSTANCE,
TableScanRule.INSTANCE)
val POST_EXPAND_CLEAN_UP_RULES: RuleSet = RuleSets.ofList(
@@ -348,6 +349,7 @@ object FlinkStreamRuleSets {
// join
StreamExecJoinRule.INSTANCE,
StreamExecWindowJoinRule.INSTANCE,
+ StreamExecTemporalJoinRule.INSTANCE,
StreamExecLookupJoinRule.SNAPSHOT_ON_TABLESCAN,
StreamExecLookupJoinRule.SNAPSHOT_ON_CALC_TABLESCAN,
// correlate
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
new file mode 100644
index 0000000..6b8f075
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.logical
+
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.calcite.FlinkRelBuilder
+import org.apache.flink.table.expressions.{FieldReferenceExpression, _}
+import org.apache.flink.table.functions.utils.TableSqlFunction
+import org.apache.flink.table.functions.{TemporalTableFunction,
TemporalTableFunctionImpl}
+import org.apache.flink.table.operations.QueryOperation
+import org.apache.flink.table.plan.util.{ExpandTableScanShuttle,
RexDefaultVisitor}
+import
org.apache.flink.table.plan.util.TemporalJoinUtil.{makeProcTimeTemporalJoinConditionCall,
makeRowTimeTemporalJoinConditionCall}
+import
org.apache.flink.table.types.logical.LogicalTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.{hasRoot,
isProctimeAttribute}
+import org.apache.flink.util.Preconditions.checkState
+import org.apache.calcite.plan.RelOptRule.{any, none, operand, some}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.core.{JoinRelType, TableFunctionScan}
+import org.apache.calcite.rel.logical.LogicalCorrelate
+import org.apache.calcite.rex._
+
+/**
+ * The initial temporal TableFunction join (LATERAL
TemporalTableFunction(o.proctime)) is
+ * a correlate. Rewrite it into a Join with a special temporal join condition
wraps time
+ * attribute and primary key information. The join will be translated into
+ *
[[org.apache.flink.table.plan.nodes.physical.stream.StreamExecTemporalJoin]] in
physical.
+ */
+class LogicalCorrelateToJoinFromTemporalTableFunctionRule
+ extends RelOptRule(
+ operand(classOf[LogicalCorrelate],
+ some(
+ operand(classOf[RelNode], any()),
+ operand(classOf[TableFunctionScan], none()))),
+ "LogicalCorrelateToJoinFromTemporalTableFunctionRule") {
+
+ private def extractNameFromTimeAttribute(timeAttribute: Expression): String
= {
+ timeAttribute match {
+ case f : FieldReferenceExpression
+ if hasRoot(f.getOutputDataType.getLogicalType,
TIMESTAMP_WITHOUT_TIME_ZONE) =>
+ f.getName
+ case _ => throw new ValidationException(
+ s"Invalid timeAttribute [$timeAttribute] in TemporalTableFunction")
+ }
+ }
+
+ private def isProctimeReference(temporalTableFunction:
TemporalTableFunctionImpl): Boolean = {
+ val fieldRef =
temporalTableFunction.getTimeAttribute.asInstanceOf[FieldReferenceExpression]
+ isProctimeAttribute(fieldRef.getOutputDataType.getLogicalType)
+ }
+
+ private def extractNameFromPrimaryKeyAttribute(expression: Expression):
String = {
+ expression match {
+ case f: FieldReferenceExpression =>
+ f.getName
+ case _ => throw new ValidationException(
+ s"Unsupported expression [$expression] as primary key. " +
+ s"Only top-level (not nested) field references are supported.")
+ }
+ }
+
+ override def onMatch(call: RelOptRuleCall): Unit = {
+ val logicalCorrelate: LogicalCorrelate = call.rel(0)
+ val leftNode: RelNode = call.rel(1)
+ val rightTableFunctionScan: TableFunctionScan = call.rel(2)
+
+ val cluster = logicalCorrelate.getCluster
+
+ new GetTemporalTableFunctionCall(cluster.getRexBuilder, leftNode)
+ .visit(rightTableFunctionScan.getCall) match {
+ case None =>
+ // Do nothing and handle standard TableFunction
+ case Some(TemporalTableFunctionCall(
+ rightTemporalTableFunction: TemporalTableFunctionImpl,
leftTimeAttribute)) =>
+
+ // If TemporalTableFunction was found, rewrite LogicalCorrelate to
TemporalJoin
+ val underlyingHistoryTable: QueryOperation = rightTemporalTableFunction
+ .getUnderlyingHistoryTable
+ val rexBuilder = cluster.getRexBuilder
+
+ val relBuilder = FlinkRelBuilder.of(cluster, leftNode.getTable)
+ val temporalTable: RelNode =
relBuilder.queryOperation(underlyingHistoryTable).build()
+ // expand QueryOperationCatalogViewTable in TableScan
+ val shuttle = new ExpandTableScanShuttle
+ val rightNode = temporalTable.accept(shuttle)
+
+ val rightTimeIndicatorExpression = createRightExpression(
+ rexBuilder,
+ leftNode,
+ rightNode,
+
extractNameFromTimeAttribute(rightTemporalTableFunction.getTimeAttribute))
+
+ val rightPrimaryKeyExpression = createRightExpression(
+ rexBuilder,
+ leftNode,
+ rightNode,
+
extractNameFromPrimaryKeyAttribute(rightTemporalTableFunction.getPrimaryKey))
+
+ relBuilder.push(leftNode)
+ relBuilder.push(rightNode)
+
+ val condition =
+ if (isProctimeReference(rightTemporalTableFunction)) {
+ makeProcTimeTemporalJoinConditionCall(
+ rexBuilder,
+ leftTimeAttribute,
+ rightPrimaryKeyExpression)
+ } else {
+ makeRowTimeTemporalJoinConditionCall(
+ rexBuilder,
+ leftTimeAttribute,
+ rightTimeIndicatorExpression,
+ rightPrimaryKeyExpression)
+ }
+ relBuilder.join(JoinRelType.INNER, condition)
+
+ call.transformTo(relBuilder.build())
+ }
+ }
+
+ private def createRightExpression(
+ rexBuilder: RexBuilder,
+ leftNode: RelNode,
+ rightNode: RelNode,
+ field: String): RexNode = {
+ val rightReferencesOffset = leftNode.getRowType.getFieldCount
+ val rightDataTypeField = rightNode.getRowType.getField(field, false, false)
+ rexBuilder.makeInputRef(
+ rightDataTypeField.getType, rightReferencesOffset +
rightDataTypeField.getIndex)
+ }
+}
+
+object LogicalCorrelateToJoinFromTemporalTableFunctionRule {
+ val INSTANCE: RelOptRule = new
LogicalCorrelateToJoinFromTemporalTableFunctionRule
+}
+
+/**
+ * Simple pojo class for extracted [[TemporalTableFunction]] with time
attribute
+ * extracted from RexNode with [[TemporalTableFunction]] call.
+ */
+case class TemporalTableFunctionCall(
+ var temporalTableFunction: TemporalTableFunction,
+ var timeAttribute: RexNode) {
+}
+
+/**
+ * Find [[TemporalTableFunction]] call and run
[[CorrelatedFieldAccessRemoval]] on it's operand.
+ */
+class GetTemporalTableFunctionCall(
+ var rexBuilder: RexBuilder,
+ var leftSide: RelNode)
+ extends RexVisitorImpl[TemporalTableFunctionCall](false) {
+
+ def visit(node: RexNode): Option[TemporalTableFunctionCall] = {
+ val result = node.accept(this)
+ if (result == null) {
+ return None
+ }
+ Some(result)
+ }
+
+ override def visitCall(rexCall: RexCall): TemporalTableFunctionCall = {
+ if (!rexCall.getOperator.isInstanceOf[TableSqlFunction]) {
+ return null
+ }
+ val tableFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction]
+
+ if (!tableFunction.getTableFunction.isInstanceOf[TemporalTableFunction]) {
+ return null
+ }
+ val temporalTableFunction =
+ tableFunction.getTableFunction.asInstanceOf[TemporalTableFunctionImpl]
+
+ checkState(
+ rexCall.getOperands.size().equals(1),
+ "TemporalTableFunction call [%s] must have exactly one argument",
+ rexCall)
+ val correlatedFieldAccessRemoval =
+ new CorrelatedFieldAccessRemoval(temporalTableFunction, rexBuilder,
leftSide)
+ TemporalTableFunctionCall(
+ temporalTableFunction,
+ rexCall.getOperands.get(0).accept(correlatedFieldAccessRemoval))
+ }
+}
+
+/**
+ * This converts field accesses like `$cor0.o_rowtime` to valid input
references
+ * for join condition context without `$cor` reference.
+ */
+class CorrelatedFieldAccessRemoval(
+ var temporalTableFunction: TemporalTableFunctionImpl,
+ var rexBuilder: RexBuilder,
+ var leftSide: RelNode) extends RexDefaultVisitor[RexNode] {
+
+ override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = {
+ val leftIndex =
leftSide.getRowType.getFieldList.indexOf(fieldAccess.getField)
+ if (leftIndex < 0) {
+ throw new IllegalStateException(
+ s"Failed to find reference to field [${fieldAccess.getField}] in node
[$leftSide]")
+ }
+ rexBuilder.makeInputRef(leftSide, leftIndex)
+ }
+
+ override def visitInputRef(inputRef: RexInputRef): RexNode = {
+ inputRef
+ }
+
+ override def visitNode(rexNode: RexNode): RexNode = {
+ throw new ValidationException(
+ s"Unsupported argument [$rexNode] " +
+ s"in ${classOf[TemporalTableFunction].getSimpleName} call of " +
+ s"[${temporalTableFunction.getUnderlyingHistoryTable}] table")
+ }
+}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableRule.scala
similarity index 74%
rename from
flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
rename to
flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableRule.scala
index 0f91a8a..27fc011 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToTemporalTableJoinRule.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableRule.scala
@@ -23,16 +23,19 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.logical.{LogicalCorrelate, LogicalFilter,
LogicalSnapshot}
/**
- * The initial temporal table join is a Correlate, rewrite it into a Join to
make the
- * join condition push-down into the Join
+ * The initial temporal table join (FOR SYSTEM_TIME AS OF) is a Correlate,
rewrite it into a Join
+ * to make join condition can be pushed-down. The join will be translated into
+ * [[org.apache.flink.table.plan.nodes.physical.stream.StreamExecLookupJoin]]
in physical and
+ * might be translated into
+ *
[[org.apache.flink.table.plan.nodes.physical.stream.StreamExecTemporalJoin]] in
the future.
*/
-class LogicalCorrelateToTemporalTableJoinRule
+class LogicalCorrelateToJoinFromTemporalTableRule
extends RelOptRule(
operand(classOf[LogicalFilter],
operand(classOf[LogicalCorrelate], some(
operand(classOf[RelNode], any()),
operand(classOf[LogicalSnapshot], any())))),
- "LogicalCorrelateToTemporalTableJoinRule") {
+ "LogicalCorrelateToJoinFromTemporalTableRule") {
override def onMatch(call: RelOptRuleCall): Unit = {
val filterOnCorrelate: LogicalFilter = call.rel(0)
@@ -52,6 +55,6 @@ class LogicalCorrelateToTemporalTableJoinRule
}
-object LogicalCorrelateToTemporalTableJoinRule {
- val INSTANCE = new LogicalCorrelateToTemporalTableJoinRule
+object LogicalCorrelateToJoinFromTemporalTableRule {
+ val INSTANCE = new LogicalCorrelateToJoinFromTemporalTableRule
}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala
index cf6514a..6986eab 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecJoinRule.scala
@@ -24,8 +24,7 @@ import
org.apache.flink.table.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.plan.nodes.FlinkConventions
import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalJoin,
FlinkLogicalRel, FlinkLogicalSnapshot}
import org.apache.flink.table.plan.nodes.physical.stream.StreamExecJoin
-import org.apache.flink.table.plan.util.WindowJoinUtil
-
+import org.apache.flink.table.plan.util.{TemporalJoinUtil, WindowJoinUtil}
import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
import org.apache.calcite.rel.RelNode
@@ -62,7 +61,8 @@ class StreamExecJoinRule
}
// this rule shouldn't match temporal table join
- if (right.isInstanceOf[FlinkLogicalSnapshot]) {
+ if (right.isInstanceOf[FlinkLogicalSnapshot] ||
+ TemporalJoinUtil.containsTemporalJoinCondition(join.getCondition)) {
return false
}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecTemporalJoinRule.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecTemporalJoinRule.scala
new file mode 100644
index 0000000..fb1da2c
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/rules/physical/stream/StreamExecTemporalJoinRule.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.rules.physical.stream
+
+import org.apache.flink.table.plan.`trait`.FlinkRelDistribution
+import org.apache.flink.table.plan.nodes.FlinkConventions
+import org.apache.flink.table.plan.nodes.logical._
+import org.apache.flink.table.plan.nodes.physical.stream.StreamExecTemporalJoin
+import
org.apache.flink.table.plan.util.TemporalJoinUtil.containsTemporalJoinCondition
+import org.apache.flink.table.plan.util.{FlinkRelOptUtil, WindowJoinUtil}
+
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.plan.RelOptRule.{any, operand}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.core.JoinRelType
+
+import java.util
+
+class StreamExecTemporalJoinRule
+ extends RelOptRule(
+ operand(
+ classOf[FlinkLogicalJoin],
+ operand(classOf[FlinkLogicalRel], any()),
+ operand(classOf[FlinkLogicalRel], any())),
+ "StreamExecTemporalJoinRule") {
+
+ override def matches(call: RelOptRuleCall): Boolean = {
+ val join = call.rel[FlinkLogicalJoin](0)
+ val joinInfo = join.analyzeCondition
+
+ if (!containsTemporalJoinCondition(join.getCondition)) {
+ return false
+ }
+
+ val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(join)
+ val (windowBounds, _) = WindowJoinUtil.extractWindowBoundsFromPredicate(
+ joinInfo.getRemaining(join.getCluster.getRexBuilder),
+ join.getLeft.getRowType.getFieldCount,
+ join.getRowType,
+ join.getCluster.getRexBuilder,
+ tableConfig)
+
+ windowBounds.isEmpty && join.getJoinType == JoinRelType.INNER
+ }
+
+ override def onMatch(call: RelOptRuleCall): Unit = {
+ val join = call.rel[FlinkLogicalJoin](0)
+ val left = call.rel[FlinkLogicalRel](1)
+ val right = call.rel[FlinkLogicalRel](2)
+
+ val traitSet: RelTraitSet =
join.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
+ val joinInfo = join.analyzeCondition
+
+ def toHashTraitByColumns(columns: util.Collection[_ <: Number],
inputTraitSets: RelTraitSet) = {
+ val distribution = if (columns.size() == 0) {
+ FlinkRelDistribution.SINGLETON
+ } else {
+ FlinkRelDistribution.hash(columns)
+ }
+ inputTraitSets.
+ replace(FlinkConventions.STREAM_PHYSICAL).
+ replace(distribution)
+ }
+ val (leftRequiredTrait, rightRequiredTrait) = (
+ toHashTraitByColumns(joinInfo.leftKeys, left.getTraitSet),
+ toHashTraitByColumns(joinInfo.rightKeys, right.getTraitSet))
+
+ val convLeft: RelNode = RelOptRule.convert(left, leftRequiredTrait)
+ val convRight: RelNode = RelOptRule.convert(right, rightRequiredTrait)
+
+
+ val temporalJoin = new StreamExecTemporalJoin(
+ join.getCluster,
+ traitSet,
+ convLeft,
+ convRight,
+ join.getCondition,
+ join.getJoinType)
+
+ call.transformTo(temporalJoin)
+ }
+}
+
+object StreamExecTemporalJoinRule {
+ val INSTANCE: RelOptRule = new StreamExecTemporalJoinRule
+}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RexDefaultVisitor.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RexDefaultVisitor.scala
new file mode 100644
index 0000000..7c44616
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/RexDefaultVisitor.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util
+
+import org.apache.calcite.rex._
+
+/**
+ * Implementation of [[RexVisitor]] that redirects all calls into generic
+ * [[RexDefaultVisitor#visitNode(org.apache.calcite.rex.RexNode)]] method.
+ */
+abstract class RexDefaultVisitor[R] extends RexVisitor[R] {
+
+ override def visitFieldAccess(fieldAccess: RexFieldAccess): R =
+ visitNode(fieldAccess)
+
+ override def visitCall(call: RexCall): R =
+ visitNode(call)
+
+ override def visitInputRef(inputRef: RexInputRef): R =
+ visitNode(inputRef)
+
+ override def visitOver(over: RexOver): R =
+ visitNode(over)
+
+ override def visitCorrelVariable(correlVariable: RexCorrelVariable): R =
+ visitNode(correlVariable)
+
+ override def visitLocalRef(localRef: RexLocalRef): R =
+ visitNode(localRef)
+
+ override def visitDynamicParam(dynamicParam: RexDynamicParam): R =
+ visitNode(dynamicParam)
+
+ override def visitRangeRef(rangeRef: RexRangeRef): R =
+ visitNode(rangeRef)
+
+ override def visitTableInputRef(tableRef: RexTableInputRef): R =
+ visitNode(tableRef)
+
+ override def visitPatternFieldRef(patternFieldRef: RexPatternFieldRef): R =
+ visitNode(patternFieldRef)
+
+ override def visitSubQuery(subQuery: RexSubQuery): R =
+ visitNode(subQuery)
+
+ override def visitLiteral(literal: RexLiteral): R =
+ visitNode(literal)
+
+ def visitNode(rexNode: RexNode): R
+}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/TemporalJoinUtil.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/TemporalJoinUtil.scala
new file mode 100644
index 0000000..996a775
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/TemporalJoinUtil.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan.util
+
+import org.apache.calcite.rex._
+import org.apache.calcite.sql.`type`.{OperandTypes, ReturnTypes}
+import org.apache.calcite.sql.{SqlFunction, SqlFunctionCategory, SqlKind}
+import org.apache.flink.util.Preconditions.checkArgument
+
+/**
+ * Utilities for temporal table join
+ */
+object TemporalJoinUtil {
+
+ //
----------------------------------------------------------------------------------------
+ // Temporal TableFunction Join Utilities
+ //
----------------------------------------------------------------------------------------
+
+ /**
+ * [[TEMPORAL_JOIN_CONDITION]] is a specific condition which correctly
defines
+ * references to rightTimeAttribute, rightPrimaryKeyExpression and
leftTimeAttribute.
+ * The condition is used to mark this is a temporal tablefunction join.
+ * Later rightTimeAttribute, rightPrimaryKeyExpression and
leftTimeAttribute will be
+ * extracted from the condition.
+ */
+ val TEMPORAL_JOIN_CONDITION = new SqlFunction(
+ "__TEMPORAL_JOIN_CONDITION",
+ SqlKind.OTHER_FUNCTION,
+ ReturnTypes.BOOLEAN_NOT_NULL,
+ null,
+ OperandTypes.or(
+ OperandTypes.sequence(
+ "'(LEFT_TIME_ATTRIBUTE, RIGHT_TIME_ATTRIBUTE, PRIMARY_KEY)'",
+ OperandTypes.DATETIME,
+ OperandTypes.DATETIME,
+ OperandTypes.ANY),
+ OperandTypes.sequence(
+ "'(LEFT_TIME_ATTRIBUTE, PRIMARY_KEY)'",
+ OperandTypes.DATETIME,
+ OperandTypes.ANY)),
+ SqlFunctionCategory.SYSTEM)
+
+ def isRowtimeCall(call: RexCall): Boolean = {
+ checkArgument(call.getOperator == TEMPORAL_JOIN_CONDITION)
+ call.getOperands.size() == 3
+ }
+
+ def isProctimeCall(call: RexCall): Boolean = {
+ checkArgument(call.getOperator == TEMPORAL_JOIN_CONDITION)
+ call.getOperands.size() == 2
+ }
+
+ def makeRowTimeTemporalJoinConditionCall(
+ rexBuilder: RexBuilder,
+ leftTimeAttribute: RexNode,
+ rightTimeAttribute: RexNode,
+ rightPrimaryKeyExpression: RexNode): RexNode = {
+ rexBuilder.makeCall(
+ TEMPORAL_JOIN_CONDITION,
+ leftTimeAttribute,
+ rightTimeAttribute,
+ rightPrimaryKeyExpression)
+ }
+
+ def makeProcTimeTemporalJoinConditionCall(
+ rexBuilder: RexBuilder,
+ leftTimeAttribute: RexNode,
+ rightPrimaryKeyExpression: RexNode): RexNode = {
+ rexBuilder.makeCall(
+ TEMPORAL_JOIN_CONDITION,
+ leftTimeAttribute,
+ rightPrimaryKeyExpression)
+ }
+
+ def containsTemporalJoinCondition(condition: RexNode): Boolean = {
+ var hasTemporalJoinCondition: Boolean = false
+ condition.accept(new RexVisitorImpl[Void](true) {
+ override def visitCall(call: RexCall): Void = {
+ if (call.getOperator != TEMPORAL_JOIN_CONDITION) {
+ super.visitCall(call)
+ } else {
+ hasTemporalJoinCondition = true
+ null
+ }
+ }
+ })
+ hasTemporalJoinCondition
+ }
+
+}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/WindowJoinUtil.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/WindowJoinUtil.scala
index e534a2c..ab7c4ee 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/WindowJoinUtil.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/plan/util/WindowJoinUtil.scala
@@ -392,7 +392,13 @@ object WindowJoinUtil {
private def accessesNonTimeAttribute(expr: RexNode, inputType: RelDataType):
Boolean = {
expr match {
case ref: RexInputRef =>
- val accessedType = inputType.getFieldList.get(ref.getIndex).getType
+ var accessedType: RelDataType = null
+ try {
+ accessedType = inputType.getFieldList.get(ref.getIndex).getType
+ } catch {
+ case e =>
+ e.printStackTrace()
+ }
accessedType match {
case _: TimeIndicatorRelDataType => false
case _ => true
diff --git
a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/TemporalJoinTest.xml
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/TemporalJoinTest.xml
new file mode 100644
index 0000000..9533bb2
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/plan/stream/sql/join/TemporalJoinTest.xml
@@ -0,0 +1,101 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testSimpleJoin">
+ <Resource name="sql">
+ <![CDATA[SELECT o_amount * rate as rate FROM Orders AS o, LATERAL TABLE
(Rates(o.o_rowtime)) AS r WHERE currency = o_currency]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(rate=[*($0, $4)])
++- LogicalFilter(condition=[=($3, $1)])
+ +- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{2}])
+ :- LogicalTableScan(table=[[default_catalog, default_database, Orders]])
+ +- LogicalTableFunctionScan(invocation=[Rates($cor0.o_rowtime)],
rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME
ATTRIBUTE(ROWTIME) rowtime)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[*(o_amount, rate) AS rate])
++- TemporalJoin(joinType=[InnerJoin],
where=[AND(__TEMPORAL_JOIN_CONDITION(o_rowtime, rowtime, currency), =(currency,
o_currency))], select=[o_amount, o_currency, o_rowtime, currency, rate,
rowtime])
+ :- Exchange(distribution=[hash[o_currency]])
+ : +- DataStreamScan(table=[[default_catalog, default_database, Orders]],
fields=[o_amount, o_currency, o_rowtime])
+ +- Exchange(distribution=[hash[currency]])
+ +- DataStreamScan(table=[[default_catalog, default_database,
RatesHistory]], fields=[currency, rate, rowtime])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testComplexJoin">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM (SELECT o_amount * rate as rate, secondary_key as
secondary_key FROM Orders AS o, LATERAL TABLE (Rates(o_rowtime)) AS r WHERE
currency = o_currency OR secondary_key = o_secondary_key), Table3 WHERE
t3_secondary_key = secondary_key]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(rate=[$0], secondary_key=[$1], t3_comment=[$2],
t3_secondary_key=[$3])
++- LogicalFilter(condition=[=($3, $1)])
+ +- LogicalJoin(condition=[true], joinType=[inner])
+ :- LogicalProject(rate=[*($2, $8)], secondary_key=[$9])
+ : +- LogicalFilter(condition=[OR(=($7, $3), =($9, $4))])
+ : +- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{0}])
+ : :- LogicalTableScan(table=[[default_catalog, default_database,
Orders]])
+ : +-
LogicalTableFunctionScan(invocation=[Rates($cor0.o_rowtime)],
rowType=[RecordType(TIME ATTRIBUTE(ROWTIME) rowtime, VARCHAR(2147483647)
comment, VARCHAR(2147483647) currency, INTEGER rate, INTEGER secondary_key)],
elementType=[class [Ljava.lang.Object;])
+ +- LogicalTableScan(table=[[default_catalog, default_database, Table3]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Join(joinType=[InnerJoin], where=[=(t3_secondary_key, secondary_key)],
select=[rate, secondary_key, t3_comment, t3_secondary_key],
leftInputSpec=[NoUniqueKey], rightInputSpec=[NoUniqueKey])
+:- Exchange(distribution=[hash[secondary_key]])
+: +- Calc(select=[*(o_amount, rate) AS rate, secondary_key])
+: +- TemporalJoin(joinType=[InnerJoin],
where=[AND(__TEMPORAL_JOIN_CONDITION(o_rowtime, rowtime, currency),
OR(=(currency, o_currency), =(secondary_key, o_secondary_key)))],
select=[o_rowtime, o_amount, o_currency, o_secondary_key, rowtime, currency,
rate, secondary_key])
+: :- Exchange(distribution=[single])
+: : +- Calc(select=[o_rowtime, o_amount, o_currency, o_secondary_key])
+: : +- DataStreamScan(table=[[default_catalog, default_database,
Orders]], fields=[o_rowtime, o_comment, o_amount, o_currency, o_secondary_key])
+: +- Exchange(distribution=[single])
+: +- Calc(select=[rowtime, currency, rate, secondary_key],
where=[>(rate, 110)])
+: +- DataStreamScan(table=[[default_catalog, default_database,
RatesHistory]], fields=[rowtime, comment, currency, rate, secondary_key])
++- Exchange(distribution=[hash[t3_secondary_key]])
+ +- DataStreamScan(table=[[default_catalog, default_database, Table3]],
fields=[t3_comment, t3_secondary_key])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testSimpleProctimeJoin">
+ <Resource name="sql">
+ <![CDATA[SELECT o_amount * rate as rate FROM ProctimeOrders AS o,
LATERAL TABLE (ProctimeRates(o.o_proctime)) AS r WHERE currency = o_currency]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(rate=[*($0, $4)])
++- LogicalFilter(condition=[=($3, $1)])
+ +- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{2}])
+ :- LogicalTableScan(table=[[default_catalog, default_database,
ProctimeOrders]])
+ +-
LogicalTableFunctionScan(invocation=[ProctimeRates($cor0.o_proctime)],
rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME
ATTRIBUTE(PROCTIME) proctime)], elementType=[class [Ljava.lang.Object;])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+Calc(select=[*(o_amount, rate) AS rate])
++- TemporalJoin(joinType=[InnerJoin],
where=[AND(__TEMPORAL_JOIN_CONDITION(o_proctime, currency), =(currency,
o_currency))], select=[o_amount, o_currency, o_proctime, currency, rate,
proctime])
+ :- Exchange(distribution=[hash[o_currency]])
+ : +- DataStreamScan(table=[[default_catalog, default_database,
ProctimeOrders]], fields=[o_amount, o_currency, o_proctime])
+ +- Exchange(distribution=[hash[currency]])
+ +- DataStreamScan(table=[[default_catalog, default_database,
ProctimeRatesHistory]], fields=[currency, rate, proctime])
+]]>
+ </Resource>
+ </TestCase>
+</Root>
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/TemporalJoinTest.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/TemporalJoinTest.scala
new file mode 100644
index 0000000..b709dca
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/batch/sql/join/TemporalJoinTest.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan.batch.sql.join
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.TableException
+import org.apache.flink.table.util.{BatchTableTestUtil, TableTestBase}
+import org.hamcrest.Matchers.containsString
+import org.junit.Test
+
+class TemporalJoinTest extends TableTestBase {
+
+ val util: BatchTableTestUtil = batchTestUtil()
+
+ val orders = util.addDataStream[(Long, String, Timestamp)](
+ "Orders", 'o_amount, 'o_currency, 'o_rowtime)
+
+ val ratesHistory = util.addDataStream[(String, Int, Timestamp)](
+ "RatesHistory", 'currency, 'rate, 'rowtime)
+
+ val rates = util.addFunction(
+ "Rates",
+ ratesHistory.createTemporalTableFunction("rowtime", "currency"))
+
+ @Test
+ def testSimpleJoin(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage("Cannot generate a valid execution plan
for the given query")
+
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(o_rowtime)) AS r " +
+ "WHERE currency = o_currency"
+
+ util.verifyExplain(sqlQuery)
+ }
+
+ /**
+ * Test temporal table joins with more complicated query.
+ * Important thing here is that we have complex OR join condition
+ * and there are some columns that are not being used (are being pruned).
+ */
+ @Test(expected = classOf[TableException])
+ def testComplexJoin(): Unit = {
+ val util = batchTestUtil()
+ util.addDataStream[(String, Int)]("Table3", 't3_comment, 't3_secondary_key)
+ util.addDataStream[(Timestamp, String, Long, String, Int)](
+ "Orders", 'o_rowtime, 'o_comment, 'o_amount, 'o_currency,
'o_secondary_key)
+
+ val ratesHistory = util.addDataStream[(Timestamp, String, String, Int,
Int)](
+ "RatesHistory", 'rowtime, 'comment, 'currency, 'rate, 'secondary_key)
+ val rates = ratesHistory.createTemporalTableFunction("rowtime", "currency")
+ util.addFunction("Rates", rates)
+
+ val sqlQuery =
+ "SELECT * FROM " +
+ "(SELECT " +
+ "o_amount * rate as rate, " +
+ "secondary_key as secondary_key " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(o_rowtime)) AS r " +
+ "WHERE currency = o_currency OR secondary_key = o_secondary_key), " +
+ "Table3 " +
+ "WHERE t3_secondary_key = secondary_key"
+
+ util.verifyExplain(sqlQuery)
+ }
+
+ @Test
+ def testUncorrelatedJoin(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage(containsString("Cannot generate a valid
execution plan"))
+
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123')) AS r " +
+ "WHERE currency = o_currency"
+
+ util.verifyExplain(sqlQuery)
+ }
+
+ @Test
+ def testTemporalTableFunctionScan(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage(containsString("Cannot generate a valid
execution plan"))
+
+ val sqlQuery = "SELECT * FROM LATERAL TABLE (Rates(TIMESTAMP '2016-06-27
10:10:42.123'))";
+
+ util.verifyExplain(sqlQuery)
+ }
+}
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/join/TemporalJoinTest.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/join/TemporalJoinTest.scala
new file mode 100644
index 0000000..3ee0ee1
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/plan/stream/sql/join/TemporalJoinTest.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan.stream.sql.join
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.TableException
+import org.apache.flink.table.util.{StreamTableTestUtil, TableTestBase}
+import org.hamcrest.Matchers.containsString
+import org.junit.Test
+
+class TemporalJoinTest extends TableTestBase {
+
+ val util: StreamTableTestUtil = streamTestUtil()
+
+ val orders = util.addDataStream[(Long, String, Timestamp)](
+ "Orders", 'o_amount, 'o_currency, 'o_rowtime)
+
+ val ratesHistory = util.addDataStream[(String, Int, Timestamp)](
+ "RatesHistory", 'currency, 'rate, 'rowtime)
+
+ val rates = util.addFunction(
+ "Rates",
+ ratesHistory.createTemporalTableFunction("rowtime", "currency"))
+
+ val proctimeOrders = util.addDataStream[(Long, String)](
+ "ProctimeOrders", 'o_amount, 'o_currency, 'o_proctime)
+
+ val proctimeRatesHistory = util.addDataStream[(String, Int)](
+ "ProctimeRatesHistory", 'currency, 'rate, 'proctime)
+
+ val proctimeRates = util.addFunction(
+ "ProctimeRates",
+ proctimeRatesHistory.createTemporalTableFunction("proctime", "currency"))
+
+ @Test
+ def testSimpleJoin(): Unit = {
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(o.o_rowtime)) AS r " +
+ "WHERE currency = o_currency"
+
+ util.verifyPlan(sqlQuery)
+ }
+
+ @Test
+ def testSimpleProctimeJoin(): Unit = {
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM ProctimeOrders AS o, " +
+ "LATERAL TABLE (ProctimeRates(o.o_proctime)) AS r " +
+ "WHERE currency = o_currency"
+
+ util.verifyPlan(sqlQuery)
+ }
+
+ /**
+ * Test versioned joins with more complicated query.
+ * Important thing here is that we have complex OR join condition
+ * and there are some columns that are not being used (are being pruned).
+ */
+ @Test
+ def testComplexJoin(): Unit = {
+ val util = streamTestUtil()
+ util.addDataStream[(String, Int)]("Table3", 't3_comment, 't3_secondary_key)
+ util.addDataStream[(Timestamp, String, Long, String, Int)](
+ "Orders", 'o_rowtime, 'o_comment, 'o_amount, 'o_currency,
'o_secondary_key)
+
+ val ratesHistory = util.addDataStream[(Timestamp, String, String, Int,
Int)](
+ "RatesHistory", 'rowtime, 'comment, 'currency, 'rate, 'secondary_key)
+ val rates = util.tableEnv
+ .sqlQuery("SELECT * FROM RatesHistory WHERE rate > 110")
+ .createTemporalTableFunction("rowtime", "currency")
+ util.addFunction("Rates", rates)
+
+ val sqlQuery =
+ "SELECT * FROM " +
+ "(SELECT " +
+ "o_amount * rate as rate, " +
+ "secondary_key as secondary_key " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(o_rowtime)) AS r " +
+ "WHERE currency = o_currency OR secondary_key = o_secondary_key), " +
+ "Table3 " +
+ "WHERE t3_secondary_key = secondary_key"
+
+ util.verifyPlan(sqlQuery)
+ }
+
+ @Test
+ def testUncorrelatedJoin(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage(containsString("Cannot generate a valid
execution plan"))
+
+ val sqlQuery = "SELECT " +
+ "o_amount * rate as rate " +
+ "FROM Orders AS o, " +
+ "LATERAL TABLE (Rates(TIMESTAMP '2016-06-27 10:10:42.123')) AS r " +
+ "WHERE currency = o_currency"
+
+ util.verifyExplain(sqlQuery)
+ }
+
+ @Test
+ def testTemporalTableFunctionScan(): Unit = {
+ expectedException.expect(classOf[TableException])
+ expectedException.expectMessage(containsString("Cannot generate a valid
execution plan"))
+
+ val sqlQuery = "SELECT * FROM LATERAL TABLE (Rates(TIMESTAMP '2016-06-27
10:10:42.123'))"
+
+ util.verifyExplain(sqlQuery)
+ }
+}
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/harness/AbstractTwoInputStreamOperatorWithTTLTest.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/harness/AbstractTwoInputStreamOperatorWithTTLTest.scala
new file mode 100644
index 0000000..cfa327f
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/harness/AbstractTwoInputStreamOperatorWithTTLTest.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.harness
+
+import org.apache.flink.api.common.time.Time
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.runtime.state.VoidNamespace
+import org.apache.flink.streaming.api.operators.InternalTimer
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord
+import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness
+import org.apache.flink.table.dataformat.BaseRow
+import
org.apache.flink.table.runtime.harness.HarnessTestBase.TestingBaseRowKeySelector
+import
org.apache.flink.table.runtime.join.temporal.BaseTwoInputStreamOperatorWithStateRetention
+import org.apache.flink.table.runtime.util.StreamRecordUtils.record
+import
org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.HEAP_BACKEND
+import
org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode
+import org.hamcrest.MatcherAssert.assertThat
+import org.hamcrest.{Description, TypeSafeMatcher}
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.junit.{After, Before, Test}
+
+import java.lang.{Long => JLong}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Tests for the
+ *
[[org.apache.flink.table.runtime.join.temporal.BaseTwoInputStreamOperatorWithStateRetention]].
+ */
+class AbstractTwoInputStreamOperatorWithTTLTest
+ extends HarnessTestBase(HEAP_BACKEND) {
+
+ @transient
+ private var recordAForFirstKey: StreamRecord[BaseRow] = _
+ @transient
+ private var recordBForFirstKey: StreamRecord[BaseRow] = _
+
+ private val minRetentionTime = Time.milliseconds(2)
+ private val maxRetentionTime = Time.milliseconds(4)
+
+ private var operatorUnderTest: StubOperatorWithStateTTL = _
+
+ private var testHarness
+ : KeyedTwoInputStreamOperatorTestHarness[JLong, BaseRow, BaseRow, BaseRow] =
_
+
+ @Before
+ def createTestHarness(): Unit = {
+ operatorUnderTest = new StubOperatorWithStateTTL(minRetentionTime,
maxRetentionTime)
+ testHarness = createTestHarness(operatorUnderTest)
+ testHarness.open()
+ recordAForFirstKey = record(1L: JLong, "hello")
+ recordBForFirstKey = record(1L: JLong, "world")
+ }
+
+ @After
+ def closeTestHarness(): Unit = {
+ testHarness.close()
+ }
+
+ @Test
+ def normalScenarioWorks(): Unit = {
+ testHarness.setProcessingTime(1L)
+ testHarness.processElement1(recordAForFirstKey)
+
+ testHarness.setProcessingTime(10L)
+
+ assertThat(operatorUnderTest, hasFiredCleanUpTimersForTimestamps(5L))
+ }
+
+ @Test
+ def
whenCurrentTimePlusMinRetentionSmallerThanCurrentCleanupTimeNoNewTimerRegistered():
Unit = {
+ testHarness.setProcessingTime(1L)
+ testHarness.processElement1(recordAForFirstKey)
+
+ testHarness.setProcessingTime(2L)
+ testHarness.processElement1(recordBForFirstKey)
+
+ testHarness.setProcessingTime(20L)
+
+ assertThat(operatorUnderTest, hasFiredCleanUpTimersForTimestamps(5L))
+ }
+
+ @Test
+ def
whenCurrentTimePlusMinRetentionLargerThanCurrentCleanupTimeTimerIsUpdated():
Unit = {
+ testHarness.setProcessingTime(1L)
+ testHarness.processElement1(recordAForFirstKey)
+
+ testHarness.setProcessingTime(4L)
+ testHarness.processElement1(recordBForFirstKey)
+
+ testHarness.setProcessingTime(20L)
+
+ assertThat(operatorUnderTest, hasFiredCleanUpTimersForTimestamps(8L))
+ }
+
+ @Test
+ def otherSideToSameKeyStateAlsoUpdatesCleanupTimer(): Unit = {
+ testHarness.setProcessingTime(1L)
+ testHarness.processElement1(recordAForFirstKey)
+
+ testHarness.setProcessingTime(4L)
+ testHarness.processElement2(recordBForFirstKey)
+
+ testHarness.setProcessingTime(20L)
+
+ assertThat(operatorUnderTest, hasFiredCleanUpTimersForTimestamps(8L))
+ }
+
+ // -------------------------------- Test Utilities
--------------------------------
+
+ private def createTestHarness(operator:
BaseTwoInputStreamOperatorWithStateRetention) = {
+ new KeyedTwoInputStreamOperatorTestHarness[JLong, BaseRow, BaseRow,
BaseRow](
+ operator,
+ new TestingBaseRowKeySelector(0),
+ new TestingBaseRowKeySelector(0),
+ BasicTypeInfo.LONG_TYPE_INFO,
+ 1,
+ 1,
+ 0)
+ }
+
+ // -------------------------------- Matchers --------------------------------
+
+ private def hasFiredCleanUpTimersForTimestamps(timers: JLong*) =
+ new TypeSafeMatcher[StubOperatorWithStateTTL]() {
+
+ override protected def matchesSafely(operator:
StubOperatorWithStateTTL): Boolean = {
+ operator.firedCleanUpTimers.toArray.deep == timers.toArray.deep
+ }
+
+ def describeTo(description: Description): Unit = {
+ description
+ .appendText("a list of timers with timestamps=")
+ .appendValue(timers.mkString(","))
+ }
+ }
+
+ // -------------------------------- Test Classes
--------------------------------
+
+ /**
+ * A mock [[BaseTwoInputStreamOperatorWithStateRetention]] which registers
+ * the timestamps of the clean-up timers that fired (not the registered
+ * ones, which can be deleted without firing).
+ */
+ class StubOperatorWithStateTTL(
+ minRetentionTime: Time,
+ maxRetentionTime: Time)
+ extends BaseTwoInputStreamOperatorWithStateRetention(
+ minRetentionTime.toMilliseconds, maxRetentionTime.toMilliseconds) {
+
+ val firedCleanUpTimers: mutable.Buffer[JLong] = ArrayBuffer.empty
+
+ override def cleanupState(time: Long): Unit = {
+ firedCleanUpTimers.append(time)
+ }
+
+ override def processElement1(element: StreamRecord[BaseRow]): Unit = {
+ registerProcessingCleanupTimer()
+ }
+
+ override def processElement2(element: StreamRecord[BaseRow]): Unit = {
+ registerProcessingCleanupTimer()
+ }
+
+ override def onEventTime(timer: InternalTimer[Object, VoidNamespace]):
Unit = ()
+ }
+}
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
index d31549c..904b2da 100644
---
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala
@@ -18,7 +18,6 @@
package org.apache.flink.table.runtime.harness
import java.util
-
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.dag.Transformation
import org.apache.flink.api.java.functions.KeySelector
@@ -31,6 +30,7 @@ import org.apache.flink.streaming.api.scala.DataStream
import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness
+import org.apache.flink.table.JLong
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.runtime.utils.StreamingTestBase
import
org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND,
ROCKSDB_BACKEND, StateBackendMode}
@@ -109,4 +109,12 @@ object HarnessTestBase {
def parameters(): util.Collection[Array[java.lang.Object]] = {
Seq[Array[AnyRef]](Array(HEAP_BACKEND), Array(ROCKSDB_BACKEND))
}
+
+ class TestingBaseRowKeySelector(
+ private val selectorField: Int) extends KeySelector[BaseRow, JLong] {
+
+ override def getKey(value: BaseRow): JLong = {
+ value.getLong(selectorField)
+ }
+ }
}
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
new file mode 100644
index 0000000..65717af
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/stream/sql/TemporalJoinITCase.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.stream.sql
+
+import java.sql.Timestamp
+
+import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.TimeCharacteristic
+import
org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.api.windowing.time.Time
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.runtime.utils.{StreamingWithStateTestBase,
TestingAppendSink}
+import
org.apache.flink.table.runtime.utils.StreamingWithStateTestBase.StateBackendMode
+import org.apache.flink.types.Row
+import org.junit.Assert.assertEquals
+import org.junit._
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.mutable
+
+@RunWith(classOf[Parameterized])
+class TemporalJoinITCase(state: StateBackendMode)
+ extends StreamingWithStateTestBase(state) {
+
+ /**
+ * Because of nature of the processing time, we can not (or at least it is
not that easy)
+ * validate the result here. Instead of that, here we are just testing
whether there are no
+ * exceptions in a full blown ITCase. Actual correctness is tested in unit
tests.
+ */
+ @Test
+ def testProcessTimeInnerJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv: StreamTableEnvironment =
TableEnvironment.getTableEnvironment(env)
+ env.setParallelism(1)
+ env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime)
+
+ val sqlQuery =
+ """
+ |SELECT
+ | o.amount * r.rate AS amount
+ |FROM
+ | Orders AS o,
+ | LATERAL TABLE (Rates(o.proctime)) AS r
+ |WHERE r.currency = o.currency
+ |""".stripMargin
+
+ val ordersData = new mutable.MutableList[(Long, String)]
+ ordersData.+=((2L, "Euro"))
+ ordersData.+=((1L, "US Dollar"))
+ ordersData.+=((50L, "Yen"))
+ ordersData.+=((3L, "Euro"))
+ ordersData.+=((5L, "US Dollar"))
+
+ val ratesHistoryData = new mutable.MutableList[(String, Long)]
+ ratesHistoryData.+=(("US Dollar", 102L))
+ ratesHistoryData.+=(("Euro", 114L))
+ ratesHistoryData.+=(("Yen", 1L))
+ ratesHistoryData.+=(("Euro", 116L))
+ ratesHistoryData.+=(("Euro", 119L))
+
+ val orders = env
+ .fromCollection(ordersData)
+ .toTable(tEnv, 'amount, 'currency, 'proctime)
+ val ratesHistory = env
+ .fromCollection(ratesHistoryData)
+ .toTable(tEnv, 'currency, 'rate, 'proctime)
+
+ tEnv.registerTable("Orders", orders)
+ tEnv.registerTable("RatesHistory", ratesHistory)
+ tEnv.registerFunction(
+ "Rates",
+ ratesHistory.createTemporalTableFunction("proctime", "currency"))
+
+ val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
+ result.addSink(new TestingAppendSink)
+ env.execute()
+ }
+
+ @Test
+ def testEventTimeInnerJoin(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv: StreamTableEnvironment =
TableEnvironment.getTableEnvironment(env)
+ env.setParallelism(1)
+ env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
+
+ val sqlQuery =
+ """
+ |SELECT
+ | o.amount * r.rate AS amount
+ |FROM
+ | Orders AS o,
+ | LATERAL TABLE (Rates(o.rowtime)) AS r
+ |WHERE r.currency = o.currency
+ |""".stripMargin
+
+ val ordersData = new mutable.MutableList[(Long, String, Timestamp)]
+ ordersData.+=((2L, "Euro", new Timestamp(2L)))
+ ordersData.+=((1L, "US Dollar", new Timestamp(3L)))
+ ordersData.+=((50L, "Yen", new Timestamp(4L)))
+ ordersData.+=((3L, "Euro", new Timestamp(5L)))
+
+ val ratesHistoryData = new mutable.MutableList[(String, Long, Timestamp)]
+ ratesHistoryData.+=(("US Dollar", 102L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Euro", 114L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Yen", 1L, new Timestamp(1L)))
+ ratesHistoryData.+=(("Euro", 116L, new Timestamp(5L)))
+ ratesHistoryData.+=(("Euro", 119L, new Timestamp(7L)))
+
+ var expectedOutput = new mutable.HashSet[String]()
+ expectedOutput += (2 * 114).toString
+ expectedOutput += (3 * 116).toString
+
+ val orders = env
+ .fromCollection(ordersData)
+ .assignTimestampsAndWatermarks(new TimestampExtractor[Long, String]())
+ .toTable(tEnv, 'amount, 'currency, 'rowtime)
+ val ratesHistory = env
+ .fromCollection(ratesHistoryData)
+ .assignTimestampsAndWatermarks(new TimestampExtractor[String, Long]())
+ .toTable(tEnv, 'currency, 'rate, 'rowtime)
+
+ tEnv.registerTable("Orders", orders)
+ tEnv.registerTable("RatesHistory", ratesHistory)
+ tEnv.registerTable("FilteredRatesHistory",
+ tEnv.sqlQuery("SELECT * FROM RatesHistory WHERE rate > 110"))
+ tEnv.registerFunction(
+ "Rates",
+ tEnv
+ .scan("FilteredRatesHistory")
+ .createTemporalTableFunction("rowtime", "currency"))
+ tEnv.registerTable("TemporalJoinResult", tEnv.sqlQuery(sqlQuery))
+
+ // Scan from registered table to test for interplay between
+ // LogicalCorrelateToTemporalTableJoinRule and TableScanRule
+ val result = tEnv.scan("TemporalJoinResult").toAppendStream[Row]
+ val sink = new TestingAppendSink
+ result.addSink(sink)
+ env.execute()
+
+ assertEquals(expectedOutput, sink.getAppendResults.toSet)
+ }
+}
+
+class TimestampExtractor[T1, T2]
+ extends BoundedOutOfOrdernessTimestampExtractor[(T1, T2,
Timestamp)](Time.seconds(10)) {
+ override def extractTimestamp(element: (T1, T2, Timestamp)): Long = {
+ element._3.getTime
+ }
+}
diff --git
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BaseRowUtil.java
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BaseRowUtil.java
index a15079b..96f07d6 100644
---
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BaseRowUtil.java
+++
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/dataformat/util/BaseRowUtil.java
@@ -42,6 +42,10 @@ public final class BaseRowUtil {
return baseRow.getHeader() == ACCUMULATE_MSG;
}
+ public static boolean isRetractMsg(BaseRow baseRow) {
+ return baseRow.getHeader() == RETRACT_MSG;
+ }
+
public static BaseRow setAccumulate(BaseRow baseRow) {
baseRow.setHeader(ACCUMULATE_MSG);
return baseRow;
diff --git
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/stream/AbstractStreamingJoinOperator.java
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/stream/AbstractStreamingJoinOperator.java
index b9c4642..588c1d7 100644
---
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/stream/AbstractStreamingJoinOperator.java
+++
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/stream/AbstractStreamingJoinOperator.java
@@ -117,7 +117,9 @@ public abstract class AbstractStreamingJoinOperator extends
AbstractStreamOperat
@Override
public void close() throws Exception {
super.close();
- joinCondition.backingJoinCondition.close();
+ if (joinCondition != null) {
+ joinCondition.backingJoinCondition.close();
+ }
}
//
----------------------------------------------------------------------------------------
diff --git
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/BaseTwoInputStreamOperatorWithStateRetention.java
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/BaseTwoInputStreamOperatorWithStateRetention.java
new file mode 100644
index 0000000..dd88950
--- /dev/null
+++
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/BaseTwoInputStreamOperatorWithStateRetention.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.join.temporal;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.streaming.api.SimpleTimerService;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.InternalTimer;
+import org.apache.flink.streaming.api.operators.InternalTimerService;
+import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.table.dataformat.BaseRow;
+
+import java.io.IOException;
+import java.util.Optional;
+
+/**
+ * An abstract {@link TwoInputStreamOperator} that allows its subclasses to
clean
+ * up their state based on a TTL. This TTL should be specified in the provided
+ * {@code minRetentionTime} and {@code maxRetentionTime}.
+ *
+ * <p>For each known key, this operator registers a timer (in processing time)
to
+ * fire after the TTL expires. When the timer fires, the subclass can decide
which
+ * state to cleanup and what further action to take.
+ *
+ * <p>This class takes care of maintaining at most one timer per key.
+ *
+ * <p><b>IMPORTANT NOTE TO USERS:</b> When extending this class, do not use
processing time
+ * timers in your business logic. The reason is that:
+ *
+ * <p>1) if your timers collide with clean up timers and you delete them, then
state
+ * clean-up will not be performed, and
+ *
+ * <p>2) (this one is the reason why this class does not allow to override the
onProcessingTime())
+ * the onProcessingTime with your logic would be also executed on each clean
up timer.
+ */
+@Internal
+public abstract class BaseTwoInputStreamOperatorWithStateRetention
+ extends AbstractStreamOperator<BaseRow>
+ implements TwoInputStreamOperator<BaseRow, BaseRow, BaseRow>,
Triggerable<Object, VoidNamespace> {
+
+ private static final long serialVersionUID = -5953921797477294258L;
+
+ private static final String CLEANUP_TIMESTAMP = "cleanup-timestamp";
+ private static final String TIMERS_STATE_NAME = "timers";
+
+ private final long minRetentionTime;
+ private final long maxRetentionTime;
+ protected final boolean stateCleaningEnabled;
+
+ private transient ValueState<Long> latestRegisteredCleanupTimer;
+ private transient SimpleTimerService timerService;
+
+ protected BaseTwoInputStreamOperatorWithStateRetention(long
minRetentionTime, long maxRetentionTime) {
+ this.minRetentionTime = minRetentionTime;
+ this.maxRetentionTime = maxRetentionTime;
+ this.stateCleaningEnabled = minRetentionTime > 1;
+ }
+
+ @Override
+ public void open() throws Exception {
+ initializeTimerService();
+
+ if (stateCleaningEnabled) {
+ ValueStateDescriptor<Long> cleanupStateDescriptor =
+ new ValueStateDescriptor<>(CLEANUP_TIMESTAMP,
Types.LONG);
+ latestRegisteredCleanupTimer =
getRuntimeContext().getState(cleanupStateDescriptor);
+ }
+ }
+
+ private void initializeTimerService() {
+ InternalTimerService<VoidNamespace> internalTimerService =
getInternalTimerService(
+ TIMERS_STATE_NAME,
+ VoidNamespaceSerializer.INSTANCE,
+ this);
+
+ timerService = new SimpleTimerService(internalTimerService);
+ }
+
+ /**
+ * If the user has specified a {@code minRetentionTime} and {@code
maxRetentionTime}, this
+ * method registers a cleanup timer for {@code currentProcessingTime +
minRetentionTime}.
+ *
+ * <p>When this timer fires, the {@link #cleanupState(long)} method is
called.
+ */
+ protected void registerProcessingCleanupTimer() throws IOException {
+ if (stateCleaningEnabled) {
+ long currentProcessingTime =
timerService.currentProcessingTime();
+ Optional<Long> currentCleanupTime =
Optional.ofNullable(latestRegisteredCleanupTimer.value());
+
+ if (!currentCleanupTime.isPresent()
+ || (currentProcessingTime + minRetentionTime) >
currentCleanupTime.get()) {
+
+ updateCleanupTimer(currentProcessingTime,
currentCleanupTime);
+ }
+ }
+ }
+
+ private void updateCleanupTimer(long currentProcessingTime,
Optional<Long> currentCleanupTime) throws IOException {
+ currentCleanupTime.ifPresent(aLong ->
timerService.deleteProcessingTimeTimer(aLong));
+
+ long newCleanupTime = currentProcessingTime + maxRetentionTime;
+ timerService.registerProcessingTimeTimer(newCleanupTime);
+ latestRegisteredCleanupTimer.update(newCleanupTime);
+ }
+
+ protected void cleanupLastTimer() throws IOException {
+ if (stateCleaningEnabled) {
+ Optional<Long> currentCleanupTime =
Optional.ofNullable(latestRegisteredCleanupTimer.value());
+ if (currentCleanupTime.isPresent()) {
+ latestRegisteredCleanupTimer.clear();
+
timerService.deleteProcessingTimeTimer(currentCleanupTime.get());
+ }
+ }
+ }
+
+ /**
+ * The users of this class are not allowed to use processing time
timers.
+ * See class javadoc.
+ */
+ @Override
+ public final void onProcessingTime(InternalTimer<Object, VoidNamespace>
timer) throws Exception {
+ if (stateCleaningEnabled) {
+ long timerTime = timer.getTimestamp();
+ Long cleanupTime = latestRegisteredCleanupTimer.value();
+
+ if (cleanupTime != null && cleanupTime == timerTime) {
+ cleanupState(cleanupTime);
+ latestRegisteredCleanupTimer.clear();
+ }
+ }
+ }
+
+ // ----------------- Abstract Methods -----------------
+
+ /**
+ * The method to be called when a cleanup timer fires.
+ * @param time The timestamp of the fired timer.
+ */
+ public abstract void cleanupState(long time);
+}
diff --git
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/TemporalProcessTimeJoinOperator.java
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/TemporalProcessTimeJoinOperator.java
new file mode 100644
index 0000000..c67b480
--- /dev/null
+++
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/TemporalProcessTimeJoinOperator.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.join.temporal;
+
+import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.streaming.api.operators.InternalTimer;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.dataformat.JoinedRow;
+import org.apache.flink.table.dataformat.util.BaseRowUtil;
+import org.apache.flink.table.generated.GeneratedJoinCondition;
+import org.apache.flink.table.generated.JoinCondition;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+
+/**
+ * The operator to temporal join a stream on processing time.
+ */
+public class TemporalProcessTimeJoinOperator
+ extends BaseTwoInputStreamOperatorWithStateRetention {
+
+ private static final long serialVersionUID = -5182289624027523612L;
+
+ private final BaseRowTypeInfo rightType;
+ private final GeneratedJoinCondition generatedJoinCondition;
+
+ private transient ValueState<BaseRow> rightState;
+ private transient JoinCondition joinCondition;
+
+ private transient JoinedRow outRow;
+ private transient TimestampedCollector<BaseRow> collector;
+
+ public TemporalProcessTimeJoinOperator(
+ BaseRowTypeInfo rightType,
+ GeneratedJoinCondition generatedJoinCondition,
+ long minRetentionTime,
+ long maxRetentionTime) {
+ super(minRetentionTime, maxRetentionTime);
+ this.rightType = rightType;
+ this.generatedJoinCondition = generatedJoinCondition;
+ }
+
+ @Override
+ public void open() throws Exception {
+ this.joinCondition =
generatedJoinCondition.newInstance(getRuntimeContext().getUserCodeClassLoader());
+ FunctionUtils.setFunctionRuntimeContext(joinCondition,
getRuntimeContext());
+ FunctionUtils.openFunction(joinCondition, new Configuration());
+
+ ValueStateDescriptor<BaseRow> rightStateDesc = new
ValueStateDescriptor<>("right", rightType);
+ this.rightState = getRuntimeContext().getState(rightStateDesc);
+ this.collector = new TimestampedCollector<>(output);
+ this.outRow = new JoinedRow();
+ // consider watermark from left stream only.
+ super.processWatermark2(Watermark.MAX_WATERMARK);
+ }
+
+ @Override
+ public void processElement1(StreamRecord<BaseRow> element) throws
Exception {
+ BaseRow rightSideRow = rightState.value();
+ if (rightSideRow == null) {
+ return;
+ }
+
+ BaseRow leftSideRow = element.getValue();
+ if (joinCondition.apply(leftSideRow, rightSideRow)) {
+ outRow.setHeader(leftSideRow.getHeader());
+ outRow.replace(leftSideRow, rightSideRow);
+ collector.collect(outRow);
+ }
+ registerProcessingCleanupTimer();
+ }
+
+ @Override
+ public void processElement2(StreamRecord<BaseRow> element) throws
Exception {
+ if (BaseRowUtil.isAccumulateMsg(element.getValue())) {
+ rightState.update(element.getValue());
+ registerProcessingCleanupTimer();
+ } else {
+ rightState.clear();
+ cleanupLastTimer();
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ FunctionUtils.closeFunction(joinCondition);
+ }
+
+ /**
+ * The method to be called when a cleanup timer fires.
+ *
+ * @param time The timestamp of the fired timer.
+ */
+ @Override
+ public void cleanupState(long time) {
+ rightState.clear();
+ }
+
+ /**
+ * Invoked when an event-time timer fires.
+ */
+ @Override
+ public void onEventTime(InternalTimer<Object, VoidNamespace> timer)
throws Exception {
+ }
+}
diff --git
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/TemporalRowTimeJoinOperator.java
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/TemporalRowTimeJoinOperator.java
new file mode 100644
index 0000000..56d6e10
--- /dev/null
+++
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/join/temporal/TemporalRowTimeJoinOperator.java
@@ -0,0 +1,400 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.join.temporal;
+
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.streaming.api.operators.InternalTimer;
+import org.apache.flink.streaming.api.operators.InternalTimerService;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.dataformat.BaseRow;
+import org.apache.flink.table.dataformat.JoinedRow;
+import org.apache.flink.table.dataformat.util.BaseRowUtil;
+import org.apache.flink.table.generated.GeneratedJoinCondition;
+import org.apache.flink.table.generated.JoinCondition;
+import org.apache.flink.table.typeutils.BaseRowTypeInfo;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.ListIterator;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * This operator works by keeping on the state collection of probe and build
records to process
+ * on next watermark. The idea is that between watermarks we are collecting
those elements
+ * and once we are sure that there will be no updates we emit the correct
result and clean up the
+ * state.
+ *
+ * <p>Cleaning up the state drops all of the "old" values from the probe side,
where "old" is defined
+ * as older then the current watermark. Build side is also cleaned up in the
similar fashion,
+ * however we always keep at least one record - the latest one - even if it's
past the last
+ * watermark.
+ *
+ * <p>One more trick is how the emitting results and cleaning up is triggered.
It is achieved
+ * by registering timers for the keys. We could register a timer for every
probe and build
+ * side element's event time (when watermark exceeds this timer, that's when
we are emitting and/or
+ * cleaning up the state). However this would cause huge number of registered
timers. For example
+ * with following evenTimes of probe records accumulated: {1, 2, 5, 8, 9}, if
we
+ * had received Watermark(10), it would trigger 5 separate timers for the same
key. To avoid that
+ * we always keep only one single registered timer for any given key,
registered for the minimal
+ * value. Upon triggering it, we process all records with event times older
then or equal to
+ * currentWatermark.
+ */
+public class TemporalRowTimeJoinOperator
+ extends BaseTwoInputStreamOperatorWithStateRetention {
+
+ private static final long serialVersionUID = 6642514795175288193L;
+
+ private static final String NEXT_LEFT_INDEX_STATE_NAME = "next-index";
+ private static final String LEFT_STATE_NAME = "left";
+ private static final String RIGHT_STATE_NAME = "right";
+ private static final String REGISTERED_TIMER_STATE_NAME = "timer";
+ private static final String TIMERS_STATE_NAME = "timers";
+
+ private final BaseRowTypeInfo leftType;
+ private final BaseRowTypeInfo rightType;
+ private final GeneratedJoinCondition generatedJoinCondition;
+ private final int leftTimeAttribute;
+ private final int rightTimeAttribute;
+
+ private final RowtimeComparator rightRowtimeComparator;
+
+ /**
+ * Incremental index generator for {@link #leftState}'s keys.
+ */
+ private transient ValueState<Long> nextLeftIndex;
+
+ /**
+ * Mapping from artificial row index (generated by `nextLeftIndex`)
into the left side `Row`.
+ * We can not use List to accumulate Rows, because we need efficient
deletes of the oldest rows.
+ *
+ * <p>TODO: this could be OrderedMultiMap[Jlong, Row] indexed by row's
timestamp, to avoid
+ * full map traversals (if we have lots of rows on the state that
exceed `currentWatermark`).
+ */
+ private transient MapState<Long, BaseRow> leftState;
+
+ /**
+ * Mapping from timestamp to right side `Row`.
+ *
+ * <p>TODO: having `rightState` as an OrderedMapState would allow us to
avoid sorting cost
+ * once per watermark
+ */
+ private transient MapState<Long, BaseRow> rightState;
+
+ // Long for correct handling of default null
+ private transient ValueState<Long> registeredTimer;
+ private transient TimestampedCollector<BaseRow> collector;
+ private transient InternalTimerService<VoidNamespace> timerService;
+
+ private transient JoinCondition joinCondition;
+ private transient JoinedRow outRow;
+
+ public TemporalRowTimeJoinOperator(
+ BaseRowTypeInfo leftType,
+ BaseRowTypeInfo rightType,
+ GeneratedJoinCondition generatedJoinCondition,
+ int leftTimeAttribute,
+ int rightTimeAttribute,
+ long minRetentionTime,
+ long maxRetentionTime) {
+ super(minRetentionTime, maxRetentionTime);
+ this.leftType = leftType;
+ this.rightType = rightType;
+ this.generatedJoinCondition = generatedJoinCondition;
+ this.leftTimeAttribute = leftTimeAttribute;
+ this.rightTimeAttribute = rightTimeAttribute;
+ this.rightRowtimeComparator = new
RowtimeComparator(rightTimeAttribute);
+ }
+
+ @Override
+ public void open() throws Exception {
+ joinCondition =
generatedJoinCondition.newInstance(getRuntimeContext().getUserCodeClassLoader());
+ joinCondition.setRuntimeContext(getRuntimeContext());
+ joinCondition.open(new Configuration());
+
+ nextLeftIndex = getRuntimeContext().getState(
+ new ValueStateDescriptor<>(NEXT_LEFT_INDEX_STATE_NAME,
Types.LONG));
+ leftState = getRuntimeContext().getMapState(
+ new MapStateDescriptor<>(LEFT_STATE_NAME, Types.LONG,
leftType));
+ rightState = getRuntimeContext().getMapState(
+ new MapStateDescriptor<>(RIGHT_STATE_NAME, Types.LONG,
rightType));
+ registeredTimer = getRuntimeContext().getState(
+ new ValueStateDescriptor<>(REGISTERED_TIMER_STATE_NAME,
Types.LONG));
+
+ timerService = getInternalTimerService(
+ TIMERS_STATE_NAME, VoidNamespaceSerializer.INSTANCE,
this);
+ collector = new TimestampedCollector<>(output);
+ outRow = new JoinedRow();
+ outRow.setHeader(BaseRowUtil.ACCUMULATE_MSG);
+ }
+
+ @Override
+ public void processElement1(StreamRecord<BaseRow> element) throws
Exception {
+ BaseRow row = element.getValue();
+ checkNotRetraction(row);
+
+ leftState.put(getNextLeftIndex(), row);
+ registerSmallestTimer(getLeftTime(row)); // Timer to emit and
clean up the state
+
+ registerProcessingCleanupTimer();
+ }
+
+ @Override
+ public void processElement2(StreamRecord<BaseRow> element) throws
Exception {
+ BaseRow row = element.getValue();
+ checkNotRetraction(row);
+
+ long rowTime = getRightTime(row);
+ rightState.put(rowTime, row);
+ registerSmallestTimer(rowTime); // Timer to clean up the state
+
+ registerProcessingCleanupTimer();
+ }
+
+ @Override
+ public void onEventTime(InternalTimer<Object, VoidNamespace> timer)
throws Exception {
+ registeredTimer.clear();
+ long lastUnprocessedTime =
emitResultAndCleanUpState(timerService.currentWatermark());
+ if (lastUnprocessedTime < Long.MAX_VALUE) {
+ registerTimer(lastUnprocessedTime);
+ }
+
+ // if we have more state at any side, then update the timer,
else clean it up.
+ if (stateCleaningEnabled) {
+ if (lastUnprocessedTime < Long.MAX_VALUE ||
rightState.iterator().hasNext()) {
+ registerProcessingCleanupTimer();
+ } else {
+ cleanupLastTimer();
+ }
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ if (joinCondition != null) {
+ joinCondition.close();
+ }
+ }
+
+ /**
+ * @return a row time of the oldest unprocessed probe record or
Long.MaxValue, if all records
+ * have been processed.
+ */
+ private long emitResultAndCleanUpState(long timerTimestamp) throws
Exception {
+ List<BaseRow> rightRowsSorted =
getRightRowSorted(rightRowtimeComparator);
+ long lastUnprocessedTime = Long.MAX_VALUE;
+
+ Iterator<Map.Entry<Long, BaseRow>> leftIterator =
leftState.entries().iterator();
+ while (leftIterator.hasNext()) {
+ Map.Entry<Long, BaseRow> entry = leftIterator.next();
+ BaseRow leftRow = entry.getValue();
+ long leftTime = getLeftTime(leftRow);
+
+ if (leftTime <= timerTimestamp) {
+ Optional<BaseRow> rightRow =
latestRightRowToJoin(rightRowsSorted, leftTime);
+ if (rightRow.isPresent()) {
+ if (joinCondition.apply(leftRow,
rightRow.get())) {
+ outRow.replace(leftRow,
rightRow.get());
+ collector.collect(outRow);
+ }
+ }
+ leftIterator.remove();
+ } else {
+ lastUnprocessedTime =
Math.min(lastUnprocessedTime, leftTime);
+ }
+ }
+
+ cleanupState(timerTimestamp, rightRowsSorted);
+ return lastUnprocessedTime;
+ }
+
+ /**
+ * Removes all right entries older then the watermark, except the
latest one. For example with:
+ * rightState = [1, 5, 9]
+ * and
+ * watermark = 6
+ * we can not remove "5" from rightState, because left elements with
rowtime of 7 or 8 could
+ * be joined with it later
+ */
+ private void cleanupState(long timerTimestamp, List<BaseRow>
rightRowsSorted) throws Exception {
+ int i = 0;
+ int indexToKeep = firstIndexToKeep(timerTimestamp,
rightRowsSorted);
+ while (i < indexToKeep) {
+ long rightTime = getRightTime(rightRowsSorted.get(i));
+ rightState.remove(rightTime);
+ i += 1;
+ }
+ }
+
+ /**
+ * The method to be called when a cleanup timer fires.
+ *
+ * @param time The timestamp of the fired timer.
+ */
+ @Override
+ public void cleanupState(long time) {
+ leftState.clear();
+ rightState.clear();
+ }
+
+ private int firstIndexToKeep(long timerTimestamp, List<BaseRow>
rightRowsSorted) {
+ int firstIndexNewerThenTimer =
+ indexOfFirstElementNewerThanTimer(timerTimestamp,
rightRowsSorted);
+
+ if (firstIndexNewerThenTimer < 0) {
+ return rightRowsSorted.size() - 1;
+ }
+ else {
+ return firstIndexNewerThenTimer - 1;
+ }
+ }
+
+ private int indexOfFirstElementNewerThanTimer(long timerTimestamp,
List<BaseRow> list) {
+ ListIterator<BaseRow> iter = list.listIterator();
+ while (iter.hasNext()) {
+ if (getRightTime(iter.next()) > timerTimestamp) {
+ return iter.previousIndex();
+ }
+ }
+ return -1;
+ }
+
+ /**
+ * Binary search {@code rightRowsSorted} to find the latest right row
to join with {@code leftTime}.
+ * Latest means a right row with largest time that is still smaller or
equal to {@code leftTime}.
+ *
+ * @return found element or {@code Optional.empty} If such row was not
found (either {@code rightRowsSorted}
+ * is empty or all {@code rightRowsSorted} are are newer).
+ */
+ private Optional<BaseRow> latestRightRowToJoin(List<BaseRow>
rightRowsSorted, long leftTime) {
+ return latestRightRowToJoin(rightRowsSorted, 0,
rightRowsSorted.size() - 1, leftTime);
+ }
+
+ private Optional<BaseRow> latestRightRowToJoin(
+ List<BaseRow> rightRowsSorted,
+ int low,
+ int high,
+ long leftTime) {
+ if (low > high) {
+ // exact value not found, we are returning largest from
the values smaller then leftTime
+ if (low - 1 < 0) {
+ return Optional.empty();
+ }
+ else {
+ return Optional.of(rightRowsSorted.get(low -
1));
+ }
+ } else {
+ int mid = (low + high) >>> 1;
+ BaseRow midRow = rightRowsSorted.get(mid);
+ long midTime = getRightTime(midRow);
+ int cmp = Long.compare(midTime, leftTime);
+ if (cmp < 0) {
+ return latestRightRowToJoin(rightRowsSorted,
mid + 1, high, leftTime);
+ }
+ else if (cmp > 0) {
+ return latestRightRowToJoin(rightRowsSorted,
low, mid - 1, leftTime);
+ }
+ else {
+ return Optional.of(midRow);
+ }
+ }
+ }
+
+ private void registerSmallestTimer(long timestamp) throws IOException {
+ Long currentRegisteredTimer = registeredTimer.value();
+ if (currentRegisteredTimer == null) {
+ registerTimer(timestamp);
+ } else if (currentRegisteredTimer > timestamp) {
+
timerService.deleteEventTimeTimer(VoidNamespace.INSTANCE,
currentRegisteredTimer);
+ registerTimer(timestamp);
+ }
+ }
+
+ private void registerTimer(long timestamp) throws IOException {
+ registeredTimer.update(timestamp);
+ timerService.registerEventTimeTimer(VoidNamespace.INSTANCE,
timestamp);
+ }
+
+ private List<BaseRow> getRightRowSorted(RowtimeComparator
rowtimeComparator) throws Exception {
+ List<BaseRow> rightRows = new ArrayList<>();
+ for (BaseRow row : rightState.values()) {
+ rightRows.add(row);
+ }
+ rightRows.sort(rowtimeComparator);
+ return rightRows;
+ }
+
+ private long getNextLeftIndex() throws IOException {
+ Long index = nextLeftIndex.value();
+ if (index == null) {
+ index = 0L;
+ }
+ nextLeftIndex.update(index + 1);
+ return index;
+ }
+
+ private long getLeftTime(BaseRow leftRow) {
+ return leftRow.getLong(leftTimeAttribute);
+ }
+
+ private long getRightTime(BaseRow rightRow) {
+ return rightRow.getLong(rightTimeAttribute);
+ }
+
+ private void checkNotRetraction(BaseRow row) {
+ if (BaseRowUtil.isRetractMsg(row)) {
+ String className = getClass().getSimpleName();
+ throw new IllegalStateException(
+ "Retractions are not supported by " + className
+
+ ". If this can happen it should be
validated during planning!");
+ }
+ }
+
+ //
------------------------------------------------------------------------------------------
+
+ private static class RowtimeComparator implements Comparator<BaseRow>,
Serializable {
+
+ private static final long serialVersionUID =
8160134014590716914L;
+
+ private final int timeAttribute;
+
+ private RowtimeComparator(int timeAttribute) {
+ this.timeAttribute = timeAttribute;
+ }
+
+ @Override
+ public int compare(BaseRow o1, BaseRow o2) {
+ long o1Time = o1.getLong(timeAttribute);
+ long o2Time = o2.getLong(timeAttribute);
+ return Long.compare(o1Time, o2Time);
+ }
+ }
+}