mihaibudiu commented on code in PR #3854:
URL: https://github.com/apache/calcite/pull/3854#discussion_r1676240133


##########
core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java:
##########
@@ -560,11 +560,15 @@ protected RexNode removeCorrelationExpr(
       // Now add the corVars from the input, starting from
       // position oldGroupKeyCount.
       for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
-        RexInputRef.add2(projects, entry.getValue(), newInputOutput);
-
-        corDefOutputs.put(entry.getKey(), newPos);
-        mapNewInputToProjOutputs.put(entry.getValue(), newPos);
-        newPos++;
+        final Integer pos = mapNewInputToProjOutputs.get(entry.getValue());

Review Comment:
   I think a comment would help maintainers understand why this can happen



##########
core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java:
##########
@@ -86,4 +99,75 @@ public static Frameworks.ConfigBuilder config() {
         + "      LogicalTableScan(table=[[scott, DEPT]])\n";
     assertThat(after, hasTree(planAfter));
   }
+
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-2744";>[CALCITE-2744]
+   * RelDecorrelator use wrong output map for LogicalAggregate 
decorrelate</a>. */
+  @Test void testCorrVarOnAggregateKey() {
+    final FrameworkConfig frameworkConfig = config().build();
+    final RelBuilder builder = RelBuilder.create(frameworkConfig);
+    final RelOptCluster cluster = builder.getCluster();
+    final Planner planner = Frameworks.getPlanner(frameworkConfig);
+    final String sql = "WITH agg_sal AS"
+        + " (SELECT deptno, sum(sal) AS total FROM emp GROUP BY deptno)\n"
+        + " SELECT 1 FROM agg_sal s1"
+        + " WHERE s1.total > (SELECT avg(total) FROM agg_sal s2 WHERE 
s1.deptno = s2.deptno)";
+    try {
+      final SqlNode parse = planner.parse(sql);
+      final SqlNode validate = planner.validate(parse);
+      final RelNode originalRel = planner.rel(validate).rel;
+      final HepProgram hepProgram = HepProgram.builder()
+          .addRuleCollection(
+              ImmutableList.of(
+                  // SubQuery program rules
+                  CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
+                  CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
+                  CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
+                  // plus FilterAggregateTransposeRule
+                  CoreRules.FILTER_AGGREGATE_TRANSPOSE))
+          .build();
+      final Program program =
+          Programs.of(hepProgram, true, 
Objects.requireNonNull(cluster.getMetadataProvider()));
+      final RelNode before =
+          program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
+              Collections.emptyList(), Collections.emptyList());
+      final String planBefore = ""
+          + "LogicalProject(EXPR$0=[1])\n"
+          + "  LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+          + "    LogicalFilter(condition=[>($1, $2)])\n"
+          + "      LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{0}])\n"
+          + "        LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+          + "          LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+          + "            LogicalTableScan(table=[[scott, EMP]])\n"
+          + "        LogicalAggregate(group=[{}], EXPR$0=[AVG($0)])\n"
+          + "          LogicalProject(TOTAL=[$1])\n"
+          + "            LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+          + "              LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+          + "                LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+          + "                  LogicalTableScan(table=[[scott, EMP]])\n";
+      assertThat(before, hasTree(planBefore));
+
+      // Check decorrelation does not fail here
+      final RelNode after = RelDecorrelator.decorrelateQuery(before, builder);
+
+      // Verify plan
+      final String planAfter = ""
+          + "LogicalProject(EXPR$0=[1])\n"
+          + "  LogicalJoin(condition=[AND(=($0, $2), >($1, $3))], 
joinType=[inner])\n"
+          + "    LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+          + "      LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+          + "        LogicalTableScan(table=[[scott, EMP]])\n"
+          + "    LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n"
+          + "      LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+          + "        LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+          + "          LogicalProject(DEPTNO=[$0], SAL=[$1])\n"
+          + "            LogicalFilter(condition=[IS NOT NULL($0)])\n"
+          + "              LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+          + "                LogicalTableScan(table=[[scott, EMP]])\n";
+      assertThat(after, hasTree(planAfter));
+    } catch (Exception e) {
+      throw TestUtil.rethrow(e);

Review Comment:
   why do you have to catch and rethrow?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to