This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new 7353fa9  [CALCITE-4208] Improve metadata row count for Join
7353fa9 is described below

commit 7353fa94bad8d5db48a217799d8c9567c9639a2f
Author: rubenada <[email protected]>
AuthorDate: Wed Sep 2 09:43:42 2020 +0100

    [CALCITE-4208] Improve metadata row count for Join
---
 .../org/apache/calcite/rel/metadata/RelMdUtil.java | 32 ++++++++++++++--------
 .../java/org/apache/calcite/util/NumberUtil.java   |  8 ++++++
 .../org/apache/calcite/test/RelMetadataTest.java   | 21 ++++++++++++--
 core/src/test/resources/sql/sub-query.iq           | 30 ++++++++++----------
 4 files changed, 63 insertions(+), 28 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java 
b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
index 48e0301..ae94a46 100644
--- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
+++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
@@ -43,6 +43,7 @@ import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.NumberUtil;
+import org.apache.calcite.util.Util;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
@@ -713,9 +714,11 @@ public class RelMdUtil {
       // semijoin filter and pass it to getSelectivity
       RexNode semiJoinSelectivity =
           RelMdUtil.makeSemiJoinSelectivityRexNode(mq, join);
-
+      Double selectivity = mq.getSelectivity(join.getLeft(), 
semiJoinSelectivity);
       return NumberUtil.multiply(
-          mq.getSelectivity(join.getLeft(), semiJoinSelectivity),
+          join.getJoinType() == JoinRelType.SEMI
+              ? selectivity
+              : NumberUtil.subtract(1D, selectivity), // ANTI join
           mq.getRowCount(join.getLeft()));
     }
     // Row count estimates of 0 will be rounded up to 1.
@@ -731,19 +734,24 @@ public class RelMdUtil {
         return max;
       }
     }
-    double product = left * right;
-
-    return product * mq.getSelectivity(join, condition);
-  }
 
