This is an automated email from the ASF dual-hosted git repository.
rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new c686e2e412 [CALCITE-6749] RelMdUtil#setAggChildKeys may return an
incorrect result
c686e2e412 is described below
commit c686e2e4124056ce23d6f1b97469e715d62d596e
Author: Ruben Quesada Lopez <[email protected]>
AuthorDate: Fri Dec 27 14:55:33 2024 +0000
[CALCITE-6749] RelMdUtil#setAggChildKeys may return an incorrect result
---
.../org/apache/calcite/rel/metadata/RelMdUtil.java | 10 ++--
.../apache/calcite/rel/metadata/RelMdUtilTest.java | 42 +++++++++++++++++
core/src/test/resources/sql/agg.iq | 6 +--
core/src/test/resources/sql/sub-query.iq | 55 ++++++++++++----------
4 files changed, 79 insertions(+), 34 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
index 29236bf8bc..6cb45aded5 100644
--- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
+++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java
@@ -568,15 +568,15 @@ public class RelMdUtil {
ImmutableBitSet groupKey,
Aggregate aggRel,
ImmutableBitSet.Builder childKey) {
- List<AggregateCall> aggCalls = aggRel.getAggCallList();
+ final List<AggregateCall> aggCallList = aggRel.getAggCallList();
+ final List<Integer> groupList = aggRel.getGroupSet().asList();
for (int bit : groupKey) {
if (bit < aggRel.getGroupCount()) {
// group by column
- childKey.set(bit);
+ childKey.set(groupList.get(bit));
} else {
- // aggregate column -- set a bit for each argument being
- // aggregated
- AggregateCall agg = aggCalls.get(bit - aggRel.getGroupCount());
+ // aggregate column -- set a bit for each argument being aggregated
+ final AggregateCall agg = aggCallList.get(bit -
aggRel.getGroupCount());
for (Integer arg : agg.getArgList()) {
childKey.set(arg);
}
diff --git
a/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java
b/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java
index 1e14216bdc..f5df2d3bf4 100644
--- a/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java
+++ b/core/src/test/java/org/apache/calcite/rel/metadata/RelMdUtilTest.java
@@ -16,11 +16,17 @@
*/
package org.apache.calcite.rel.metadata;
+import org.apache.calcite.plan.hep.HepPlanner;
+import org.apache.calcite.plan.hep.HepProgram;
+import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Sort;
+import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.test.RelMetadataFixture;
import org.apache.calcite.tools.Frameworks;
+import org.apache.calcite.util.ImmutableBitSet;
import org.junit.jupiter.api.Test;
@@ -30,6 +36,7 @@ import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
/**
@@ -110,4 +117,39 @@ public class RelMdUtilTest {
});
}
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6749">[CALCITE-6749]
+ * RelMdUtil#setAggChildKeys may return an incorrect result</a>. */
+ @Test void testSetAggChildKeys() {
+ Frameworks.withPlanner((cluster, relOptSchema, rootSchema) -> {
+ RelNode rel = sql("select d.deptno, count(distinct e.job)\n"
+ + "from sales.emp e\n"
+ + "right outer join sales.dept d on e.deptno = d.deptno\n"
+ + "group by d.deptno")
+ .withRelTransform(relNode -> {
+ final HepProgramBuilder builder = HepProgram.builder();
+ builder.addRuleInstance(CoreRules.AGGREGATE_PROJECT_MERGE);
+ final HepPlanner prePlanner = new HepPlanner(builder.build());
+ prePlanner.setRoot(relNode);
+ return prePlanner.findBestExp();
+ }).toRel();
+ final Aggregate agg = (Aggregate) rel;
+ // We should get an Aggregate(group=[{9}], EXPR$1=[COUNT(DISTINCT $2)])
+ assertEquals(1, agg.getGroupCount());
+ assertEquals(9, agg.getGroupSet().asList().get(0));
+ assertEquals(1, agg.getAggCallList().size());
+ assertEquals(1, agg.getAggCallList().get(0).getArgList().size());
+ assertEquals(2, agg.getAggCallList().get(0).getArgList().get(0));
+ // The childKey corresponding to 0 (group key) must be 9
+ final ImmutableBitSet.Builder builder1 = ImmutableBitSet.builder();
+ RelMdUtil.setAggChildKeys(ImmutableBitSet.of(0), agg, builder1);
+ assertEquals(ImmutableBitSet.of(9), builder1.build());
+ // The childKey corresponding to 1 (count aggCall) must be 2
+ final ImmutableBitSet.Builder builder2 = ImmutableBitSet.builder();
+ RelMdUtil.setAggChildKeys(ImmutableBitSet.of(1), agg, builder2);
+ assertEquals(ImmutableBitSet.of(2), builder2.build());
+ return null;
+ });
+ }
+
}
diff --git a/core/src/test/resources/sql/agg.iq
b/core/src/test/resources/sql/agg.iq
index dfd56acd35..7322b84da2 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -2862,9 +2862,9 @@ select MGR, count(distinct DEPTNO, JOB), MIN(SAL),
MAX(SAL) from "scott".emp gro
!ok
-EnumerableAggregate(group=[{0}], EXPR$1=[COUNT($1, $2) FILTER $5],
EXPR$2=[MIN($3) FILTER $6], EXPR$3=[MIN($4) FILTER $6])
- EnumerableCalc(expr#0..5=[{inputs}], expr#6=[0], expr#7=[=($t5, $t6)],
expr#8=[3], expr#9=[=($t5, $t8)], MGR=[$t1], DEPTNO=[$t2], JOB=[$t0],
EXPR$2=[$t3], EXPR$3=[$t4], $g_0=[$t7], $g_3=[$t9])
- EnumerableAggregate(group=[{2, 3, 7}], groups=[[{2, 3, 7}, {3}]],
EXPR$2=[MIN($5)], EXPR$3=[MAX($5)], $g=[GROUPING($3, $7, $2)])
+EnumerableAggregate(group=[{1}], EXPR$1=[COUNT($2, $0) FILTER $5],
EXPR$2=[MIN($3) FILTER $6], EXPR$3=[MIN($4) FILTER $6])
+ EnumerableCalc(expr#0..5=[{inputs}], expr#6=[0], expr#7=[=($t5, $t6)],
expr#8=[5], expr#9=[=($t5, $t8)], proj#0..4=[{exprs}], $g_0=[$t7], $g_5=[$t9])
+ EnumerableAggregate(group=[{2, 3, 7}], groups=[[{2, 3, 7}, {3}]],
EXPR$2=[MIN($5)], EXPR$3=[MAX($5)], $g=[GROUPING($2, $3, $7)])
EnumerableTableScan(table=[[scott, EMP]])
!plan
diff --git a/core/src/test/resources/sql/sub-query.iq
b/core/src/test/resources/sql/sub-query.iq
index 35ce9cbb55..6a7927c15f 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -427,35 +427,38 @@ where e.job not in (
!ok
EnumerableCalc(expr#0..9=[{inputs}], expr#10=[0], expr#11=[=($t5, $t10)],
expr#12=[IS NULL($t1)], expr#13=[IS NOT NULL($t9)], expr#14=[<($t6, $t5)],
expr#15=[OR($t12, $t13, $t14)], expr#16=[IS NOT TRUE($t15)], expr#17=[OR($t11,
$t16)], EMPNO=[$t0], $condition=[$t17])
- EnumerableHashJoin(condition=[AND(=($1, $7), =($2, $8))], joinType=[left])
- EnumerableHashJoin(condition=[=($2, $4)], joinType=[left])
- EnumerableCalc(expr#0..3=[{inputs}], EMPNO=[$t1], JOB=[$t2],
DEPTNO=[$t3], DEPTNO0=[$t0])
- EnumerableHashJoin(condition=[=($0, $3)], joinType=[inner])
+ EnumerableMergeJoin(condition=[AND(=($1, $7), =($2, $8))], joinType=[left])
+ EnumerableSort(sort0=[$1], sort1=[$2], dir0=[ASC], dir1=[ASC])
+ EnumerableMergeJoin(condition=[=($2, $4)], joinType=[left])
+ EnumerableMergeJoin(condition=[=($2, $3)], joinType=[inner])
+ EnumerableSort(sort0=[$2], dir0=[ASC])
+ EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2],
DEPTNO=[$t7])
+ EnumerableTableScan(table=[[scott, EMP]])
EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
EnumerableTableScan(table=[[scott, DEPT]])
- EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2],
DEPTNO=[$t7])
- EnumerableTableScan(table=[[scott, EMP]])
- EnumerableAggregate(group=[{3}], c=[COUNT()], ck=[COUNT($1)])
- EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
- EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2],
DEPTNO=[$t7])
- EnumerableTableScan(table=[[scott, EMP]])
- EnumerableAggregate(group=[{1}])
- EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
- EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
+ EnumerableSort(sort0=[$0], dir0=[ASC])
+ EnumerableAggregate(group=[{3}], c=[COUNT()], ck=[COUNT($1)])
+ EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
+ EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2],
DEPTNO=[$t7])
EnumerableTableScan(table=[[scott, EMP]])
- EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
- EnumerableTableScan(table=[[scott, DEPT]])
- EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t0)],
proj#0..2=[{exprs}], $condition=[$t3])
- EnumerableAggregate(group=[{1, 3}], i=[LITERAL_AGG(true)])
- EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
- EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2],
DEPTNO=[$t7])
- EnumerableTableScan(table=[[scott, EMP]])
- EnumerableAggregate(group=[{1}])
- EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
- EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
- EnumerableTableScan(table=[[scott, EMP]])
- EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
- EnumerableTableScan(table=[[scott, DEPT]])
+ EnumerableAggregate(group=[{1}])
+ EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
+ EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0],
DEPTNO=[$t7])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+ EnumerableTableScan(table=[[scott, DEPT]])
+ EnumerableSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC])
+ EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t0)],
proj#0..2=[{exprs}], $condition=[$t3])
+ EnumerableAggregate(group=[{1, 3}], i=[LITERAL_AGG(true)])
+ EnumerableNestedLoopJoin(condition=[>($2, $3)], joinType=[inner])
+ EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2],
DEPTNO=[$t7])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableAggregate(group=[{1}])
+ EnumerableHashJoin(condition=[=($1, $2)], joinType=[semi])
+ EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+ EnumerableTableScan(table=[[scott, DEPT]])
!plan
# Condition that returns a NULL key.