This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 29c413a3c6 [CALCITE-6468] RelDecorrelator throws AssertionError if 
correlated variable is used as Aggregate group key
29c413a3c6 is described below

commit 29c413a3c6933d1c871047b312b2b81235f3d1c6
Author: Ruben Quesada Lopez <[email protected]>
AuthorDate: Fri Jul 12 18:25:45 2024 +0100

    [CALCITE-6468] RelDecorrelator throws AssertionError if correlated variable 
is used as Aggregate group key
---
 .../apache/calcite/sql2rel/RelDecorrelator.java    | 16 ++--
 .../calcite/sql2rel/RelDecorrelatorTest.java       | 87 ++++++++++++++++++++++
 core/src/test/resources/sql/sub-query.iq           | 14 ++++
 3 files changed, 112 insertions(+), 5 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index f084d279b7..79fde87ee2 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -560,11 +560,17 @@ public class RelDecorrelator implements ReflectiveVisitor 
{
       // Now add the corVars from the input, starting from
       // position oldGroupKeyCount.
       for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
-        RexInputRef.add2(projects, entry.getValue(), newInputOutput);
-
-        corDefOutputs.put(entry.getKey(), newPos);
-        mapNewInputToProjOutputs.put(entry.getValue(), newPos);
-        newPos++;
+        // Verify if the CorDef position was already added to the 
mapNewInputToProjOutputs
+        // during the previous group key processing
+        final Integer pos = mapNewInputToProjOutputs.get(entry.getValue());
+        if (pos == null) {
+          RexInputRef.add2(projects, entry.getValue(), newInputOutput);
+          corDefOutputs.put(entry.getKey(), newPos);
+          mapNewInputToProjOutputs.put(entry.getValue(), newPos);
+          newPos++;
+        } else {
+          corDefOutputs.put(entry.getKey(), pos);
+        }
       }
     }
 
diff --git 
a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java 
b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
index e471ff8f74..dbbf69d872 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
@@ -16,22 +16,35 @@
  */
 package org.apache.calcite.sql2rel;
 
+import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelTraitDef;
+import org.apache.calcite.plan.hep.HepProgram;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.rules.CoreRules;
 import org.apache.calcite.rex.RexCorrelVariable;
 import org.apache.calcite.schema.SchemaPlus;
+import org.apache.calcite.sql.SqlNode;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.parser.SqlParser;
 import org.apache.calcite.test.CalciteAssert;
+import org.apache.calcite.tools.FrameworkConfig;
 import org.apache.calcite.tools.Frameworks;
+import org.apache.calcite.tools.Planner;
+import org.apache.calcite.tools.Program;
+import org.apache.calcite.tools.Programs;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.Holder;
+import org.apache.calcite.util.TestUtil;
+
+import com.google.common.collect.ImmutableList;
 
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.junit.jupiter.api.Test;
 
+import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 
 import static org.apache.calcite.test.Matchers.hasTree;
 
@@ -86,4 +99,78 @@ public class RelDecorrelatorTest {
         + "      LogicalTableScan(table=[[scott, DEPT]])\n";
     assertThat(after, hasTree(planAfter));
   }
+
+  /**
+   * Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6468";>[CALCITE-6468] 
RelDecorrelator
+   * throws AssertionError if correlated variable is used as Aggregate group 
key</a>.
+   */
+  @Test void testCorrVarOnAggregateKey() {
+    final FrameworkConfig frameworkConfig = config().build();
+    final RelBuilder builder = RelBuilder.create(frameworkConfig);
+    final RelOptCluster cluster = builder.getCluster();
+    final Planner planner = Frameworks.getPlanner(frameworkConfig);
+    final String sql = "WITH agg_sal AS"
+        + " (SELECT deptno, sum(sal) AS total FROM emp GROUP BY deptno)\n"
+        + " SELECT 1 FROM agg_sal s1"
+        + " WHERE s1.total > (SELECT avg(total) FROM agg_sal s2 WHERE 
s1.deptno = s2.deptno)";
+    final RelNode originalRel;
+    try {
+      final SqlNode parse = planner.parse(sql);
+      final SqlNode validate = planner.validate(parse);
+      originalRel = planner.rel(validate).rel;
+    } catch (Exception e) {
+      throw TestUtil.rethrow(e);
+    }
+
+    final HepProgram hepProgram = HepProgram.builder()
+        .addRuleCollection(
+            ImmutableList.of(
+                // SubQuery program rules
+                CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
+                CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
+                CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
+                // plus FilterAggregateTransposeRule
+                CoreRules.FILTER_AGGREGATE_TRANSPOSE))
+        .build();
+    final Program program =
+        Programs.of(hepProgram, true, 
Objects.requireNonNull(cluster.getMetadataProvider()));
+    final RelNode before =
+        program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
+            Collections.emptyList(), Collections.emptyList());
+    final String planBefore = ""
+        + "LogicalProject(EXPR$0=[1])\n"
+        + "  LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+        + "    LogicalFilter(condition=[>($1, $2)])\n"
+        + "      LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{0}])\n"
+        + "        LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+        + "          LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+        + "            LogicalTableScan(table=[[scott, EMP]])\n"
+        + "        LogicalAggregate(group=[{}], EXPR$0=[AVG($0)])\n"
+        + "          LogicalProject(TOTAL=[$1])\n"
+        + "            LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+        + "              LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+        + "                LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+        + "                  LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(before, hasTree(planBefore));
+
+    // Check decorrelation does not fail here
+    final RelNode after = RelDecorrelator.decorrelateQuery(before, builder);
+
+    // Verify plan
+    final String planAfter = ""
+        + "LogicalProject(EXPR$0=[1])\n"
+        + "  LogicalJoin(condition=[AND(=($0, $2), >($1, $3))], 
joinType=[inner])\n"
+        + "    LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+        + "      LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+        + "        LogicalTableScan(table=[[scott, EMP]])\n"
+        + "    LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n"
+        + "      LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+        + "        LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+        + "          LogicalProject(DEPTNO=[$0], SAL=[$1])\n"
+        + "            LogicalFilter(condition=[IS NOT NULL($0)])\n"
+        + "              LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+        + "                LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(after, hasTree(planAfter));
+  }
 }
diff --git a/core/src/test/resources/sql/sub-query.iq 
b/core/src/test/resources/sql/sub-query.iq
index 69b05cad8d..460618045c 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -3800,4 +3800,18 @@ SELECT array(SELECT empno FROM emp WHERE empno > 7800 
ORDER BY empno DESC LIMIT
 
 !ok
 
+# [CALCITE-6468] RelDecorrelator throws AssertionError if correlated variable
+# is used as Aggregate group key
+WITH agg_sal AS
+ (SELECT deptno, sum(sal) AS total FROM emp GROUP BY deptno)
+SELECT 1 FROM agg_sal s1
+WHERE s1.total > (SELECT avg(total) FROM agg_sal s2 WHERE s1.deptno = 
s2.deptno);
++--------+
+| EXPR$0 |
++--------+
++--------+
+(0 rows)
+
+!ok
+
 # End sub-query.iq

Reply via email to