leonardBang commented on a change in pull request #13299:
URL: https://github.com/apache/flink/pull/13299#discussion_r502886842



##########
File path: 
flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableRule.scala
##########
@@ -74,23 +178,150 @@ abstract class 
LogicalCorrelateToJoinFromTemporalTableRule(
     val rel = builder.build()
     call.transformTo(rel)
   }
+}
+
+
+/**
+ * General temporal table join rule to rewrite the original Correlate into a 
Join.
+ */
+abstract class LogicalCorrelateToJoinFromGeneralTemporalTableRule(
+    operand: RelOptRuleOperand,
+    description: String)
+  extends LogicalCorrelateToJoinFromTemporalTableRule(operand, description) {
+
+  protected def extractRightTimeInputRef(
+      leftInput: RelNode,
+      snapshot: LogicalSnapshot): Option[RexNode] = {
+    val rightFields = snapshot.getRowType.getFieldList.asScala
+    val timeAttributeFields = rightFields.filter(
+      f => f.getType.isInstanceOf[TimeIndicatorRelDataType])
+    val rexBuilder = snapshot.getCluster.getRexBuilder
+
+    if (timeAttributeFields != null && timeAttributeFields.length == 1) {
+      val leftFieldCnt = leftInput.getRowType.getFieldCount
+      val timeColIndex = leftFieldCnt + 
rightFields.indexOf(timeAttributeFields.get(0))
+      val timeColDataType = timeAttributeFields.get(0).getType
+      Some(rexBuilder.makeInputRef(timeColDataType, timeColIndex))
+    } else {
+      None
+    }
+  }
+
+  protected def extractSnapshotTimeInputRef(
+      leftInput: RelNode,
+      snapshot: LogicalSnapshot): Option[RexInputRef] = {
+    val leftRowType = leftInput.getRowType
+    val leftFields = leftRowType.getFieldList
+    val periodField = snapshot.getPeriod.asInstanceOf[RexFieldAccess].getField
+    if (leftFields.contains(periodField)) {
+      val index = leftRowType.getFieldList.indexOf(periodField)
+      Some(RexInputRef.of(index, leftRowType))
+    } else {
+      None
+    }
+  }
+
+  override def onMatch(call: RelOptRuleCall): Unit = {
+    val correlate: LogicalCorrelate = call.rel(0)
+    val leftInput: RelNode = call.rel(1)
+    val filterCondition = getFilterCondition(call)
+    val snapshot = getLogicalSnapshot(call)
+
+    val leftRowType = leftInput.getRowType
+    val joinCondition = filterCondition.accept(new RexShuttle() {
+      // change correlate variable expression to normal RexInputRef (which is 
from left side)
+      override def visitFieldAccess(fieldAccess: RexFieldAccess): RexNode = {
+        fieldAccess.getReferenceExpr match {
+          case corVar: RexCorrelVariable =>
+            require(correlate.getCorrelationId.equals(corVar.id))
+            val index = leftRowType.getFieldList.indexOf(fieldAccess.getField)
+            RexInputRef.of(index, leftRowType)
+          case _ => super.visitFieldAccess(fieldAccess)
+        }
+      }
 
+      // update the field index from right side
+      override def visitInputRef(inputRef: RexInputRef): RexNode = {
+        val rightIndex = leftRowType.getFieldCount + inputRef.getIndex
+        new RexInputRef(rightIndex, inputRef.getType)
+      }
+    })
+
+    validateSnapshotInCorrelate(snapshot, correlate)
+
+    val (leftJoinKey, rightJoinKey) = {
+      val rexBuilder = correlate.getCluster.getRexBuilder
+      val relBuilder = call.builder()
+      relBuilder.push(leftInput)
+      relBuilder.push(snapshot)
+      val rewriteJoin = relBuilder.join(correlate.getJoinType, 
joinCondition).build()
+      val joinInfo = rewriteJoin.asInstanceOf[LogicalJoin].analyzeCondition()
+      val leftJoinKey = joinInfo.leftKeys.map(i => 
rexBuilder.makeInputRef(leftInput, i))
+      val rightJoinKey = joinInfo.rightKeys.map(i => {
+        val leftFieldCnt = leftInput.getRowType.getFieldCount
+        val leftKeyType = snapshot.getRowType.getFieldList.get(i).getType
+        rexBuilder.makeInputRef(leftKeyType, leftFieldCnt + i)
+      })
+      (leftJoinKey, rightJoinKey)
+    }
+
+    val snapshotTimeInputRef = extractSnapshotTimeInputRef(leftInput, snapshot)
+      .getOrElse(throw new ValidationException("Temporal Table Join requires 
time attribute in the " +
+        s"left table, but no row time attribute found."))
+
+    val rexBuilder = correlate.getCluster.getRexBuilder
+    val temporalCondition = if(isRowTimeTemporalTableJoin(snapshot)) {
+      val rightTimeInputRef = extractRightTimeInputRef(leftInput, snapshot)
+      if (rightTimeInputRef.isEmpty || 
!isRowtimeIndicatorType(rightTimeInputRef.get.getType)) {
+          throw new ValidationException("Event-Time Temporal Table Join 
requires both" +
+            s" primary key and row time attribute in versioned table," +

Review comment:
       The `primary key` will checked in 
`TemporalJoinRewriteWithUniqueKeyRule.` 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to