This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch release-1.13
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.13 by this push:
     new 3b4f230  [FLINK-22901][table-planner-blink] Introduce getUpsertKeys in 
FlinkRelMetadataQuery
3b4f230 is described below

commit 3b4f2301c41e9d67af03bdae5b459575b5224ff8
Author: Jingsong Lee <[email protected]>
AuthorDate: Fri Sep 10 09:35:00 2021 +0800

    [FLINK-22901][table-planner-blink] Introduce getUpsertKeys in 
FlinkRelMetadataQuery
    
    This closes #17207
---
 .../table/planner/plan/metadata/FlinkMetadata.java |  16 +
 .../plan/metadata/FlinkRelMetadataQuery.java       |  52 +++
 .../metadata/FlinkDefaultRelMetadataProvider.scala |   1 +
 .../plan/metadata/FlinkRelMdUniqueKeys.scala       |  85 ++---
 .../plan/metadata/FlinkRelMdUpsertKeys.scala       | 301 +++++++++++++++++
 .../stream/StreamPhysicalOverAggregate.scala       |   2 +-
 .../StreamCommonSubGraphBasedOptimizer.scala       |   1 +
 .../planner/plan/schema/IntermediateRelTable.scala |   7 +-
 .../plan/metadata/FlinkRelMdHandlerTestBase.scala  | 128 +++++--
 .../plan/metadata/FlinkRelMdUpsertKeysTest.scala   | 372 +++++++++++++++++++++
 10 files changed, 893 insertions(+), 72 deletions(-)

diff --git 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkMetadata.java
 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkMetadata.java
index f3750e7..d7cf6cf 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkMetadata.java
+++ 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkMetadata.java
@@ -32,6 +32,7 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.util.ImmutableBitSet;
 
 import java.lang.reflect.Method;
+import java.util.Set;
 
 /** Contains the interfaces for several specified metadata of flink. */
 public abstract class FlinkMetadata {
@@ -238,4 +239,19 @@ public abstract class FlinkMetadata {
             RelWindowProperties getWindowProperties(RelNode r, 
RelMetadataQuery mq);
         }
     }
+
+    /** Metadata about which combinations of columns are upsert identifiers. */
+    public interface UpsertKeys extends Metadata {
+        Method METHOD = Types.lookupMethod(UpsertKeys.class, "getUpsertKeys");
+
+        MetadataDef<UpsertKeys> DEF =
+                MetadataDef.of(UpsertKeys.class, UpsertKeys.Handler.class, 
METHOD);
+
+        Set<ImmutableBitSet> getUpsertKeys();
+
+        /** Handler API. */
+        interface Handler extends MetadataHandler<UpsertKeys> {
+            Set<ImmutableBitSet> getUpsertKeys(RelNode r, RelMetadataQuery mq);
+        }
+    }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkRelMetadataQuery.java
 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkRelMetadataQuery.java
index b5e37af..5c1cf8d 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkRelMetadataQuery.java
+++ 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/metadata/FlinkRelMetadataQuery.java
@@ -25,10 +25,14 @@ import 
org.apache.flink.table.planner.plan.trait.RelWindowProperties;
 import org.apache.flink.util.Preconditions;
 
 import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Exchange;
 import org.apache.calcite.rel.metadata.JaninoRelMetadataProvider;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.util.ImmutableBitSet;
 
+import java.util.Arrays;
+import java.util.Set;
+
 /**
  * A RelMetadataQuery that defines extended metadata handler in Flink, e.g 
ColumnInterval,
  * ColumnNullCount.
@@ -45,6 +49,7 @@ public class FlinkRelMetadataQuery extends RelMetadataQuery {
     private FlinkMetadata.FlinkDistribution.Handler distributionHandler;
     private FlinkMetadata.ModifiedMonotonicity.Handler 
modifiedMonotonicityHandler;
     private FlinkMetadata.WindowProperties.Handler windowPropertiesHandler;
+    private FlinkMetadata.UpsertKeys.Handler upsertKeysHandler;
 
     /**
      * Returns an instance of FlinkRelMetadataQuery. It ensures that cycles do 
not occur while
@@ -79,6 +84,7 @@ public class FlinkRelMetadataQuery extends RelMetadataQuery {
         this.distributionHandler = HANDLERS.distributionHandler;
         this.modifiedMonotonicityHandler = 
HANDLERS.modifiedMonotonicityHandler;
         this.windowPropertiesHandler = HANDLERS.windowPropertiesHandler;
+        this.upsertKeysHandler = HANDLERS.upsertKeysHandler;
     }
 
     /** Extended handlers. */
