This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new c686e2e412 [CALCITE-6749] RelMdUtil#setAggChildKeys may return an 
incorrect result
c686e2e412 is described below

commit c686e2e4124056ce23d6f1b97469e715d62d596e
Author: Ruben Quesada Lopez <[email protected]>
AuthorDate: Fri Dec 27 14:55:33 2024 +0000

    [CALCITE-6749] RelMdUtil#setAggChildKeys may return an incorrect result
---
 .../org/apache/calcite/rel/metadata/RelMdUtil.java | 10 ++--
 .../apache/calcite/rel/metadata/RelMdUtilTest.java | 42 +++++++++++++++++
 core/src/test/resources/sql/agg.iq                 |  6 +--
 core/src/test/resources/sql/sub-query.iq           | 55 ++++++++++++----------
 4 files changed, 79 insertions(+), 34 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java 
b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
index 29236bf8bc..6cb45aded5 100644
--- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
+++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
@@ -568,15 +568,15 @@ public class RelMdUtil {
       ImmutableBitSet groupKey,
       Aggregate aggRel,
       ImmutableBitSet.Builder childKey) {
-    List<AggregateCall> aggCalls = aggRel.getAggCallList();
+    final List<AggregateCall> aggCallList = aggRel.getAggCallList();
+    final List<Integer> groupList = aggRel.getGroupSet().asList();
     for (int bit : groupKey) {
       if (bit < aggRel.getGroupCount()) {
         // group by column
-        childKey.set(bit);
+        childKey.set(groupList.get(bit));
       } else {
-        // aggregate column -- set a bit for each argument being
-        // aggregated
-        AggregateCall agg = aggCalls.get(bit - aggRel.getGroupCount());
+        // aggregate column -- set a bit for each argument being aggregated
+        final AggregateCall agg = aggCallList.get(bit - 
aggRel.getGroupCount());
         for (Integer arg : agg.getArgList()) {
           childKey.set(arg);
         }
diff --git 
a/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java 
b/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java
index 1e14216bdc..f5df2d3bf4 100644
--- a/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java
+++ b/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java
@@ -16,11 +16,17 @@
  */
 package org.apache.calcite.rel.metadata;
 
+import org.apache.calcite.plan.hep.HepPlanner;
+import org.apache.calcite.plan.hep.HepProgram;
+import org.apache.calcite.plan.hep.HepProgramBuilder;
 import org.apache.calcite.rel.RelCollations;
 import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.Sort;
+import org.apache.calcite.rel.rules.CoreRules;
 import org.apache.calcite.test.RelMetadataFixture;
 import org.apache.calcite.tools.Frameworks;
+import org.apache.calcite.util.ImmutableBitSet;
 
 import org.junit.jupiter.api.Test;
 
@@ -30,6 +36,7 @@ import static org.hamcrest.CoreMatchers.not;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 
 /**
@@ -110,4 +117,39 @@ public class RelMdUtilTest {
     });
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6749";>[CALCITE-6749]
+   * RelMdUtil#setAggChildKeys may return an incorrect result</a>. */
+  @Test void testSetAggChildKeys() {
+    Frameworks.withPlanner((cluster, relOptSchema, rootSchema) -> {
+      RelNode rel = sql("select d.deptno, count(distinct e.job)\n"
+          + "from sales.emp e\n"
+          + "right outer join sales.dept d on e.deptno = d.deptno\n"
+          + "group by d.deptno")
+          .withRelTransform(relNode -> {
+            final HepProgramBuilder builder = HepProgram.builder();
+            builder.addRuleInstance(CoreRules.AGGREGATE_PROJECT_MERGE);
+            final HepPlanner prePlanner = new HepPlanner(builder.build());
+            prePlanner.setRoot(relNode);
+            return prePlanner.findBestExp();
+          }).toRel();
+      final Aggregate agg = (Aggregate) rel;
+      // We should get an Aggregate(group=[{9}], EXPR$1=[COUNT(DISTINCT $2)])
+      assertEquals(1, agg.getGroupCount());
+      assertEquals(9, agg.getGroupSet().asList().get(0));
+      assertEquals(1, agg.getAggCallList().size());
+      assertEquals(1, agg.getAggCallList().get(0).getArgList().size());
+      assertEquals(2, agg.getAggCallList().get(0).getArgList().get(0));
+      // The childKey corresponding to 0 (group key) must be 9
+      final ImmutableBitSet.Builder builder1 = ImmutableBitSet.builder();
+      RelMdUtil.setAggChildKeys(ImmutableBitSet.of(0), agg, builder1);
+      assertEquals(ImmutableBitSet.of(9), builder1.build());
+      // The childKey corresponding to 1 (count aggCall) must be 2
+      final ImmutableBitSet.Builder builder2 = ImmutableBitSet.builder();
+      RelMdUtil.setAggChildKeys(ImmutableBitSet.of(1), agg, builder2);
+      assertEquals(ImmutableBitSet.of(2), builder2.build());
+      return null;
+    });
+  }
+
 }
diff --git a/core/src/test/resources/sql/agg.iq 
b/core/src/test/resources/sql/agg.iq
index dfd56acd35..7322b84da2 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -2862,9 +2862,9 @@ select MGR, count(distinct DEPTNO, JOB), MIN(SAL), 
MAX(SAL) from "scott".emp gro
 
 !ok
 
-EnumerableAggregate(group=[{0}], EXPR$1=[COUNT($1, $2) FILTER $5], 
EXPR$2=[MIN($3) FILTER $6], EXPR$3=[MIN($4) FILTER $6])
-  EnumerableCalc(expr#0..5=[{inputs}], expr#6=[0], expr#7=[=($t5, $t6)], 
expr#8=[3], expr#9=[=($t5, $t8)], MGR=[$t1], DEPTNO=[$t2], JOB=[$t0], 
EXPR$2=[$t3], EXPR$3=[$t4], $g_0=[$t7], $g_3=[$t9])
-    EnumerableAggregate(group=[{2, 3, 7}], groups=[[{2, 3, 7}, {3}]], 
EXPR$2=[MIN($5)], EXPR$3=[MAX($5)], $g=[GROUPING($3, $7, $2)])
+EnumerableAggregate(group=[{1}], EXPR$1=[COUNT($2, $0) FILTER $5], 
EXPR$2=[MIN($3) FILTER $6], EXPR$3=[MIN($4) FILTER $6])
+  EnumerableCalc(expr#0..5=[{inputs}], expr#6=[0], expr#7=[=($t5, $t6)], 
expr#8=[5], expr#9=[=($t5, $t8)], proj#0..4=[{exprs}], $g_0=[$t7], $g_5=[$t9])
+    EnumerableAggregate(group=[{2, 3, 7}], groups=[[{2, 3, 7}, {3}]], 
EXPR$2=[MIN($5)], EXPR$3=[MAX($5)], $g=[GROUPING($2, $3, $7)])
       EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
diff --git a/core/src/test/resources/sql/sub-query.iq 
b/core/src/test/resources/sql/sub-query.iq
index 35ce9cbb55..6a7927c15f 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -427,35 +427,38 @@ where e.job not in (
 
 !ok
 EnumerableCalc(expr#0..9=[{inputs}], expr#10=[0], expr#11=[=($t5, $t10)], 
expr#12=[IS NULL($t1)], expr#13=[IS NOT NULL($t9)], expr#14=[<($t6, $t5)], 
expr#15=[OR($t12, $t13, $t14)], expr#16=[IS NOT TRUE($t15)], expr#17=[OR($t11, 
$t16)], EMPNO=[$t0], $condition=[$t17])
-  EnumerableHashJoin(condition=[AND(=($1, $7), =($2, $8))], joinType=[left])
-    EnumerableHashJoin(condition=[=($2, $4)], joinType=[left])
-      EnumerableCalc(expr#0..3=[{inputs}], EMPNO=[$t1], JOB=[$t2], 
DEPTNO=[$t3], DEPTNO0=[$t0])
-        EnumerableHashJoin(condition=[=($0, $3)], joinType=[inner])
+  EnumerableMergeJoin(condition=[AND(=($1, $7), =($2, $8))], joinType=[left])
+    EnumerableSort(sort0=[$1], sort1=[$2], dir0=[ASC], dir1=[ASC])
+      EnumerableMergeJoin(condition=[=($2, $4)], joinType=[left])
+        EnumerableMergeJoin(condition=[=($2, $3)], joinType=[inner])
+          EnumerableSort(sort0=[$2], dir0=[ASC])
+            EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], 
DEPTNO=[$t7])
+              EnumerableTableScan(table=[[scott, EMP]])
           EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
             EnumerableTableScan(table=[[scott, DEPT]])