-  /** Returns an estimate of the number of rows returned by a semi-join. */
-  public static Double getSemiJoinRowCount(RelMetadataQuery mq, RelNode left,
-      RelNode right, JoinRelType joinType, RexNode condition) {
-    final Double leftCount = mq.getRowCount(left);
-    if (leftCount == null) {
+    Double selectivity = mq.getSelectivity(join, condition);
+    if (selectivity == null) {
       return null;
     }
-    return leftCount * RexUtil.getSelectivity(condition);
+    double innerRowCount = left * right * selectivity;
+    switch (join.getJoinType()) {
+    case INNER:
+      return innerRowCount;
+    case LEFT:
+      return left * (1D - selectivity) + innerRowCount;
+    case RIGHT:
+      return right * (1D - selectivity) + innerRowCount;
+    case FULL:
+      return (left + right) * (1D - selectivity) + innerRowCount;
+    default:
+      throw Util.unexpected(join.getJoinType());
+    }
   }
 
   public static double estimateFilteredRows(RelNode child, RexProgram program,
diff --git a/core/src/main/java/org/apache/calcite/util/NumberUtil.java 
b/core/src/main/java/org/apache/calcite/util/NumberUtil.java
index b5dade9..23ef7d7 100644
--- a/core/src/main/java/org/apache/calcite/util/NumberUtil.java
+++ b/core/src/main/java/org/apache/calcite/util/NumberUtil.java
@@ -143,6 +143,14 @@ public class NumberUtil {
     return a + b;
   }
 
+  public static Double subtract(Double a, Double b) {
+    if ((a == null) || (b == null)) {
+      return null;
+    }
+
+    return a - b;
+  }
+
   public static Double divide(Double a, Double b) {
     if ((a == null) || (b == null) || (b == 0D)) {
       return null;
diff --git a/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java 
b/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java
index eba2feb..8c5bd1d 100644
--- a/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java
@@ -599,8 +599,8 @@ public class RelMetadataTest extends SqlToRelTestBase {
     final String sql = "select * from (select * from emp limit 0) as emp\n"
         + "right join (select * from dept limit 4) as dept\n"
         + "on emp.deptno = dept.deptno";
-    checkRowCount(sql, 1D, // 0, rounded up to row count's minimum 1
-        0D, 4D); // 1 * 4
+    checkRowCount(sql, 4D,
+        0D, 4D);
   }
 
   @Test void testRowCountJoinFiniteEmpty() {
@@ -611,6 +611,23 @@ public class RelMetadataTest extends SqlToRelTestBase {
         0D, 0D); // 7 * 0
   }
 
+  @Test void testRowCountLeftJoinFiniteEmpty() {
+    final String sql = "select * from (select * from emp limit 4) as emp\n"
+        + "left join (select * from dept limit 0) as dept\n"
+        + "on emp.deptno = dept.deptno";
+    checkRowCount(sql, 4D,
+        0D, 4D);
+  }
+
+  @Test void testRowCountRightJoinFiniteEmpty() {
+    final String sql = "select * from (select * from emp limit 4) as emp\n"
+        + "right join (select * from dept limit 0) as dept\n"
+        + "on emp.deptno = dept.deptno";
+    checkRowCount(sql, 1D, // 0, rounded up to row count's minimum 1
+        0D, 0D); // 0 * 4
+  }
+
+
   @Test void testRowCountJoinEmptyEmpty() {
     final String sql = "select * from (select * from emp limit 0) as emp\n"
         + "inner join (select * from dept limit 0) as dept\n"
diff --git a/core/src/test/resources/sql/sub-query.iq 
b/core/src/test/resources/sql/sub-query.iq
index 3d1c3a9..2a133f7 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -2053,21 +2053,23 @@ where sal + 100 not in (
 !ok
 EnumerableAggregate(group=[{}], C=[COUNT()])
   EnumerableCalc(expr#0..9=[{inputs}], expr#10=[0], expr#11=[=($t4, $t10)], 
expr#12=[IS NULL($t2)], expr#13=[IS NOT NULL($t7)], expr#14=[<($t5, $t4)], 
expr#15=[OR($t12, $t13, $t14)], expr#16=[IS NOT TRUE($t15)], expr#17=[OR($t11, 
$t16)], proj#0..9=[{exprs}], $condition=[$t17])
-    EnumerableHashJoin(condition=[AND(=($1, $8), =($2, $9))], joinType=[left])
-      EnumerableMergeJoin(condition=[=($1, $3)], joinType=[left])
-        EnumerableSort(sort0=[$1], dir0=[ASC])
-          EnumerableCalc(expr#0..7=[{inputs}], proj#0..1=[{exprs}], SAL=[$t5])
-            EnumerableTableScan(table=[[scott, EMP]])
-        EnumerableSort(sort0=[$0], dir0=[ASC])
-          EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1:BIGINT], expr#4=[IS 
NOT NULL($t1)], DNAME=[$t1], $f1=[$t3], $f2=[$t3], $condition=[$t4])
-            EnumerableTableScan(table=[[scott, DEPT]])
-      EnumerableCalc(expr#0..4=[{inputs}], DEPTNO=[$t2], i=[$t3], DNAME=[$t4], 
SAL=[$t0])
-        EnumerableHashJoin(condition=[=($1, $2)], joinType=[inner])
-          EnumerableCalc(expr#0=[{inputs}], expr#1=[100], expr#2=[+($t0, 
$t1)], SAL=[$t0], $f1=[$t2])
-            EnumerableAggregate(group=[{5}])
+    EnumerableMergeJoin(condition=[AND(=($1, $8), =($2, $9))], joinType=[left])
+      EnumerableSort(sort0=[$1], sort1=[$2], dir0=[ASC], dir1=[ASC])
+        EnumerableMergeJoin(condition=[=($1, $3)], joinType=[left])
+          EnumerableSort(sort0=[$1], dir0=[ASC])
+            EnumerableCalc(expr#0..7=[{inputs}], proj#0..1=[{exprs}], 
SAL=[$t5])
               EnumerableTableScan(table=[[scott, EMP]])
-          EnumerableCalc(expr#0..2=[{inputs}], expr#3=[true], expr#4=[IS NOT 
NULL($t1)], DEPTNO=[$t0], i=[$t3], DNAME=[$t1], $condition=[$t4])
-            EnumerableTableScan(table=[[scott, DEPT]])
+          EnumerableSort(sort0=[$0], dir0=[ASC])
+            EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1:BIGINT], expr#4=[IS 
NOT NULL($t1)], DNAME=[$t1], $f1=[$t3], $f2=[$t3], $condition=[$t4])
+              EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableSort(sort0=[$2], sort1=[$3], dir0=[ASC], dir1=[ASC])
+        EnumerableCalc(expr#0..4=[{inputs}], DEPTNO=[$t2], i=[$t3], 
DNAME=[$t4], SAL=[$t0])
+          EnumerableHashJoin(condition=[=($1, $2)], joinType=[inner])
+            EnumerableCalc(expr#0=[{inputs}], expr#1=[100], expr#2=[+($t0, 
$t1)], SAL=[$t0], $f1=[$t2])
+              EnumerableAggregate(group=[{5}])
+                EnumerableTableScan(table=[[scott, EMP]])
+            EnumerableCalc(expr#0..2=[{inputs}], expr#3=[true], expr#4=[IS NOT 
NULL($t1)], DEPTNO=[$t0], i=[$t3], DNAME=[$t1], $condition=[$t4])
+              EnumerableTableScan(table=[[scott, DEPT]])
 !plan
 
 # Correlated ANY sub-query

Reply via email to