@@ -99,6 +105,8 @@ public class FlinkRelMetadataQuery extends RelMetadataQuery {
                 
initialHandler(FlinkMetadata.ModifiedMonotonicity.Handler.class);
         private FlinkMetadata.WindowProperties.Handler windowPropertiesHandler 
=
                 initialHandler(FlinkMetadata.WindowProperties.Handler.class);
+        private FlinkMetadata.UpsertKeys.Handler upsertKeysHandler =
+                initialHandler(FlinkMetadata.UpsertKeys.Handler.class);
     }
 
     /**
@@ -256,4 +264,48 @@ public class FlinkRelMetadataQuery extends 
RelMetadataQuery {
             }
         }
     }
+
+    /**
+     * Determines the set of upsert minimal keys for this expression. A key is 
represented as an
+     * {@link org.apache.calcite.util.ImmutableBitSet}, where each bit 
position represents a 0-based
+     * output column ordinal.
+     *
+     * <p>Different from the unique keys: In distributed streaming computing, 
one record may be
+     * divided into RowKind.UPDATE_BEFORE and RowKind.UPDATE_AFTER. If a key 
changing join is
+     * connected downstream, the two records will be divided into different 
tasks, resulting in
+     * disorder. In this case, the downstream cannot rely on the order of the 
original key. So in
+     * this case, it has unique keys in the traditional sense, but it doesn't 
have upsert keys.
+     *
+     * @return set of keys, or null if this information cannot be determined 
(whereas empty set
+     *     indicates definitely no keys at all)
+     */
+    public Set<ImmutableBitSet> getUpsertKeys(RelNode rel) {
+        for (; ; ) {
+            try {
+                return upsertKeysHandler.getUpsertKeys(rel, this);
+            } catch (JaninoRelMetadataProvider.NoHandler e) {
+                upsertKeysHandler = revise(e.relClass, 
FlinkMetadata.UpsertKeys.DEF);
+            }
+        }
+    }
+
+    /**
+     * Determines the set of upsert minimal keys in a single key group range, 
which means can ignore
+     * exchange by partition keys.
+     *
+     * <p>Some optimizations can rely on this ability to do upsert in a single 
key group range.
+     */
+    public Set<ImmutableBitSet> getUpsertKeysInKeyGroupRange(RelNode rel, 
int[] partitionKeys) {
+        if (rel instanceof Exchange) {
+            Exchange exchange = (Exchange) rel;
+            if (Arrays.equals(
+                    exchange.getDistribution().getKeys().stream()
+                            .mapToInt(Integer::intValue)
+                            .toArray(),
+                    partitionKeys)) {
+                rel = exchange.getInput();
+            }
+        }
+        return getUpsertKeys(rel);
+    }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkDefaultRelMetadataProvider.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkDefaultRelMetadataProvider.scala
index f9d864f..75ccb2d 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkDefaultRelMetadataProvider.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkDefaultRelMetadataProvider.scala
@@ -40,6 +40,7 @@ object FlinkDefaultRelMetadataProvider {
       FlinkRelMdPopulationSize.SOURCE,
       FlinkRelMdColumnUniqueness.SOURCE,
       FlinkRelMdUniqueKeys.SOURCE,
+      FlinkRelMdUpsertKeys.SOURCE,
       FlinkRelMdUniqueGroups.SOURCE,
       FlinkRelMdModifiedMonotonicity.SOURCE,
       RelMdColumnOrigins.SOURCE,
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
index 8ebb7fa..4a76839 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
@@ -35,7 +35,7 @@ import com.google.common.collect.ImmutableSet
 import org.apache.calcite.plan.RelOptTable
 import org.apache.calcite.plan.hep.HepRelVertex
 import org.apache.calcite.plan.volcano.RelSubset
-import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
 import org.apache.calcite.rel.core._
 import org.apache.calcite.rel.metadata._
 import org.apache.calcite.rel.{RelNode, SingleRel}
@@ -45,6 +45,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable
 import org.apache.calcite.util.{Bug, BuiltInMethod, ImmutableBitSet, Util}
 
 import java.util
+import java.util.Set
 
 import scala.collection.JavaConversions._
 
@@ -117,6 +118,18 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       input: RelNode,
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
+    getProjectUniqueKeys(
+      projects,
+      input.getCluster.getTypeFactory,
+      () => mq.getUniqueKeys(input, ignoreNulls),
+      ignoreNulls)
+  }
+
+  def getProjectUniqueKeys(
+      projects: JList[RexNode],
+      typeFactory: RelDataTypeFactory,
+      getInputUniqueKeys :() => util.Set[ImmutableBitSet],
+      ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
     // LogicalProject maps a set of rows to a different set;
     // Without knowledge of the mapping function(whether it
     // preserves uniqueness), it is only safe to derive uniqueness
@@ -144,7 +157,6 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
             val castOperand = a.getOperands.get(0)
             castOperand match {
               case castRef: RexInputRef =>
-                val typeFactory = input.getCluster.getTypeFactory
                 val castType = 
typeFactory.createTypeWithNullability(projExpr.getType, true)
                 val origType = 
typeFactory.createTypeWithNullability(castOperand.getType, true)
                 if (castType == origType) {
@@ -165,7 +177,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       return projUniqueKeySet
     }
 
-    val childUniqueKeySet = mq.getUniqueKeys(input, ignoreNulls)
+    val childUniqueKeySet = getInputUniqueKeys()
     if (childUniqueKeySet != null) {
       // Now add to the projUniqueKeySet the child keys that are fully
       // projected.
@@ -206,6 +218,11 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       rel: Expand,
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
+    getExpandUniqueKeys(rel, () => mq.getUniqueKeys(rel.getInput, ignoreNulls))
+  }
+
+  def getExpandUniqueKeys(
+      rel: Expand, getInputUniqueKeys :() => util.Set[ImmutableBitSet]): 
JSet[ImmutableBitSet] = {
     // mapping input column index to output index for non-null value columns
     val mapInputToOutput = new JHashMap[Int, Int]()
     (0 until rel.getRowType.getFieldCount).filter(_ != 
rel.expandIdIndex).foreach { column =>
@@ -219,7 +236,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       return null
     }
 
-    val inputUniqueKeys = mq.getUniqueKeys(rel.getInput, ignoreNulls)
+    val inputUniqueKeys = getInputUniqueKeys()
     if (inputUniqueKeys == null || inputUniqueKeys.isEmpty) {
       return inputUniqueKeys
     }
@@ -256,7 +273,10 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       rel: Rank,
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
-    val inputUniqueKeys = mq.getUniqueKeys(rel.getInput, ignoreNulls)
+    getRankUniqueKeys(rel, mq.getUniqueKeys(rel.getInput, ignoreNulls))
+  }
+
+  def getRankUniqueKeys(rel: Rank, inputKeys: JSet[ImmutableBitSet]): 
JSet[ImmutableBitSet] = {
     val rankFunColumnIndex = 
RankUtil.getRankNumberColumnIndex(rel).getOrElse(-1)
     // for Rank node that can convert to Deduplicate, unique key is partition 
key
     val canConvertToDeduplicate: Boolean = {
@@ -264,7 +284,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       val isRowNumberType = rel.rankType == RankType.ROW_NUMBER
       val isLimit1 = rankRange match {
         case rankRange: ConstantRankRange =>
-          rankRange.getRankStart() == 1 && rankRange.getRankEnd() == 1
+          rankRange.getRankStart == 1 && rankRange.getRankEnd == 1
         case _ => false
       }
       isRowNumberType && isLimit1
@@ -276,7 +296,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       retSet
     }
     else if (rankFunColumnIndex < 0) {
-      inputUniqueKeys
+      inputKeys
     } else {
       val retSet = new JHashSet[ImmutableBitSet]
       rel.rankType match {
@@ -284,8 +304,8 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
           
retSet.add(rel.partitionKey.union(ImmutableBitSet.of(rankFunColumnIndex)))
         case _ => // do nothing
       }
-      if (inputUniqueKeys != null && inputUniqueKeys.nonEmpty) {
-        inputUniqueKeys.foreach {
+      if (inputKeys != null && inputKeys.nonEmpty) {
+        inputKeys.foreach {
           uniqueKey => retSet.add(uniqueKey)
         }
       }
@@ -323,7 +343,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       rel: Aggregate,
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
-    getUniqueKeysOnAggregate(rel.getGroupSet.toArray, mq, ignoreNulls)
+    getUniqueKeysOnAggregate(rel.getGroupSet.toArray)
   }
 
   def getUniqueKeys(
@@ -331,7 +351,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
     if (rel.isFinal) {
-      getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls)
+      getUniqueKeysOnAggregate(rel.grouping)
     } else {
       null
     }
@@ -341,7 +361,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       rel: StreamPhysicalGroupAggregate,
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
-    getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls)
+    getUniqueKeysOnAggregate(rel.grouping)
   }
 
   def getUniqueKeys(
@@ -353,13 +373,10 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       rel: StreamPhysicalGlobalGroupAggregate,
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): JSet[ImmutableBitSet] = {
-    getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls)
+    getUniqueKeysOnAggregate(rel.grouping)
   }
 
-  def getUniqueKeysOnAggregate(
-      grouping: Array[Int],
-      mq: RelMetadataQuery,
-      ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
+  def getUniqueKeysOnAggregate(grouping: Array[Int]): 
util.Set[ImmutableBitSet] = {
     // group by keys form a unique key
     ImmutableSet.of(ImmutableBitSet.of(grouping.indices: _*))
   }
@@ -371,9 +388,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
     getUniqueKeysOnWindowAgg(
       rel.getRowType.getFieldCount,
       rel.getNamedProperties,
-      rel.getGroupSet.toArray,
-      mq,
-      ignoreNulls)
+      rel.getGroupSet.toArray)
   }
 
   def getUniqueKeys(
@@ -384,9 +399,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       getUniqueKeysOnWindowAgg(
         rel.getRowType.getFieldCount,
         rel.namedWindowProperties,
-        rel.grouping,
-        mq,
-        ignoreNulls)
+        rel.grouping)
     } else {
       null
     }
@@ -397,15 +410,13 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
       mq: RelMetadataQuery,
       ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
     getUniqueKeysOnWindowAgg(
-      rel.getRowType.getFieldCount, rel.namedWindowProperties, rel.grouping, 
mq, ignoreNulls)
+      rel.getRowType.getFieldCount, rel.namedWindowProperties, rel.grouping)
   }
 
