This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new d7ea117  [FLINK-20624][table-planner-blink] Refactor 
StreamExecJoinRule, StreamExecIntervalJoinRule and StreamExecTemporalJoinRule
d7ea117 is described below

commit d7ea11733f8283ffdc629d7e5fcdc3a5c34c9d37
Author: Jerry Wang <[email protected]>
AuthorDate: Tue Dec 22 11:34:13 2020 +0800

    [FLINK-20624][table-planner-blink] Refactor StreamExecJoinRule, 
StreamExecIntervalJoinRule and StreamExecTemporalJoinRule
    
    This closes #14404
---
 .../stream/StreamExecIntervalJoinRule.scala        | 81 ++++--------------
 .../rules/physical/stream/StreamExecJoinRule.scala | 64 +++-----------
 .../physical/stream/StreamExecJoinRuleBase.scala   | 99 ++++++++++++++++++++++
 .../stream/StreamExecTemporalJoinRule.scala        | 68 ++++-----------
 4 files changed, 144 insertions(+), 168 deletions(-)

diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecIntervalJoinRule.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecIntervalJoinRule.scala
index a52291d..632c0d4 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecIntervalJoinRule.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecIntervalJoinRule.scala
@@ -18,18 +18,12 @@
 
 package org.apache.flink.table.planner.plan.rules.physical.stream
 
-import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory}
-import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
-import org.apache.flink.table.planner.plan.nodes.FlinkConventions
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
+import org.apache.flink.table.planner.plan.nodes.FlinkRelNode
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalJoin
 import 
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecIntervalJoin
-import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, 
IntervalJoinUtil}
-
 import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
 import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.convert.ConverterRule
-
-import java.util
 
 import scala.collection.JavaConversions._
 
@@ -38,11 +32,7 @@ import scala.collection.JavaConversions._
   * to [[StreamExecIntervalJoin]].
   */
 class StreamExecIntervalJoinRule