-          EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], 
DEPTNO=[$t7])
-            EnumerableTableScan(table=[[scott, EMP]])
-      EnumerableAggregate(group=[{3}], c=[COUNT()], ck=[COUNT($1)])
-        EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
-          EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], 
DEPTNO=[$t7])
-            EnumerableTableScan(table=[[scott, EMP]])
-          EnumerableAggregate(group=[{1}])
-            EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
-              EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
+        EnumerableSort(sort0=[$0], dir0=[ASC])
+          EnumerableAggregate(group=[{3}], c=[COUNT()], ck=[COUNT($1)])
+            EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
+              EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], 
DEPTNO=[$t7])
                 EnumerableTableScan(table=[[scott, EMP]])
-              EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
-                EnumerableTableScan(table=[[scott, DEPT]])
-    EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t0)], 
proj#0..2=[{exprs}], $condition=[$t3])
-      EnumerableAggregate(group=[{1, 3}], i=[LITERAL_AGG(true)])
-        EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
-          EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], 
DEPTNO=[$t7])
-            EnumerableTableScan(table=[[scott, EMP]])
-          EnumerableAggregate(group=[{1}])
-            EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
-              EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
-                EnumerableTableScan(table=[[scott, EMP]])
-              EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
-                EnumerableTableScan(table=[[scott, DEPT]])
+              EnumerableAggregate(group=[{1}])
+                EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
+                  EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], 
DEPTNO=[$t7])
+                    EnumerableTableScan(table=[[scott, EMP]])
+                  EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+                    EnumerableTableScan(table=[[scott, DEPT]])
+    EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC])
+      EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t0)], 
proj#0..2=[{exprs}], $condition=[$t3])
+        EnumerableAggregate(group=[{1, 3}], i=[LITERAL_AGG(true)])
+          EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
+            EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], 
DEPTNO=[$t7])
+              EnumerableTableScan(table=[[scott, EMP]])
+            EnumerableAggregate(group=[{1}])
+              EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
+                EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
+                  EnumerableTableScan(table=[[scott, EMP]])
+                EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+                  EnumerableTableScan(table=[[scott, DEPT]])
 !plan
 
 # Condition that returns a NULL key.

Reply via email to