-  private def getUniqueKeysOnWindowAgg(
+  def getUniqueKeysOnWindowAgg(
       fieldCount: Int,
       namedProperties: Seq[PlannerNamedWindowProperty],
-      grouping: Array[Int],
-      mq: RelMetadataQuery,
-      ignoreNulls: Boolean): util.Set[ImmutableBitSet] = {
+      grouping: Array[Int]): util.Set[ImmutableBitSet] = {
     if (namedProperties.nonEmpty) {
       val begin = fieldCount - namedProperties.size
       val end = fieldCount - 1
@@ -478,11 +489,10 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
     val leftUniqueKeys = mq.getUniqueKeys(left, ignoreNulls)
     val leftType = left.getRowType
     getJoinUniqueKeys(
-      join.joinInfo, join.joinType, leftType, leftUniqueKeys, null,
+      join.joinType, leftType, leftUniqueKeys, null,
       mq.areColumnsUnique(left, join.joinInfo.leftSet, ignoreNulls),
       // TODO get uniqueKeys from TableSchema of TableSource
-      null,
-      mq)
+      null)
   }
 
   private def getJoinUniqueKeys(
@@ -495,21 +505,18 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
     val leftUniqueKeys = mq.getUniqueKeys(left, ignoreNulls)
     val rightUniqueKeys = mq.getUniqueKeys(right, ignoreNulls)
     getJoinUniqueKeys(
-      joinInfo, joinRelType, left.getRowType, leftUniqueKeys, rightUniqueKeys,
+      joinRelType, left.getRowType, leftUniqueKeys, rightUniqueKeys,
       mq.areColumnsUnique(left, joinInfo.leftSet, ignoreNulls),
-      mq.areColumnsUnique(right, joinInfo.rightSet, ignoreNulls),
-      mq)
+      mq.areColumnsUnique(right, joinInfo.rightSet, ignoreNulls))
   }
 