-  extends ConverterRule(
-    classOf[FlinkLogicalJoin],
-    FlinkConventions.LOGICAL,
-    FlinkConventions.STREAM_PHYSICAL,
-    "StreamExecIntervalJoinRule") {
+  extends StreamExecJoinRuleBase("StreamExecIntervalJoinRule") {
 
   override def matches(call: RelOptRuleCall): Boolean = {
     val join: FlinkLogicalJoin = call.rel(0)
@@ -53,13 +43,7 @@ class StreamExecIntervalJoinRule
       return false
     }
 
-    val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(join)
-    val (windowBounds, _) = IntervalJoinUtil.extractWindowBoundsFromPredicate(
-      join.getCondition,
-      join.getLeft.getRowType.getFieldCount,
-      joinRowType,
-      join.getCluster.getRexBuilder,
-      tableConfig)
+    val (windowBounds, _) = extractWindowBounds(join)
 
     if (windowBounds.isDefined) {
       if (windowBounds.get.isEventTime) {
@@ -76,55 +60,22 @@ class StreamExecIntervalJoinRule
     }
   }
 
-  override def convert(rel: RelNode): RelNode = {
-    val join: FlinkLogicalJoin = rel.asInstanceOf[FlinkLogicalJoin]
-    val joinRowType = join.getRowType
-    val left = join.getLeft
-    val right = join.getRight
-
-    def toHashTraitByColumns(
-        columns: util.Collection[_ <: Number],
-        inputTraitSet: RelTraitSet): RelTraitSet = {
-      val distribution = if (columns.size() == 0) {
-        FlinkRelDistribution.SINGLETON
-      } else {
-        FlinkRelDistribution.hash(columns)
-      }
-      inputTraitSet
-        .replace(FlinkConventions.STREAM_PHYSICAL)
-        .replace(distribution)
-    }
-
-    val joinInfo = join.analyzeCondition
-    val (leftRequiredTrait, rightRequiredTrait) = (
-      toHashTraitByColumns(joinInfo.leftKeys, left.getTraitSet),
-      toHashTraitByColumns(joinInfo.rightKeys, right.getTraitSet))
-
-    val newLeft = RelOptRule.convert(left, leftRequiredTrait)
-    val newRight = RelOptRule.convert(right, rightRequiredTrait)
-    val providedTraitSet = 
join.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
-
-    val tableConfig = rel
-      .getCluster
-      .getPlanner
-      .getContext
-      .unwrap(classOf[FlinkContext])
-      .getTableConfig
-    val (windowBounds, remainCondition) = 
IntervalJoinUtil.extractWindowBoundsFromPredicate(
-      join.getCondition,
-      left.getRowType.getFieldCount,
-      joinRowType,
-      join.getCluster.getRexBuilder,
-      tableConfig)
-
+  override protected def transform(
+      join: FlinkLogicalJoin,
+      leftInput: FlinkRelNode,
+      leftConversion: RelNode => RelNode,
+      rightInput: FlinkRelNode,
+      rightConversion: RelNode => RelNode,
+      providedTraitSet: RelTraitSet): FlinkRelNode = {
+    val (windowBounds, remainCondition) = extractWindowBounds(join)
     new StreamExecIntervalJoin(
-      rel.getCluster,
+      join.getCluster,
       providedTraitSet,
-      newLeft,
-      newRight,
+      leftConversion(leftInput),
+      rightConversion(rightInput),
       join.getCondition,
       join.getJoinType,
-      joinRowType,
+      join.getRowType,
       windowBounds.get.isEventTime,
       windowBounds.get.leftLowerBound,
       windowBounds.get.leftUpperBound,
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecJoinRule.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecJoinRule.scala
index b4467ca..9d397ac 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecJoinRule.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecJoinRule.scala
@@ -19,16 +19,13 @@
 package org.apache.flink.table.planner.plan.rules.physical.stream
 
 import org.apache.flink.table.api.TableException
-import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory}
-import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
-import org.apache.flink.table.planner.plan.nodes.FlinkConventions
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
+import org.apache.flink.table.planner.plan.nodes.FlinkRelNode
 import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalJoin, 
FlinkLogicalRel, FlinkLogicalSnapshot}
 import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecJoin
 import org.apache.flink.table.planner.plan.utils.{IntervalJoinUtil, 
TemporalJoinUtil}
-import org.apache.calcite.plan.RelOptRule.{any, operand}
 import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
 import org.apache.calcite.rel.RelNode
-import java.util
 
 import scala.collection.JavaConversions._
 
@@ -37,11 +34,7 @@ import scala.collection.JavaConversions._
   * to [[StreamExecJoin]].
   */
 class StreamExecJoinRule
-  extends RelOptRule(
-    operand(classOf[FlinkLogicalJoin],
-      operand(classOf[FlinkLogicalRel], any()),
-      operand(classOf[FlinkLogicalRel], any())),
-    "StreamExecJoinRule") {
+  extends StreamExecJoinRuleBase("StreamExecJoinRule") {
 
   override def matches(call: RelOptRuleCall): Boolean = {
     val join: FlinkLogicalJoin = call.rel(0)
@@ -51,7 +44,6 @@ class StreamExecJoinRule
     }
     val left: FlinkLogicalRel = call.rel(1).asInstanceOf[FlinkLogicalRel]
     val right: FlinkLogicalRel = call.rel(2).asInstanceOf[FlinkLogicalRel]
-    val tableConfig = 
call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig
     val joinRowType = join.getRowType
 
     if (left.isInstanceOf[FlinkLogicalSnapshot]) {
@@ -65,13 +57,7 @@ class StreamExecJoinRule
       return false
     }
 
-    val (windowBounds, remainingPreds) = 
IntervalJoinUtil.extractWindowBoundsFromPredicate(
-      join.getCondition,
-      join.getLeft.getRowType.getFieldCount,
-      joinRowType,
-      join.getCluster.getRexBuilder,
-      tableConfig)
-
+    val (windowBounds, remainingPreds) = extractWindowBounds(join)
     if (windowBounds.isDefined) {
       return false
     }
@@ -95,42 +81,20 @@ class StreamExecJoinRule
     !remainingPredsAccessTime
   }
 
-  override def onMatch(call: RelOptRuleCall): Unit = {
-    val join: FlinkLogicalJoin = call.rel(0)
-    val left = join.getLeft
-    val right = join.getRight
-
-    def toHashTraitByColumns(
-        columns: util.Collection[_ <: Number],
-        inputTraitSets: RelTraitSet): RelTraitSet = {
-      val distribution = if (columns.isEmpty) {
-        FlinkRelDistribution.SINGLETON
-      } else {
-        FlinkRelDistribution.hash(columns)
-      }
-      inputTraitSets
-        .replace(FlinkConventions.STREAM_PHYSICAL)
-        .replace(distribution)
-    }
-
-    val joinInfo = join.analyzeCondition()
-    val (leftRequiredTrait, rightRequiredTrait) = (
-      toHashTraitByColumns(joinInfo.leftKeys, left.getTraitSet),
-      toHashTraitByColumns(joinInfo.rightKeys, right.getTraitSet))
-
-    val providedTraitSet = 
join.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
-
-    val newLeft: RelNode = RelOptRule.convert(left, leftRequiredTrait)
-    val newRight: RelNode = RelOptRule.convert(right, rightRequiredTrait)
-
-    val newJoin = new StreamExecJoin(
+  override protected def transform(
+      join: FlinkLogicalJoin,
+      leftInput: FlinkRelNode,
+      leftConversion: RelNode => RelNode,
+      rightInput: FlinkRelNode,
+      rightConversion: RelNode => RelNode,
+      providedTraitSet: RelTraitSet): FlinkRelNode = {
+    new StreamExecJoin(
       join.getCluster,
       providedTraitSet,
-      newLeft,
-      newRight,
+      leftConversion(leftInput),
+      rightConversion(rightInput),
       join.getCondition,
       join.getJoinType)
-    call.transformTo(newJoin)
   }
 }
 
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecJoinRuleBase.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecJoinRuleBase.scala
new file mode 100644
index 0000000..0336334
--- /dev/null
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecJoinRuleBase.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.stream
+
+import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
+import org.apache.flink.table.planner.plan.nodes.{FlinkConventions, 
FlinkRelNode}
+import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalJoin, 
FlinkLogicalRel, FlinkLogicalSnapshot}
+import org.apache.flink.table.planner.plan.utils.IntervalJoinUtil.WindowBounds
+import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, 
IntervalJoinUtil}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
+import org.apache.calcite.plan.RelOptRule.{any, operand}
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rex.RexNode
+
+import java.util
+
+/**
+ * Base implementation for rules match stream-stream join, including
+ * regular stream join, interval join and temporal join.
+ */
+abstract class StreamExecJoinRuleBase(description: String)
+  extends RelOptRule(
+    operand(classOf[FlinkLogicalJoin],
+      operand(classOf[FlinkLogicalRel], any()),
+      operand(classOf[FlinkLogicalRel], any())),
+    description) {
+
+  protected def extractWindowBounds(join: FlinkLogicalJoin):
+    (Option[WindowBounds], Option[RexNode]) = {
+    val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(join)
+    IntervalJoinUtil.extractWindowBoundsFromPredicate(
+      join.getCondition,
+      join.getLeft.getRowType.getFieldCount,
+      join.getRowType,
+      join.getCluster.getRexBuilder,
+      tableConfig)
+  }
+
+  override def onMatch(call: RelOptRuleCall): Unit = {
+    val join = call.rel[FlinkLogicalJoin](0)
+    val left = call.rel[FlinkLogicalRel](1)
+    val right = call.rel[FlinkLogicalRel](2)
+
+    def toHashTraitByColumns(
+        columns: util.Collection[_ <: Number],
+        inputTraitSet: RelTraitSet): RelTraitSet = {
+      val distribution = if (columns.size() == 0) {
+        FlinkRelDistribution.SINGLETON
+      } else {
+        FlinkRelDistribution.hash(columns)
+      }
+      inputTraitSet
+          .replace(FlinkConventions.STREAM_PHYSICAL)
+          .replace(distribution)
+    }
+
+    def convertInput(input: RelNode, columns: util.Collection[_ <: Number]): 
RelNode = {
+      val requiredTraitSet = toHashTraitByColumns(columns, input.getTraitSet)
+      RelOptRule.convert(input, requiredTraitSet)
+    }
+
+    val joinInfo = join.analyzeCondition
+    val providedTraitSet: RelTraitSet = 
join.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
+
+    val leftConversion: RelNode => RelNode = leftInput => {
+      convertInput(leftInput, joinInfo.leftKeys)
+    }
+    val rightConversion: RelNode => RelNode = rightInput => {
+      convertInput(rightInput, joinInfo.rightKeys)
+    }
+
+    val newJoin = transform(join, left, leftConversion, right, 
rightConversion, providedTraitSet)
+    call.transformTo(newJoin)
+  }
+
+  protected def transform(
+      join: FlinkLogicalJoin,
+      leftInput: FlinkRelNode,
+      leftConversion: RelNode => RelNode,
+      rightInput: FlinkRelNode,
+      rightConversion: RelNode => RelNode,
+      providedTraitSet: RelTraitSet): FlinkRelNode
+}
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecTemporalJoinRule.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecTemporalJoinRule.scala
index 9a5f1f2..2803289 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecTemporalJoinRule.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecTemporalJoinRule.scala
@@ -18,14 +18,10 @@
 
 package org.apache.flink.table.planner.plan.rules.physical.stream
 
-import java.util
-
-import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
-import org.apache.flink.table.planner.plan.nodes.FlinkConventions
+import org.apache.flink.table.planner.plan.nodes.FlinkRelNode
 import org.apache.flink.table.planner.plan.nodes.logical._
 import 
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecTemporalJoin
-import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, 
IntervalJoinUtil, TemporalJoinUtil}
-import org.apache.calcite.plan.RelOptRule.{any, operand}
+import org.apache.flink.table.planner.plan.utils.TemporalJoinUtil
 import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelTraitSet}
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.core.JoinRelType
@@ -36,11 +32,7 @@ import org.apache.flink.util.Preconditions.checkState
  * the temporal join node is a [[FlinkLogicalJoin]] which contains 
[[TemporalJoinCondition]].
  */
 class StreamExecTemporalJoinRule
-  extends RelOptRule(
-    operand(classOf[FlinkLogicalJoin],
-      operand(classOf[FlinkLogicalRel], any()),
-      operand(classOf[FlinkLogicalRel], any())),
-    "StreamExecTemporalJoinRule") {
+  extends StreamExecJoinRuleBase("StreamExecJoinRuleBase") {
 
   override def matches(call: RelOptRuleCall): Boolean = {
     val join = call.rel[FlinkLogicalJoin](0)
@@ -60,59 +52,29 @@ class StreamExecTemporalJoinRule
   }
 
   private def matchesTemporalTableFunctionJoin(join: FlinkLogicalJoin): 
Boolean = {
-    val joinInfo = join.analyzeCondition
-    val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(join)
-    val (windowBounds, _) = IntervalJoinUtil.extractWindowBoundsFromPredicate(
-      joinInfo.getRemaining(join.getCluster.getRexBuilder),
-      join.getLeft.getRowType.getFieldCount,
-      join.getRowType,
-      join.getCluster.getRexBuilder,
-      tableConfig)
+    val (windowBounds, _) = extractWindowBounds(join)
     windowBounds.isEmpty && join.getJoinType == JoinRelType.INNER
   }
 
-  override def onMatch(call: RelOptRuleCall): Unit = {
-    val join = call.rel[FlinkLogicalJoin](0)
-    val left = call.rel[FlinkLogicalRel](1)
-    val right = call.rel[FlinkLogicalRel](2)
-
-    val newRight = right match {
+  override protected def transform(
+      join: FlinkLogicalJoin,
+      leftInput: FlinkRelNode,
+      leftConversion: RelNode => RelNode,
+      rightInput: FlinkRelNode,
+      rightConversion: RelNode => RelNode,
+      providedTraitSet: RelTraitSet): FlinkRelNode = {
+    val newRight = rightInput match {
       case snapshot: FlinkLogicalSnapshot =>
         snapshot.getInput
       case rel: FlinkLogicalRel => rel
     }
-
-    def toHashTraitByColumns(
-        columns: util.Collection[_ <: Number],
-        inputTraitSets: RelTraitSet) = {
-      val distribution = if (columns.size() == 0) {
-        FlinkRelDistribution.SINGLETON
-      } else {
-        FlinkRelDistribution.hash(columns)
-      }
-      inputTraitSets.
-        replace(FlinkConventions.STREAM_PHYSICAL).
-        replace(distribution)
-    }
-
-    val joinInfo = join.analyzeCondition
-    val (leftRequiredTrait, rightRequiredTrait) = (
-      toHashTraitByColumns(joinInfo.leftKeys, left.getTraitSet),
-      toHashTraitByColumns(joinInfo.rightKeys, newRight.getTraitSet))
-
-    val convLeft: RelNode = RelOptRule.convert(left, leftRequiredTrait)
-    val convRight: RelNode = RelOptRule.convert(newRight, rightRequiredTrait)
-    val providedTraitSet: RelTraitSet = 
join.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
-
-    val temporalJoin = new StreamExecTemporalJoin(
+    new StreamExecTemporalJoin(
       join.getCluster,
       providedTraitSet,
-      convLeft,
-      convRight,
+      leftConversion(leftInput),
+      rightConversion(newRight),
       join.getCondition,
       join.getJoinType)
-
-    call.transformTo(temporalJoin)
   }
 }
 

Reply via email to