This is an automated email from the ASF dual-hosted git repository.
rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 29c413a3c6 [CALCITE-6468] RelDecorrelator throws AssertionError if
correlated variable is used as Aggregate group key
29c413a3c6 is described below
commit 29c413a3c6933d1c871047b312b2b81235f3d1c6
Author: Ruben Quesada Lopez <[email protected]>
AuthorDate: Fri Jul 12 18:25:45 2024 +0100
[CALCITE-6468] RelDecorrelator throws AssertionError if correlated variable
is used as Aggregate group key
---
.../apache/calcite/sql2rel/RelDecorrelator.java | 16 ++--
.../calcite/sql2rel/RelDecorrelatorTest.java | 87 ++++++++++++++++++++++
core/src/test/resources/sql/sub-query.iq | 14 ++++
3 files changed, 112 insertions(+), 5 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index f084d279b7..79fde87ee2 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -560,11 +560,17 @@ public class RelDecorrelator implements ReflectiveVisitor
{
// Now add the corVars from the input, starting from
// position oldGroupKeyCount.
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
- RexInputRef.add2(projects, entry.getValue(), newInputOutput);
-
- corDefOutputs.put(entry.getKey(), newPos);
- mapNewInputToProjOutputs.put(entry.getValue(), newPos);
- newPos++;
+ // Verify if the CorDef position was already added to the
mapNewInputToProjOutputs
+ // during the previous group key processing
+ final Integer pos = mapNewInputToProjOutputs.get(entry.getValue());
+ if (pos == null) {
+ RexInputRef.add2(projects, entry.getValue(), newInputOutput);
+ corDefOutputs.put(entry.getKey(), newPos);
+ mapNewInputToProjOutputs.put(entry.getValue(), newPos);
+ newPos++;
+ } else {
+ corDefOutputs.put(entry.getKey(), pos);
+ }
}
}
diff --git
a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
index e471ff8f74..dbbf69d872 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
@@ -16,22 +16,35 @@
*/
package org.apache.calcite.sql2rel;
+import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitDef;
+import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.schema.SchemaPlus;
+import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.test.CalciteAssert;
+import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
+import org.apache.calcite.tools.Planner;
+import org.apache.calcite.tools.Program;
+import org.apache.calcite.tools.Programs;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Holder;
+import org.apache.calcite.util.TestUtil;
+
+import com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.jupiter.api.Test;
+import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import static org.apache.calcite.test.Matchers.hasTree;
@@ -86,4 +99,78 @@ public class RelDecorrelatorTest {
+ " LogicalTableScan(table=[[scott, DEPT]])\n";
assertThat(after, hasTree(planAfter));
}
+
+ /**
+ * Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6468">[CALCITE-6468]
RelDecorrelator
+ * throws AssertionError if correlated variable is used as Aggregate group
key</a>.
+ */
+ @Test void testCorrVarOnAggregateKey() {
+ final FrameworkConfig frameworkConfig = config().build();
+ final RelBuilder builder = RelBuilder.create(frameworkConfig);
+ final RelOptCluster cluster = builder.getCluster();
+ final Planner planner = Frameworks.getPlanner(frameworkConfig);
+ final String sql = "WITH agg_sal AS"
+ + " (SELECT deptno, sum(sal) AS total FROM emp GROUP BY deptno)\n"
+ + " SELECT 1 FROM agg_sal s1"
+ + " WHERE s1.total > (SELECT avg(total) FROM agg_sal s2 WHERE
s1.deptno = s2.deptno)";
+ final RelNode originalRel;
+ try {
+ final SqlNode parse = planner.parse(sql);
+ final SqlNode validate = planner.validate(parse);
+ originalRel = planner.rel(validate).rel;
+ } catch (Exception e) {
+ throw TestUtil.rethrow(e);
+ }
+
+ final HepProgram hepProgram = HepProgram.builder()
+ .addRuleCollection(
+ ImmutableList.of(
+ // SubQuery program rules
+ CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
+ CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
+ CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
+ // plus FilterAggregateTransposeRule
+ CoreRules.FILTER_AGGREGATE_TRANSPOSE))
+ .build();
+ final Program program =
+ Programs.of(hepProgram, true,
Objects.requireNonNull(cluster.getMetadataProvider()));
+ final RelNode before =
+ program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
+ Collections.emptyList(), Collections.emptyList());
+ final String planBefore = ""
+ + "LogicalProject(EXPR$0=[1])\n"
+ + " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+ + " LogicalFilter(condition=[>($1, $2)])\n"
+ + " LogicalCorrelate(correlation=[$cor0], joinType=[left],
requiredColumns=[{0}])\n"
+ + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n"
+ + " LogicalAggregate(group=[{}], EXPR$0=[AVG($0)])\n"
+ + " LogicalProject(TOTAL=[$1])\n"
+ + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ + " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+ + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n";
+ assertThat(before, hasTree(planBefore));
+
+ // Check decorrelation does not fail here
+ final RelNode after = RelDecorrelator.decorrelateQuery(before, builder);
+
+ // Verify plan
+ final String planAfter = ""
+ + "LogicalProject(EXPR$0=[1])\n"
+ + " LogicalJoin(condition=[AND(=($0, $2), >($1, $3))],
joinType=[inner])\n"
+ + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n"
+ + " LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n"
+ + " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+ + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ + " LogicalProject(DEPTNO=[$0], SAL=[$1])\n"
+ + " LogicalFilter(condition=[IS NOT NULL($0)])\n"
+ + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n";
+ assertThat(after, hasTree(planAfter));
+ }
}
diff --git a/core/src/test/resources/sql/sub-query.iq
b/core/src/test/resources/sql/sub-query.iq
index 69b05cad8d..460618045c 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -3800,4 +3800,18 @@ SELECT array(SELECT empno FROM emp WHERE empno > 7800
ORDER BY empno DESC LIMIT
!ok
+# [CALCITE-6468] RelDecorrelator throws AssertionError if correlated variable
+# is used as Aggregate group key
+WITH agg_sal AS
+ (SELECT deptno, sum(sal) AS total FROM emp GROUP BY deptno)
+SELECT 1 FROM agg_sal s1
+WHERE s1.total > (SELECT avg(total) FROM agg_sal s2 WHERE s1.deptno =
s2.deptno);
++--------+
+| EXPR$0 |
++--------+
++--------+
+(0 rows)
+
+!ok
+
# End sub-query.iq