-  private def getJoinUniqueKeys(
-      joinInfo: JoinInfo,
+  def getJoinUniqueKeys(
       joinRelType: JoinRelType,
       leftType: RelDataType,
       leftUniqueKeys: JSet[ImmutableBitSet],
       rightUniqueKeys: JSet[ImmutableBitSet],
       isLeftUnique: JBoolean,
-      isRightUnique: JBoolean,
-      mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+      isRightUnique: JBoolean): JSet[ImmutableBitSet] = {
 
     // first add the different combinations of concatenated unique keys
     // from the left and the right, adjusting the right hand side keys to
@@ -622,7 +629,7 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
 
 object FlinkRelMdUniqueKeys {
 
-  private val INSTANCE = new FlinkRelMdUniqueKeys
+  val INSTANCE = new FlinkRelMdUniqueKeys
 
   val SOURCE: RelMetadataProvider = 
ReflectiveRelMetadataProvider.reflectiveSource(
     BuiltInMethod.UNIQUE_KEYS.method, INSTANCE)
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeys.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeys.scala
new file mode 100644
index 0000000..2844582
--- /dev/null
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeys.scala
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.metadata
+
+import org.apache.flink.table.planner._
+import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.UpsertKeys
+import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, 
WatermarkAssigner, WindowAggregate}
+import 
org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalGroupAggregateBase,
 BatchPhysicalOverAggregate, BatchPhysicalWindowAggregateBase}
+import 
org.apache.flink.table.planner.plan.nodes.physical.common.CommonPhysicalLookupJoin
+import 
org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamPhysicalChangelogNormalize,
 StreamPhysicalDeduplicate, StreamPhysicalDropUpdateBefore, 
StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, 
StreamPhysicalGroupWindowAggregate, StreamPhysicalIntervalJoin, 
StreamPhysicalLocalGroupAggregate, StreamPhysicalOverAggregate}
+import org.apache.flink.table.planner.plan.schema.IntermediateRelTable
+
+import com.google.common.collect.ImmutableSet
+import org.apache.calcite.plan.hep.HepRelVertex
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.rel.core.{Aggregate, Calc, Exchange, Filter, Join, 
JoinInfo, JoinRelType, Project, SetOp, Sort, TableScan, Window}
+import org.apache.calcite.rel.metadata._
+import org.apache.calcite.rel.{RelDistribution, RelNode, SingleRel}
+import org.apache.calcite.rex.RexNode
+import org.apache.calcite.util.{Bug, ImmutableBitSet, Util}
+
+import java.util
+
+import scala.collection.JavaConversions._
+
+/**
+ * FlinkRelMdUpsertKeys supplies a default implementation of 
[[FlinkRelMetadataQuery#getUpsertKeys]]
+ * for the standard logical algebra.
+ */
+class FlinkRelMdUpsertKeys private extends MetadataHandler[UpsertKeys] {
+
+  override def getDef: MetadataDef[UpsertKeys] = UpsertKeys.DEF
+
+  def getUpsertKeys(rel: TableScan, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    rel.getTable match {
+      case t: IntermediateRelTable => t.upsertKeys
+      case _ => mq.getUniqueKeys(rel)
+    }
+  }
+
+  def getUpsertKeys(rel: Project, mq: RelMetadataQuery): JSet[ImmutableBitSet] 
=
+    getProjectUpsertKeys(rel.getProjects, rel.getInput, mq)
+
+  def getUpsertKeys(rel: Filter, mq: RelMetadataQuery): JSet[ImmutableBitSet] =
+    FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel.getInput)
+
+  def getUpsertKeys(calc: Calc, mq: RelMetadataQuery): JSet[ImmutableBitSet] = 
{
+    val projects = 
calc.getProgram.getProjectList.map(calc.getProgram.expandLocalRef)
+    getProjectUpsertKeys(projects, calc.getInput, mq)
+  }
+
+  private def getProjectUpsertKeys(
+      projects: JList[RexNode],
+      input: RelNode,
+      mq: RelMetadataQuery): JSet[ImmutableBitSet] =
+    FlinkRelMdUniqueKeys.INSTANCE.getProjectUniqueKeys(
+      projects,
+      input.getCluster.getTypeFactory,
+      () => FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(input),
+      ignoreNulls = false)
+
+  def getUpsertKeys(rel: Expand, mq: RelMetadataQuery): JSet[ImmutableBitSet] =
+    FlinkRelMdUniqueKeys.INSTANCE.getExpandUniqueKeys(
+      rel, () => 
FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel.getInput))
+
+  def getUpsertKeys(rel: Exchange, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    val keys = 
FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel.getInput)
+    rel.getDistribution.getType match {
+      case RelDistribution.Type.HASH_DISTRIBUTED =>
+        filterKeys(keys, ImmutableBitSet.of(rel.getDistribution.getKeys))
+      case RelDistribution.Type.SINGLETON => keys
+      case t => throw new UnsupportedOperationException("Unsupported 
distribution type: " + t)
+    }
+  }
+
+  def getUpsertKeys(rel: Rank, mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+    val inputKeys = filterKeys(FlinkRelMetadataQuery.reuseOrCreate(mq)
+        .getUpsertKeys(rel.getInput), rel.partitionKey)
+    FlinkRelMdUniqueKeys.INSTANCE.getRankUniqueKeys(rel, inputKeys)
+  }
+
+  def getUpsertKeys(rel: Sort, mq: RelMetadataQuery): JSet[ImmutableBitSet] =
+    filterKeys(
+      FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel.getInput),
+      ImmutableBitSet.of(rel.getCollation.getKeys))
+
+  def getUpsertKeys(
+      rel: StreamPhysicalDeduplicate, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    
ImmutableSet.of(ImmutableBitSet.of(rel.getUniqueKeys.map(Integer.valueOf).toList))
+  }
+
+  def getUpsertKeys(
+      rel: StreamPhysicalChangelogNormalize, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    
ImmutableSet.of(ImmutableBitSet.of(rel.uniqueKeys.map(Integer.valueOf).toList))
+  }
+
+  def getUpsertKeys(
+      rel: StreamPhysicalDropUpdateBefore, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel.getInput)
+  }
+
+  def getUpsertKeys(
+      rel: Aggregate, mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+    
FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOnAggregate(rel.getGroupSet.toArray)
+  }
+
+  def getUpsertKeys(
+      rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    if (rel.isFinal) {
+      FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOnAggregate(rel.grouping)
+    } else {
+      null
+    }
+  }
+
+  def getUpsertKeys(
+      rel: StreamPhysicalGroupAggregate, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOnAggregate(rel.grouping)
+  }
+
+  def getUpsertKeys(
+      rel: StreamPhysicalLocalGroupAggregate, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = null
+
+  def getUpsertKeys(
+      rel: StreamPhysicalGlobalGroupAggregate, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOnAggregate(rel.grouping)
+  }
+
+  def getUpsertKeys(
+      rel: WindowAggregate, mq: RelMetadataQuery): util.Set[ImmutableBitSet] = 
{
+    FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOnWindowAgg(
+      rel.getRowType.getFieldCount,
+      rel.getNamedProperties,
+      rel.getGroupSet.toArray)
+  }
+
+  def getUpsertKeys(
+      rel: BatchPhysicalWindowAggregateBase, mq: RelMetadataQuery): 
util.Set[ImmutableBitSet] = {
+    if (rel.isFinal) {
+      FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOnWindowAgg(
+        rel.getRowType.getFieldCount,
+        rel.namedWindowProperties,
+        rel.grouping)
+    } else {
+      null
+    }
+  }
+
+  def getUpsertKeys(
+      rel: StreamPhysicalGroupWindowAggregate, mq: RelMetadataQuery): 
util.Set[ImmutableBitSet] = {
+    FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeysOnWindowAgg(
+      rel.getRowType.getFieldCount, rel.namedWindowProperties, rel.grouping)
+  }
+
+  def getUpsertKeys(
+      rel: Window, mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+    getUpsertKeysOnOver(rel, mq, rel.groups.map(_.keys): _*)
+  }
+
+  def getUpsertKeys(
+      rel: BatchPhysicalOverAggregate, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    getUpsertKeysOnOver(rel, mq, ImmutableBitSet.of(rel.partitionKeyIndices: 
_*))
+  }
+
+  def getUpsertKeys(
+      rel: StreamPhysicalOverAggregate, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    getUpsertKeysOnOver(rel, mq, rel.logicWindow.groups.map(_.keys): _*)
+  }
+
+  private def getUpsertKeysOnOver(
+      rel: SingleRel,
+      mq: RelMetadataQuery,
+      distributionKeys: ImmutableBitSet*): JSet[ImmutableBitSet] = {
+    var inputKeys = 
FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel.getInput)
+    for (distributionKey <- distributionKeys) {
+      inputKeys = filterKeys(inputKeys, distributionKey)
+    }
+    inputKeys
+  }
+
+  def getUpsertKeys(
+      join: Join, mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+    val joinInfo = join.analyzeCondition()
+    join.getJoinType match {
+      case JoinRelType.SEMI | JoinRelType.ANTI =>
+        filterKeys(
+          FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(join.getLeft),
+          joinInfo.leftSet())
+      case _ =>
+        getJoinUpsertKeys(joinInfo, join.getJoinType, join.getLeft, 
join.getRight, mq)
+    }
+  }
+
+  def getUpsertKeys(
+      rel: StreamPhysicalIntervalJoin, mq: RelMetadataQuery): 
JSet[ImmutableBitSet] = {
+    val joinInfo = JoinInfo.of(rel.getLeft, rel.getRight, 
rel.originalCondition)
+    getJoinUpsertKeys(joinInfo, rel.getJoinType, rel.getLeft, rel.getRight, mq)
+  }
+
+  def getUpsertKeys(
+      join: CommonPhysicalLookupJoin, mq: RelMetadataQuery): 
util.Set[ImmutableBitSet] = {
+    val left = join.getInput
+    val leftKeys = FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(left)
+    val leftType = left.getRowType
+    val leftJoinKeys = join.joinInfo.leftSet
+    FlinkRelMdUniqueKeys.INSTANCE.getJoinUniqueKeys(
+      join.joinType, leftType, filterKeys(leftKeys, leftJoinKeys), null,
+      areColumnsUpsertKeys(leftKeys, leftJoinKeys),
+      // TODO get uniqueKeys from TableSchema of TableSource
+      null)
+  }
+
+  private def getJoinUpsertKeys(
+      joinInfo: JoinInfo,
+      joinRelType: JoinRelType,
+      left: RelNode,
+      right: RelNode,
+      mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
+    val leftKeys = fmq.getUpsertKeys(left)
+    val rightKeys = fmq.getUpsertKeys(right)
+    FlinkRelMdUniqueKeys.INSTANCE.getJoinUniqueKeys(
+      joinRelType,
+      left.getRowType,
+      filterKeys(leftKeys, joinInfo.leftSet),
+      filterKeys(rightKeys, joinInfo.rightSet),
+      areColumnsUpsertKeys(leftKeys, joinInfo.leftSet),
+      areColumnsUpsertKeys(rightKeys, joinInfo.rightSet))
+  }
+
+  def getUpsertKeys(rel: SetOp, mq: RelMetadataQuery): JSet[ImmutableBitSet] =
+    FlinkRelMdUniqueKeys.INSTANCE.getUniqueKeys(rel, mq, ignoreNulls = false)
+
+  def getUpsertKeys(
+      subset: RelSubset, mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+    if (!Bug.CALCITE_1048_FIXED) {
+      //if the best node is null, so we can get the uniqueKeys based original 
node, due to
+      //the original node is logically equivalent as the rel.
+      val rel = Util.first(subset.getBest, subset.getOriginal)
+      FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(rel)
+    } else {
+      throw new RuntimeException("CALCITE_1048 is fixed, so check this method 
again!")
+    }
+  }
+
+  def getUpsertKeys(
+      subset: HepRelVertex, mq: RelMetadataQuery): JSet[ImmutableBitSet] = {
+    FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(subset.getCurrentRel)
+  }
+
+  def getUpsertKeys(
+      subset: WatermarkAssigner, mq: RelMetadataQuery): JSet[ImmutableBitSet] 
= {
+    FlinkRelMetadataQuery.reuseOrCreate(mq).getUpsertKeys(subset.getInput)
+  }
+
+  private def filterKeys(
+      keys: JSet[ImmutableBitSet], distributionKey: ImmutableBitSet): 
JSet[ImmutableBitSet] = {
+    if (keys != null) {
+      keys.filter(k => k.contains(distributionKey))
+    } else {
+      null
+    }
+  }
+
+  private def areColumnsUpsertKeys(
+      keys: JSet[ImmutableBitSet], columns: ImmutableBitSet): Boolean = {
+    if (keys != null) {
+      keys.exists(columns.contains)
+    } else {
+      false
+    }
+  }
+
+  // Catch-all rule when none of the others apply.
+  def getUpsertKeys(rel: RelNode, mq: RelMetadataQuery): JSet[ImmutableBitSet] 
= null
+}
+
+object FlinkRelMdUpsertKeys {
+
+  private val INSTANCE = new FlinkRelMdUpsertKeys
+
+  val SOURCE: RelMetadataProvider = 
ReflectiveRelMetadataProvider.reflectiveSource(
+    UpsertKeys.METHOD, INSTANCE)
+
+}
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalOverAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalOverAggregate.scala
index b20547c..ff1ea76 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalOverAggregate.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalOverAggregate.scala
@@ -37,7 +37,7 @@ class StreamPhysicalOverAggregate(
     traitSet: RelTraitSet,
     inputRel: RelNode,
     outputRowType: RelDataType,
-    logicWindow: Window)
+    val logicWindow: Window)
   extends StreamPhysicalOverAggregateBase(
     cluster,
     traitSet,
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/StreamCommonSubGraphBasedOptimizer.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/StreamCommonSubGraphBasedOptimizer.scala
index 97a99e2..8468a4c 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/StreamCommonSubGraphBasedOptimizer.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/StreamCommonSubGraphBasedOptimizer.scala
@@ -272,6 +272,7 @@ class StreamCommonSubGraphBasedOptimizer(planner: 
StreamPlanner)
       relNode,
       modifyKindSet,
       isUpdateBeforeRequired,
+      fmq.getUpsertKeys(relNode),
       statistic)
   }
 
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/schema/IntermediateRelTable.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/schema/IntermediateRelTable.scala
index f9b2747..3cbafa5 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/schema/IntermediateRelTable.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/schema/IntermediateRelTable.scala
@@ -23,8 +23,10 @@ import 
org.apache.flink.table.planner.plan.`trait`.{ModifyKindSet, UpdateKind}
 import org.apache.flink.table.planner.plan.stats.FlinkStatistic
 
 import org.apache.calcite.rel.RelNode
+import org.apache.calcite.util.ImmutableBitSet
 
-import java.util.{List => JList}
+import java.util
+import java.util.{Set, List => JList}
 
 /**
   * An intermediate Table to wrap a optimized RelNode inside. The input data 
of this Table is
@@ -42,10 +44,11 @@ class IntermediateRelTable(
     val relNode: RelNode,
     val modifyKindSet: ModifyKindSet,
     val isUpdateBeforeRequired: Boolean,
+    val upsertKeys: util.Set[ImmutableBitSet],
     statistic: FlinkStatistic = FlinkStatistic.UNKNOWN)
   extends FlinkPreparingTableBase(null, relNode.getRowType, names, statistic) {
 
   def this(names: JList[String], relNode: RelNode) {
-    this(names, relNode, ModifyKindSet.INSERT_ONLY, false)
+    this(names, relNode, ModifyKindSet.INSERT_ONLY, false, new 
util.HashSet[ImmutableBitSet]())
   }
 }
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
index 1dbbb54..78cfd4b 100644
--- 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
@@ -39,7 +39,7 @@ import org.apache.flink.table.planner.plan.nodes.calcite._
 import org.apache.flink.table.planner.plan.nodes.logical._
 import org.apache.flink.table.planner.plan.nodes.physical.batch._
 import org.apache.flink.table.planner.plan.nodes.physical.stream._
-import org.apache.flink.table.planner.plan.schema.FlinkPreparingTableBase
+import org.apache.flink.table.planner.plan.schema.{FlinkPreparingTableBase, 
IntermediateRelTable}
 import org.apache.flink.table.planner.plan.stream.sql.join.TestTemporalTable
 import org.apache.flink.table.planner.plan.utils._
 import org.apache.flink.table.planner.utils.Top3
@@ -288,8 +288,12 @@ class FlinkRelMdHandlerTestBase {
   }
 
   // hash exchange on class
-  protected lazy val (batchExchange, streamExchange) = {
-    val hash6 = FlinkRelDistribution.hash(Array(6), requireStrict = true)
+  protected lazy val (batchExchange, streamExchange) = createExchange(6)
+
+  protected lazy val (batchExchangeById, streamExchangeById) = 
createExchange(0)
+
+  protected def createExchange(hash: Int): (RelNode, RelNode) = {
+    val hash6 = FlinkRelDistribution.hash(Array(hash), requireStrict = true)
     val batchExchange = new BatchPhysicalExchange(
       cluster,
       batchPhysicalTraits.replace(hash6),
@@ -305,13 +309,27 @@ class FlinkRelMdHandlerTestBase {
     (batchExchange, streamExchange)
   }
 
+  protected lazy val intermediateTable = new IntermediateRelTable(
+    Seq(""), streamExchangeById, null, false, Set(ImmutableBitSet.of(0)))
+
+  protected lazy val intermediateScan = new FlinkLogicalIntermediateTableScan(
+    cluster, streamExchangeById.getTraitSet, intermediateTable)
+
   // equivalent SQL is
   // select * from student order by class asc, score desc
-  protected lazy val (logicalSort, flinkLogicalSort, batchSort, streamSort) = {
-    val logicalSort = relBuilder.scan("student").sort(
-      relBuilder.field("class"),
-      relBuilder.desc(relBuilder.field("score")))
-      .build.asInstanceOf[LogicalSort]
+  protected lazy val (logicalSort, flinkLogicalSort, batchSort, streamSort) =
+    createSorts(() =>
+      Seq(relBuilder.field("class"),
+      relBuilder.desc(relBuilder.field("score"))))
+
+  // equivalent SQL is
+  // select * from student order by id asc
+  protected lazy val (logicalSortById, flinkLogicalSortById, batchSortById, 
streamSortById) =
+    createSorts(() => Seq(relBuilder.field("id")))
+
+  protected def createSorts(sortKeys: () => Seq[RexNode]): (RelNode, RelNode, 
RelNode, RelNode) = {
+    val logicalSort = relBuilder.scan("student")
+        .sort(sortKeys()).build.asInstanceOf[LogicalSort]
     val collation = logicalSort.getCollation
     val flinkLogicalSort = new FlinkLogicalSort(cluster, 
flinkLogicalTraits.replace(collation),
       studentFlinkLogicalScan, collation, null, null)
@@ -372,11 +390,25 @@ class FlinkRelMdHandlerTestBase {
     batchSortLimit,
     batchLocalSortLimit,
     batchGlobalSortLimit,
-    streamSortLimit) = {
-    val logicalSortLimit = relBuilder.scan("student").sort(
-      relBuilder.field("class"),
-      relBuilder.desc(relBuilder.field("score")))
-      .limit(10, 20).build.asInstanceOf[LogicalSort]
+    streamSortLimit) = createSortLimits(() => Seq(
+    relBuilder.field("class"),
+    relBuilder.desc(relBuilder.field("score"))))
+
+  // equivalent SQL is
+  // select * from student order by id asc limit 20 offset 10
+  protected lazy val (
+      logicalSortLimitById,
+      flinkLogicalSortLimitById,
+      batchSortLimitById,
+      batchLocalSortLimitById,
+      batchGlobalSortLimitById,
+      streamSortLimitById) = createSortLimits(() => Seq(
+      relBuilder.field("id")))
+
+  protected def createSortLimits(sortKeys: () => Seq[RexNode])
+    : (RelNode, RelNode, RelNode, RelNode, RelNode, RelNode) = {
+    val logicalSortLimit = relBuilder.scan("student").sort(sortKeys())
+        .limit(10, 20).build.asInstanceOf[LogicalSort]
 
     val collection = logicalSortLimit.collation
     val offset = logicalSortLimit.offset
@@ -386,7 +418,7 @@ class FlinkRelMdHandlerTestBase {
       flinkLogicalTraits.replace(collection), studentFlinkLogicalScan, 
collection, offset, fetch)
 
     val batchSortLimit = new BatchPhysicalSortLimit(
-        cluster, batchPhysicalTraits.replace(collection),
+      cluster, batchPhysicalTraits.replace(collection),
       new BatchPhysicalExchange(
         cluster, batchPhysicalTraits.replace(FlinkRelDistribution.SINGLETON), 
studentBatchScan,
         FlinkRelDistribution.SINGLETON),
@@ -398,7 +430,7 @@ class FlinkRelMdHandlerTestBase {
       relBuilder.literal(SortUtil.getLimitEnd(offset, fetch)),
       false)
     val batchSortGlobal = new BatchPhysicalSortLimit(
-        cluster, batchPhysicalTraits.replace(collection),
+      cluster, batchPhysicalTraits.replace(collection),
       new BatchPhysicalExchange(
         cluster, batchPhysicalTraits.replace(FlinkRelDistribution.SINGLETON), 
batchSortLocalLimit,
         FlinkRelDistribution.SINGLETON),
@@ -408,7 +440,7 @@ class FlinkRelMdHandlerTestBase {
       studentStreamScan, collection, offset, fetch, 
RankProcessStrategy.UNDEFINED_STRATEGY)
 
     (logicalSortLimit, flinkLogicalSortLimit,
-      batchSortLimit, batchSortLocalLimit, batchSortGlobal, streamSort)
+        batchSortLimit, batchSortLocalLimit, batchSortGlobal, streamSort)
   }
 
   // equivalent SQL is
@@ -417,16 +449,30 @@ class FlinkRelMdHandlerTestBase {
   //  RANK() over (partition by class order by score) rk from student
   // ) t where rk <= 5
   protected lazy val (
-    logicalRank,
-    flinkLogicalRank,
-    batchLocalRank,
-    batchGlobalRank,
-    streamRank) = {
+      logicalRank,
+      flinkLogicalRank,
+      batchLocalRank,
+      batchGlobalRank,
+      streamRank) = createRanks(6)
+
+  // equivalent SQL is
+  // select * from (
+  //  select id, name, score, age, height, sex, class,
+  //  RANK() over (partition by id order by score) rk from student
+  // ) t where rk <= 5
+  protected lazy val (
+      logicalRankById,
+      flinkLogicalRankById,
+      batchLocalRankById,
+      batchGlobalRankById,
+      streamRankById) = createRanks(0)
+
+  protected def createRanks(partitionKey: Int): (RelNode, RelNode, RelNode, 
RelNode, RelNode) = {
     val logicalRank = new LogicalRank(
       cluster,
       logicalTraits,
       studentLogicalScan,
-      ImmutableBitSet.of(6),
+      ImmutableBitSet.of(partitionKey),
       RelCollations.of(2),
       RankType.RANK,
       new ConstantRankRange(1, 5),
@@ -438,7 +484,7 @@ class FlinkRelMdHandlerTestBase {
       cluster,
       flinkLogicalTraits,
       studentFlinkLogicalScan,
-      ImmutableBitSet.of(6),
+      ImmutableBitSet.of(partitionKey),
       RelCollations.of(2),
       RankType.RANK,
       new ConstantRankRange(1, 5),
@@ -450,7 +496,7 @@ class FlinkRelMdHandlerTestBase {
       cluster,
       batchPhysicalTraits,
       studentBatchScan,
-      ImmutableBitSet.of(6),
+      ImmutableBitSet.of(partitionKey),
       RelCollations.of(2),
       RankType.RANK,
       new ConstantRankRange(1, 5),
@@ -459,14 +505,14 @@ class FlinkRelMdHandlerTestBase {
       isGlobal = false
     )
 
-    val hash6 = FlinkRelDistribution.hash(Array(6), requireStrict = true)
+    val hash6 = FlinkRelDistribution.hash(Array(partitionKey), requireStrict = 
true)
     val batchExchange = new BatchPhysicalExchange(
       cluster, batchLocalRank.getTraitSet.replace(hash6), batchLocalRank, 
hash6)
     val batchGlobalRank = new BatchPhysicalRank(
       cluster,
       batchPhysicalTraits,
       batchExchange,
-      ImmutableBitSet.of(6),
+      ImmutableBitSet.of(partitionKey),
       RelCollations.of(2),
       RankType.RANK,
       new ConstantRankRange(1, 5),
@@ -481,7 +527,7 @@ class FlinkRelMdHandlerTestBase {
       cluster,
       streamPhysicalTraits,
       streamExchange,
-      ImmutableBitSet.of(6),
+      ImmutableBitSet.of(partitionKey),
       RelCollations.of(2),
       RankType.RANK,
       new ConstantRankRange(1, 5),
@@ -1975,7 +2021,29 @@ class FlinkRelMdHandlerTestBase {
   //  dense_rank() over (partition by class order by score) as drk,
   //  avg(score) over (partition by class order by score) as avg_score
   //  from student
-  protected lazy val streamOverAgg: StreamPhysicalRel = {
+  protected lazy val streamOverAgg: StreamPhysicalRel = 
createStreamOverAgg(overAggGroups.get(1), 4)
+
+  protected lazy val streamOverAggById: StreamPhysicalRel = 
createStreamOverAgg(
+    new Window.Group(
+      ImmutableBitSet.of(0),
+      true,
+      RexWindowBound.create(SqlWindow.createUnboundedPreceding(new 
SqlParserPos(0, 0)), null),
+      RexWindowBound.create(SqlWindow.createCurrentRow(new SqlParserPos(0, 
0)), null),
+      RelCollationImpl.of(new RelFieldCollation(
+        1, RelFieldCollation.Direction.ASCENDING, 
RelFieldCollation.NullDirection.FIRST)),
+      ImmutableList.of(
+        new Window.RexWinAggCall(
+          SqlStdOperatorTable.ROW_NUMBER,
+          longType,
+          ImmutableList.of[RexNode](),
+          0,
+          false
+        )
+      )
+    ), 0
+  )
+
+  protected def createStreamOverAgg(group: Window.Group, hash: Int): 
StreamPhysicalRel = {
     val types = Map(
       "id" -> longType,
       "name" -> stringType,
@@ -2014,14 +2082,14 @@ class FlinkRelMdHandlerTestBase {
       new FlinkLogicalCalc(cluster, flinkLogicalTraits, 
studentFlinkLogicalScan, rexProgram),
       ImmutableList.of(),
       rowTypeOfWindowAgg,
-      util.Arrays.asList(overAggGroups.get(1))
+      util.Arrays.asList(group)
     )
 
     val streamScan: StreamPhysicalDataStreamScan =
       createDataStreamScan(ImmutableList.of("student"), streamPhysicalTraits)
     val calc = new StreamPhysicalCalc(
       cluster, streamPhysicalTraits, streamScan, rexProgram, rowTypeOfCalc)
-    val hash4 = FlinkRelDistribution.hash(Array(4), requireStrict = true)
+    val hash4 = FlinkRelDistribution.hash(Array(hash), requireStrict = true)
     val exchange = new StreamPhysicalExchange(cluster, 
calc.getTraitSet.replace(hash4), calc, hash4)
 
     val windowAgg = new StreamPhysicalOverAggregate(
diff --git 
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
new file mode 100644
index 0000000..77f570e
--- /dev/null
+++ 
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.metadata
+
+import org.apache.flink.table.planner.plan.nodes.calcite.LogicalExpand
+import 
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalTableSourceScan
+import org.apache.flink.table.planner.plan.schema.TableSourceTable
+import org.apache.flink.table.planner.plan.utils.ExpandUtil
+
+import com.google.common.collect.{ImmutableList, ImmutableSet}
+import org.apache.calcite.prepare.CalciteCatalogReader
+import org.apache.calcite.rel.RelNode
+import org.apache.calcite.rel.hint.RelHint
+import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN}
+import org.apache.calcite.util.ImmutableBitSet
+import org.junit.Assert._
+import org.junit.Test
+
+import java.util.Collections
+
+import scala.collection.JavaConversions._
+
+class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase {
+
+  @Test
+  def testGetUpsertKeysOnTableScan(): Unit = {
+    Array(studentLogicalScan, studentBatchScan, studentStreamScan).foreach { 
scan =>
+      assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(scan).toSet)
+    }
+
+    Array(empLogicalScan, empBatchScan, empStreamScan).foreach { scan =>
+      assertNull(mq.getUpsertKeys(scan))
+    }
+
+    val table = relBuilder
+      .getRelOptSchema
+      .asInstanceOf[CalciteCatalogReader]
+      .getTable(Seq("projected_table_source_table"))
+      .asInstanceOf[TableSourceTable]
+    val tableSourceScan = new StreamPhysicalTableSourceScan(
+      cluster,
+      streamPhysicalTraits,
+      Collections.emptyList[RelHint](),
+      table)
+    assertEquals(toBitSet(Array(0, 2)), 
mq.getUpsertKeys(tableSourceScan).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnProjectedTableScanWithPartialCompositePrimaryKey(): 
Unit = {
+    val table = relBuilder
+      .getRelOptSchema
+      .asInstanceOf[CalciteCatalogReader]
+      .getTable(Seq("projected_table_source_table_with_partial_pk"))
+      .asInstanceOf[TableSourceTable]
+    val tableSourceScan = new StreamPhysicalTableSourceScan(
+      cluster,
+      streamPhysicalTraits,
+      Collections.emptyList[RelHint](),
+      table)
+    assertNull(mq.getUpsertKeys(tableSourceScan))
+  }
+
+  @Test
+  def testGetUpsertKeysOnValues(): Unit = {
+    assertNull(mq.getUpsertKeys(logicalValues))
+    assertNull(mq.getUpsertKeys(emptyValues))
+  }
+
+  @Test
+  def testGetUpsertKeysOnProject(): Unit = {
+    assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalProject).toSet)
+
+    relBuilder.push(studentLogicalScan)
+    // id=1, id, cast(id AS bigint not null), cast(id AS int), $1
+    val exprs = List(
+      relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)),
+      relBuilder.field(0),
+      rexBuilder.makeCast(longType, relBuilder.field(0)),
+      rexBuilder.makeCast(intType, relBuilder.field(0)),
+      relBuilder.field(1))
+    val project1 = relBuilder.project(exprs).build()
+    assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(project1).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnFilter(): Unit = {
+    assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalFilter).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnWatermark(): Unit = {
+    assertEquals(toBitSet(Array(0)), 
mq.getUpsertKeys(logicalWatermarkAssigner).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnCalc(): Unit = {
+    relBuilder.push(studentLogicalScan)
+    // id < 100
+    val expr = relBuilder.call(LESS_THAN, relBuilder.field(0), 
relBuilder.literal(100))
+    val calc1 = createLogicalCalc(
+      studentLogicalScan, logicalProject.getRowType, 
logicalProject.getProjects, List(expr))
+    assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalCalc).toSet)
+
+    // id=1, id, cast(id AS bigint not null), cast(id AS int), $1
+    val exprs = List(
+      relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)),
+      relBuilder.field(0),
+      rexBuilder.makeCast(longType, relBuilder.field(0)),
+      rexBuilder.makeCast(intType, relBuilder.field(0)),
+      relBuilder.field(1))
+    val rowType = relBuilder.project(exprs).build().getRowType
+    val calc2 = createLogicalCalc(studentLogicalScan, rowType, exprs, 
List(expr))
+    assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(calc2).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnExpand(): Unit = {
+    Array(logicalExpand, flinkLogicalExpand, batchExpand, 
streamExpand).foreach {
+      expand => assertEquals(toBitSet(Array(0, 7)), 
mq.getUpsertKeys(expand).toSet)
+    }
+
+    val expandProjects = ExpandUtil.createExpandProjects(
+      studentLogicalScan.getCluster.getRexBuilder,
+      studentLogicalScan.getRowType,
+      ImmutableBitSet.of(0, 1, 2, 3),
+      ImmutableList.of(
+        ImmutableBitSet.of(0),
+        ImmutableBitSet.of(1),
+        ImmutableBitSet.of(2),
+        ImmutableBitSet.of(3)),
+      Array.empty[Integer])
+    val expand = new LogicalExpand(cluster, studentLogicalScan.getTraitSet,
+      studentLogicalScan, expandProjects, 7)
+    assertNull(mq.getUpsertKeys(expand))
+  }
+
+  @Test
+  def testGetUpsertKeysOnExchange(): Unit = {
+    Array(batchExchange, streamExchange).foreach { exchange =>
+      assertEquals(toBitSet(), mq.getUpsertKeys(exchange).toSet)
+    }
+
+    Array(batchExchangeById, streamExchangeById).foreach { exchange =>
+      assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(exchange).toSet)
+    }
+  }
+
+  @Test
+  def testGetUpsertKeysOnRank(): Unit = {
+    Array(logicalRank, flinkLogicalRank, batchLocalRank, batchGlobalRank, 
streamRank).foreach {
+      rank =>
+        assertEquals(toBitSet(), mq.getUpsertKeys(rank).toSet)
+    }
+
+    Array(logicalRankById, flinkLogicalRankById,
+      batchLocalRankById, batchGlobalRankById, streamRankById).foreach {
+      rank =>
+        assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(rank).toSet)
+    }
+
+    Array(logicalRowNumber, flinkLogicalRowNumber, streamRowNumber)
+      .foreach { rank =>
+        assertEquals(toBitSet(Array(0), Array(7)), 
mq.getUpsertKeys(rank).toSet)
+      }
+  }
+
+  @Test
+  def testGetUpsertKeysOnSort(): Unit = {
+    def testWithoutKey(rel: RelNode): Unit = {
+      assertEquals(toBitSet(), mq.getUpsertKeys(rel).toSet)
+    }
+
+    def testWithKey(rel: RelNode): Unit = {
+      assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(rel).toSet)
+    }
+
+    testWithoutKey(logicalSort)
+    testWithoutKey(flinkLogicalSort)
+    testWithoutKey(batchSort)
+    testWithoutKey(streamSort)
+    testWithoutKey(logicalSortLimit)
+    testWithoutKey(flinkLogicalSortLimit)
+    testWithoutKey(batchSortLimit)
+    testWithoutKey(streamSortLimit)
+    testWithoutKey(batchGlobalSortLimit)
+    testWithoutKey(batchLocalSortLimit)
+
+    testWithKey(logicalSortById)
+    testWithKey(flinkLogicalSortById)
+    testWithKey(batchSortById)
+    testWithKey(streamSortById)
+    testWithKey(logicalSortLimitById)
+    testWithKey(flinkLogicalSortLimitById)
+    testWithKey(batchSortLimitById)
+    testWithKey(streamSortLimitById)
+    testWithKey(batchGlobalSortLimitById)
+    testWithKey(batchLocalSortLimitById)
+
+    testWithKey(logicalLimit)
+    testWithKey(flinkLogicalLimit)
+    testWithKey(batchLimit)
+    testWithKey(streamLimit)
+  }
+
+  @Test
+  def testGetUpsertKeysOnStreamExecDeduplicate(): Unit = {
+    assertEquals(
+      toBitSet(Array(1)),
+      mq.getUpsertKeys(streamProcTimeDeduplicateFirstRow).toSet)
+    assertEquals(
+      toBitSet(Array(1, 2)),
+      mq.getUpsertKeys(streamProcTimeDeduplicateLastRow).toSet)
+    assertEquals(
+      toBitSet(Array(1)),
+      mq.getUpsertKeys(streamRowTimeDeduplicateFirstRow).toSet)
+    assertEquals(
+      toBitSet(Array(1, 2)),
+      mq.getUpsertKeys(streamRowTimeDeduplicateLastRow).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnStreamExecChangelogNormalize(): Unit = {
+    assertEquals(toBitSet(Array(1, 0)), 
mq.getUpsertKeys(streamChangelogNormalize).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnStreamExecDropUpdateBefore(): Unit = {
+    assertEquals(toBitSet(Array(0)), 
mq.getUpsertKeys(streamDropUpdateBefore).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnAggregate(): Unit = {
+    Array(logicalAgg, flinkLogicalAgg, batchGlobalAggWithLocal, 
batchGlobalAggWithoutLocal,
+      streamGlobalAggWithLocal, streamGlobalAggWithoutLocal).foreach { agg =>
+      assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(agg).toSet)
+    }
+    assertNull(mq.getUpsertKeys(batchLocalAgg))
+    assertNull(mq.getUpsertKeys(streamLocalAgg))
+
+    Array(logicalAggWithAuxGroup, flinkLogicalAggWithAuxGroup, 
batchGlobalAggWithLocalWithAuxGroup,
+      batchGlobalAggWithoutLocalWithAuxGroup).foreach { agg =>
+      assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(agg).toSet)
+    }
+    assertNull(mq.getUpsertKeys(batchLocalAggWithAuxGroup))
+  }
+
+  @Test
+  def testGetUpsertKeysOnWindowAgg(): Unit = {
+    Array(logicalWindowAgg, flinkLogicalWindowAgg, 
batchGlobalWindowAggWithoutLocalAgg,
+      batchGlobalWindowAggWithLocalAgg).foreach { agg =>
+      assertEquals(ImmutableSet.of(ImmutableBitSet.of(0, 1, 3), 
ImmutableBitSet.of(0, 1, 4),
+        ImmutableBitSet.of(0, 1, 5), ImmutableBitSet.of(0, 1, 6)),
+        mq.getUpsertKeys(agg))
+    }
+    assertNull(mq.getUpsertKeys(batchLocalWindowAgg))
+
+    Array(logicalWindowAggWithAuxGroup, flinkLogicalWindowAggWithAuxGroup,
+      batchGlobalWindowAggWithoutLocalAggWithAuxGroup,
+      batchGlobalWindowAggWithLocalAggWithAuxGroup).foreach { agg =>
+      assertEquals(ImmutableSet.of(ImmutableBitSet.of(0, 3), 
ImmutableBitSet.of(0, 4),
+        ImmutableBitSet.of(0, 5), ImmutableBitSet.of(0, 6)),
+        mq.getUpsertKeys(agg))
+    }
+    assertNull(mq.getUpsertKeys(batchLocalWindowAggWithAuxGroup))
+  }
+
+  @Test
+  def testGetUpsertKeysOnOverAgg(): Unit = {
+    Array(flinkLogicalOverAgg, batchOverAgg, streamOverAgg).foreach { agg =>
+      assertEquals(toBitSet(), mq.getUpsertKeys(agg).toSet)
+    }
+
+    assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(streamOverAggById).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnJoin(): Unit = {
+    assertEquals(toBitSet(Array(1), Array(5), Array(1, 5), Array(5, 6), 
Array(1, 5, 6)),
+      mq.getUpsertKeys(logicalInnerJoinOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalInnerJoinNotOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalInnerJoinOnRHSUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalInnerJoinWithoutEquiCond).toSet)
+    assertEquals(
+      toBitSet(), 
mq.getUpsertKeys(logicalInnerJoinWithEquiAndNonEquiCond).toSet)
+
+    assertEquals(toBitSet(Array(1), Array(1, 5), Array(1, 5, 6)),
+      mq.getUpsertKeys(logicalLeftJoinOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalLeftJoinNotOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalLeftJoinOnRHSUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalLeftJoinWithoutEquiCond).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalLeftJoinWithEquiAndNonEquiCond).toSet)
+
+    assertEquals(toBitSet(Array(5), Array(1, 5), Array(5, 6), Array(1, 5, 6)),
+      mq.getUpsertKeys(logicalRightJoinOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalRightJoinNotOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalRightJoinOnLHSUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalRightJoinWithoutEquiCond).toSet)
+    assertEquals(
+      toBitSet(), 
mq.getUpsertKeys(logicalRightJoinWithEquiAndNonEquiCond).toSet)
+
+    assertEquals(toBitSet(Array(1, 5), Array(1, 5, 6)),
+      mq.getUpsertKeys(logicalFullJoinOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalFullJoinNotOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalFullJoinOnRHSUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalFullJoinWithoutEquiCond).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalFullJoinWithEquiAndNonEquiCond).toSet)
+
+    assertEquals(toBitSet(Array(1)), 
mq.getUpsertKeys(logicalSemiJoinOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalSemiJoinNotOnUniqueKeys).toSet)
+    assertNull(mq.getUpsertKeys(logicalSemiJoinOnRHSUniqueKeys))
+    assertEquals(
+      toBitSet(Array(1)), 
mq.getUpsertKeys(logicalSemiJoinWithoutEquiCond).toSet)
+    assertEquals(toBitSet(Array(1)),
+      mq.getUpsertKeys(logicalSemiJoinWithEquiAndNonEquiCond).toSet)
+
+    assertEquals(toBitSet(Array(1)),
+      mq.getUpsertKeys(logicalAntiJoinOnUniqueKeys).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalAntiJoinNotOnUniqueKeys).toSet)
+    assertNull(mq.getUpsertKeys(logicalAntiJoinOnRHSUniqueKeys))
+    assertEquals(
+      toBitSet(Array(1)), 
mq.getUpsertKeys(logicalAntiJoinWithoutEquiCond).toSet)
+    assertEquals(toBitSet(), 
mq.getUpsertKeys(logicalAntiJoinWithEquiAndNonEquiCond).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnLookupJoin(): Unit = {
+    Array(batchLookupJoin, streamLookupJoin).foreach { join =>
+      assertEquals(toBitSet(), mq.getUpsertKeys(join).toSet)
+    }
+  }
+
+  @Test
+  def testGetUpsertKeysOnSetOp(): Unit = {
+    Array(logicalUnionAll, logicalIntersectAll, logicalMinusAll).foreach { 
setOp =>
+      assertEquals(toBitSet(), mq.getUpsertKeys(setOp).toSet)
+    }
+
+    Array(logicalUnion, logicalIntersect, logicalMinus).foreach { setOp =>
+      assertEquals(toBitSet(Array(0, 1, 2, 3, 4)), 
mq.getUpsertKeys(setOp).toSet)
+    }
+  }
+
+  @Test
+  def testGetUpsertKeysOnDefault(): Unit = {
+    assertNull(mq.getUpsertKeys(testRel))
+  }
+
+  @Test
+  def testGetUpsertKeysOnIntermediateScan(): Unit = {
+    assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(intermediateScan).toSet)
+  }
+
+  private def toBitSet(keys: Array[Int]*): Set[ImmutableBitSet] = {
+    keys.map(k => ImmutableBitSet.of(k: _*)).toSet
+  }
+}

Reply via email to