This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f148c7256 [CALCITE-7266] Optimize the "well-known count bug" 
correction
0f148c7256 is described below

commit 0f148c7256e2e5aec0dfc790af40b9f4bbcc5244
Author: Ruben Quesada Lopez <[email protected]>
AuthorDate: Tue Nov 4 10:02:56 2025 +0000

    [CALCITE-7266] Optimize the "well-known count bug" correction
---
 .../apache/calcite/sql2rel/RelDecorrelator.java    | 126 ++++++++++++---------
 .../calcite/sql2rel/RelDecorrelatorTest.java       |  98 ++++++++++++----
 .../apache/calcite/test/SqlToRelConverterTest.xml  |  26 ++---
 core/src/test/resources/sql/sub-query.iq           |  16 +++
 4 files changed, 171 insertions(+), 95 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index 42e3a84f96..e9d7f498f9 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -221,14 +221,21 @@ public static RelNode decorrelateQuery(RelNode rootRel,
     return decorrelateQuery(rootRel, relBuilder, null);
   }
 
+  public static RelNode decorrelateQuery(RelNode rootRel,
+      RelBuilder relBuilder, @Nullable RuleSet decorrelationRules) {
+    return decorrelateQuery(rootRel, relBuilder, decorrelationRules, null);
+  }
+
   /**
    * Decorrelates a query specifying a set of rules to be used in the
    * "remove correlation via rules" pre-processing.
    *
    * @param rootRel           Root node of the query
    * @param relBuilder        Builder for relational expressions
-   * @param decorrelationRules  Rules to be used in the decorrelation, if 
<code>null</code>
-   *                            a default rule set will be used
+   * @param decorrelationRules  Rules to attempt some initial 
rule-based-decorrelation conversions,
+   *                            if <code>null</code> a default rule set will 
be used
+   * @param preDecorrelateRules Pre-process rules to be used before the main 
decorrelation
+   *                            procedure, if <code>null</code> a default rule 
set will be used
    *
    * @return Equivalent query with all
    * {@link org.apache.calcite.rel.core.Correlate} instances removed
@@ -236,7 +243,8 @@ public static RelNode decorrelateQuery(RelNode rootRel,
    * @see #removeCorrelationViaRule(RelNode, RuleSet)
    */
   public static RelNode decorrelateQuery(RelNode rootRel,
-      RelBuilder relBuilder, @Nullable RuleSet decorrelationRules) {
+      RelBuilder relBuilder, @Nullable RuleSet decorrelationRules,
+      @Nullable RuleSet preDecorrelateRules) {
     final CorelMap corelMap = new CorelMapBuilder().build(rootRel);
     if (!corelMap.hasCorrelation()) {
       return rootRel;
@@ -258,7 +266,7 @@ public static RelNode decorrelateQuery(RelNode rootRel,
     }
 
     if (!decorrelator.cm.mapCorToCorRel.isEmpty()) {
-      newRootRel = decorrelator.decorrelate(newRootRel);
+      newRootRel = decorrelator.decorrelate(newRootRel, preDecorrelateRules);
     }
     Litmus.THROW.check(
         rootRel.getRowType().equalsSansFieldNames(newRootRel.getRowType()),
@@ -282,49 +290,56 @@ protected RelBuilderFactory relBuilderFactory() {
   }
 
   protected RelNode decorrelate(RelNode root) {
-    // first adjust count() expression if any
-    final RelBuilderFactory f = relBuilderFactory();
-    HepProgram program = HepProgram.builder()
-        .addRuleInstance(
-            AdjustProjectForCountAggregateRule.DEFAULT_WITHOUT_FAVLOR
-                .withRelBuilderFactory(f).toRule())
-        .addRuleInstance(
-            AdjustProjectForCountAggregateRule.DEFAULT_WITH_FAVLOR
-                .withRelBuilderFactory(f).toRule())
-        .addRuleInstance(
-            FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.DEFAULT
-                .withRelBuilderFactory(f)
-                .withOperandSupplier(b0 ->
-                    b0.operand(Filter.class).oneInput(b1 ->
-                        b1.operand(Join.class).anyInputs()))
-                .withDescription("FilterJoinRule:filter")
-                
.as(FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.class)
-                .withSmart(true)
-                .withPredicate((join, joinType, exp) -> true)
-                
.as(FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.class)
-                .toRule())
-        .addRuleInstance(
-            CoreRules.FILTER_PROJECT_TRANSPOSE.config
-                .withRelBuilderFactory(f)
-                .as(FilterProjectTransposeRule.Config.class)
-                .withOperandFor(Filter.class, filter ->
-                        !RexUtil.containsCorrelation(filter.getCondition()),
-                    Project.class, project -> true)
-                .withCopyFilter(true)
-                .withCopyProject(true)
-                .toRule())
-        .addRuleInstance(FilterCorrelateRule.Config.DEFAULT
-            .withRelBuilderFactory(f)
-            .toRule())
-        .addRuleInstance(FilterFlattenCorrelatedConditionRule.Config.DEFAULT
-            .withRelBuilderFactory(f)
-            .toRule())
-        .build();
+    return decorrelate(root, null);
+  }
 
-    HepPlanner planner = createPlanner(program);
+  protected RelNode decorrelate(RelNode root, @Nullable RuleSet 
preDecorrelateRules) {
+    final RelBuilderFactory f = relBuilderFactory();
+    final HepProgram program;
+    if (preDecorrelateRules != null) {
+      program = ruleSetToHepProgram(preDecorrelateRules);
+    } else {
+      // Use a default set of pre-decorrelate rules:
+      // adjust count() expression if any, and do some filter-related 
transformations
+      program = HepProgram.builder()
+          .addRuleInstance(
+              AdjustProjectForCountAggregateRule.DEFAULT_WITHOUT_FAVLOR
+                  .withRelBuilderFactory(f).toRule())
+          .addRuleInstance(
+              AdjustProjectForCountAggregateRule.DEFAULT_WITH_FAVLOR
+                  .withRelBuilderFactory(f).toRule())
+          .addRuleInstance(
+              
FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.DEFAULT
+                  .withRelBuilderFactory(f)
+                  .withOperandSupplier(b0 ->
+                      b0.operand(Filter.class).oneInput(b1 ->
+                          b1.operand(Join.class).anyInputs()))
+                  .withDescription("FilterJoinRule:filter")
+                  
.as(FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.class)
+                  .withSmart(true)
+                  .withPredicate((join, joinType, exp) -> true)
+                  
.as(FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.class)
+                  .toRule())
+          .addRuleInstance(
+              CoreRules.FILTER_PROJECT_TRANSPOSE.config
+                  .withRelBuilderFactory(f)
+                  .as(FilterProjectTransposeRule.Config.class)
+                  .withOperandFor(Filter.class, filter ->
+                          !RexUtil.containsCorrelation(filter.getCondition()),
+                      Project.class, project -> true)
+                  .withCopyFilter(true)
+                  .withCopyProject(true)
+                  .toRule())
+          .addRuleInstance(FilterCorrelateRule.Config.DEFAULT
+              .withRelBuilderFactory(f)
+              .toRule())
+          .addRuleInstance(FilterFlattenCorrelatedConditionRule.Config.DEFAULT
+              .withRelBuilderFactory(f)
+              .toRule())
+          .build();
+    }
 
-    planner.setRoot(root);
-    root = planner.findBestExp();
+    root = applyHepProgram(root, program);
     if (SQL2REL_LOGGER.isDebugEnabled()) {
       SQL2REL_LOGGER.debug("Plan before extracting correlated computations:\n"
           + RelOptUtil.toString(root));
@@ -374,11 +389,7 @@ protected RelNode decorrelate(RelNode root) {
         builder.addRuleCollection(getPostDecorrelateRules());
       }
       final HepProgram program2 = builder.build();
-
-      final HepPlanner planner2 = createPlanner(program2);
-      final RelNode newRoot = result;
-      planner2.setRoot(newRoot);
-      return planner2.findBestExp();
+      return applyHepProgram(result, program2);
     }
 
     return root;
@@ -434,7 +445,7 @@ public RelNode removeCorrelationViaRule(RelNode root) {
         .addRuleInstance(
             
RemoveCorrelationForScalarAggregateRule.DEFAULT.withRelBuilderFactory(f).toRule())
         .build();
-    return removeCorrelationViaRule(root, program);
+    return applyHepProgram(root, program);
   }
 
   /**
@@ -443,6 +454,10 @@ public RelNode removeCorrelationViaRule(RelNode root) {
    * {@link org.apache.calcite.rel.core.Correlate}s might be removable in such 
way).
    */
   public RelNode removeCorrelationViaRule(RelNode root, RuleSet ruleSet) {
+    return applyHepProgram(root, ruleSetToHepProgram(ruleSet));
+  }
+
+  private HepProgram ruleSetToHepProgram(RuleSet ruleSet) {
     final RelBuilderFactory f = relBuilderFactory();
     final HepProgramBuilder builder = HepProgram.builder();
     for (RelOptRule rule : ruleSet) {
@@ -451,11 +466,10 @@ public RelNode removeCorrelationViaRule(RelNode root, 
RuleSet ruleSet) {
       }
       builder.addRuleInstance(rule);
     }
-    final HepProgram program = builder.build();
-    return removeCorrelationViaRule(root, program);
+    return builder.build();
   }
 
-  private RelNode removeCorrelationViaRule(RelNode root, HepProgram program) {
+  private RelNode applyHepProgram(RelNode root, HepProgram program) {
     HepPlanner planner = createPlanner(program);
     planner.setRoot(root);
     return planner.findBestExp();
@@ -1637,7 +1651,9 @@ private static boolean isWidening(RelDataType type, 
RelDataType type1) {
     }
 
     frameStack.push(Pair.of(rel.getCorrelationId(), leftFrame));
-    final Frame rightFrame = getInvoke(oldRight, true, rel, 
parentPropagatesNullValues);
+    final Frame rightFrame =
+        getInvoke(oldRight, true, rel,
+            rel.getJoinType() == JoinRelType.LEFT || 
parentPropagatesNullValues);
     frameStack.pop();
 
     if (rightFrame == null || rightFrame.corDefOutputs.isEmpty()) {
diff --git 
a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java 
b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
index 40b6212778..c079af7740 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
@@ -208,6 +208,70 @@ public static Frameworks.ConfigBuilder config() {
     assertThat(after, hasTree(planAfter));
   }
 
+  @Test void testDecorrelateCountBug() {
+    final FrameworkConfig frameworkConfig = config().build();
+    final RelBuilder builder = RelBuilder.create(frameworkConfig);
+    final RelOptCluster cluster = builder.getCluster();
+    final Planner planner = Frameworks.getPlanner(frameworkConfig);
+    final String sql = "SELECT deptno, "
+        + "(SELECT CASE WHEN SUM(sal) > 10 then 'VIP' else 'Regular' END expr "
+        + " FROM emp e WHERE d.deptno = e.deptno) a "
+        + "FROM dept d";
+    final RelNode originalRel;
+    try {
+      final SqlNode parse = planner.parse(sql);
+      final SqlNode validate = planner.validate(parse);
+      originalRel = planner.rel(validate).rel;
+    } catch (Exception e) {
+      throw TestUtil.rethrow(e);
+    }
+
+    final HepProgram hepProgram = HepProgram.builder()
+        .addRuleCollection(
+            ImmutableList.of(
+                // SubQuery program rules
+                CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
+                CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
+                CoreRules.JOIN_SUB_QUERY_TO_CORRELATE))
+        .build();
+    final Program program =
+        Programs.of(hepProgram, true,
+            requireNonNull(cluster.getMetadataProvider()));
+    final RelNode before =
+        program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
+            Collections.emptyList(), Collections.emptyList());
+    final String planBefore = ""
+        + "LogicalProject(DEPTNO=[$0], A=[$3])\n"
+        + "  LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{0}])\n"
+        + "    LogicalTableScan(table=[[scott, DEPT]])\n"
+        + "    LogicalProject(EXPR=[CASE(>($0, 10.00), 'VIP    ', 
'Regular')])\n"
+        + "      LogicalAggregate(group=[{}], agg#0=[SUM($0)])\n"
+        + "        LogicalProject(SAL=[$5])\n"
+        + "          LogicalFilter(condition=[=($cor0.DEPTNO, $7)])\n"
+        + "            LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(before, hasTree(planBefore));
+
+    // Decorrelate without any rules, just "purely" decorrelation algorithm on 
RelDecorrelator
+    final RelNode after =
+        RelDecorrelator.decorrelateQuery(before, builder, 
RuleSets.ofList(Collections.emptyList()),
+            RuleSets.ofList(Collections.emptyList()));
+
+    // Verify plan
+    final String planAfter = ""
+        + "LogicalProject(DEPTNO=[$0], A=[$3])\n"
+        + "  LogicalJoin(condition=[=($0, $4)], joinType=[left])\n"
+        + "    LogicalTableScan(table=[[scott, DEPT]])\n"
+        + "    LogicalProject(EXPR=[CASE(>($2, 10.00), 'VIP    ', 'Regular')], 
DEPTNO=[$0])\n"
+        + "      LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[left])\n"
+        + "        LogicalProject(DEPTNO=[$0])\n"
+        + "          LogicalTableScan(table=[[scott, DEPT]])\n"
+        + "        LogicalAggregate(group=[{0}], agg#0=[SUM($1)])\n"
+        + "          LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+        + "            LogicalFilter(condition=[IS NOT NULL($7)])\n"
+        + "              LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(after, hasTree(planAfter));
+  }
+
   /**
    * Test case for
    * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6468";>[CALCITE-6468] 
RelDecorrelator
@@ -269,23 +333,17 @@ public static Frameworks.ConfigBuilder config() {
     // Verify plan
     final String planAfter = ""
         + "LogicalProject(EXPR$0=[1])\n"
-        + "  LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $2), >($1, 
$3))], joinType=[inner])\n"
+        + "  LogicalJoin(condition=[AND(=($0, $2), >($1, $3))], 
joinType=[inner])\n"
         + "    LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
         + "      LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
         + "        LogicalTableScan(table=[[scott, EMP]])\n"
-        + "    LogicalProject(DEPTNO=[$0], EXPR$0=[$2])\n"
-        + "      LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[left])\n"
-        + "        LogicalProject(DEPTNO=[$0])\n"
-        + "          LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
-        + "            LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
-        + "              LogicalTableScan(table=[[scott, EMP]])\n"
-        + "        LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n"
-        + "          LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
-        + "            LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
-        + "              LogicalProject(DEPTNO=[$0], SAL=[$1])\n"
-        + "                LogicalFilter(condition=[IS NOT NULL($0)])\n"
-        + "                  LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
-        + "                    LogicalTableScan(table=[[scott, EMP]])\n";
+        + "    LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n"
+        + "      LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+        + "        LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+        + "          LogicalProject(DEPTNO=[$0], SAL=[$1])\n"
+        + "            LogicalFilter(condition=[IS NOT NULL($0)])\n"
+        + "              LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+        + "                LogicalTableScan(table=[[scott, EMP]])\n";
     assertThat(after, hasTree(planAfter));
   }
 
@@ -366,15 +424,11 @@ public static Frameworks.ConfigBuilder config() {
         RelDecorrelator.decorrelateQuery(original, builder, noRules);
     final String planDecorrelatedNoRules = ""
         + "LogicalProject(EXPR$0=[ROW($9, $1)])\n"
-        + "  LogicalJoin(condition=[IS NOT DISTINCT FROM($7, $8)], 
joinType=[left])\n"
+        + "  LogicalJoin(condition=[=($7, $8)], joinType=[left])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n"
-        + "    LogicalProject(DEPTNO1=[$0], $f1=[$2])\n"
-        + "      LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[left])\n"
-        + "        LogicalAggregate(group=[{7}])\n"
-        + "          LogicalTableScan(table=[[scott, EMP]])\n"
-        + "        LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n"
-        + "          LogicalProject(DEPTNO1=[$0], DEPTNO=[$0])\n"
-        + "            LogicalTableScan(table=[[scott, DEPT]])\n";
+        + "    LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n"
+        + "      LogicalProject(DEPTNO1=[$0], DEPTNO=[$0])\n"
+        + "        LogicalTableScan(table=[[scott, DEPT]])\n";
     assertThat(decorrelatedNoRules, hasTree(planDecorrelatedNoRules));
   }
 }
diff --git 
a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml 
b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
index b640d7ddde..45351cfb6f 100644
--- a/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/SqlToRelConverterTest.xml
@@ -5523,16 +5523,11 @@ LogicalProject(D2=[$0], D3=[$1])
               LogicalJoin(condition=[=($0, $1)], joinType=[left])
                 LogicalProject(D1=[+($0, 1)])
                   LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
-                LogicalProject(D4=[$0], D6=[$2], $f2=[$3])
-                  LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[left])
-                    LogicalAggregate(group=[{0}])
-                      LogicalProject(D1=[+($0, 1)])
+                LogicalAggregate(group=[{0, 1}], agg#0=[MIN($2)])
+                  LogicalProject(D4=[$0], D6=[$2], $f0=[true])
+                    LogicalFilter(condition=[=($1, $0)])
+                      LogicalProject(D4=[+($0, 4)], D5=[+($0, 5)], D6=[+($0, 
6)])
                         LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
-                    LogicalAggregate(group=[{0, 1}], agg#0=[MIN($2)])
-                      LogicalProject(D4=[$0], D6=[$2], $f0=[true])
-                        LogicalFilter(condition=[=($1, $0)])
-                          LogicalProject(D4=[+($0, 4)], D5=[+($0, 5)], 
D6=[+($0, 6)])
-                            LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
     </Resource>
   </TestCase>
@@ -5558,16 +5553,11 @@ LogicalProject(D2=[$0], D3=[$1])
               LogicalJoin(condition=[=($0, $1)], joinType=[left])
                 LogicalProject(D1=[+($0, 1)])
                   LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
-                LogicalProject(D4=[$0], D6=[$2], $f2=[$3])
-                  LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[left])
-                    LogicalAggregate(group=[{0}])
-                      LogicalProject(D1=[+($0, 1)])
+                LogicalAggregate(group=[{0, 1}], agg#0=[MIN($2)])
+                  LogicalProject(D4=[$0], D6=[$2], $f0=[true])
+                    LogicalFilter(condition=[=($1, $0)])
+                      LogicalProject(D4=[+($0, 4)], D5=[+($0, 5)], D6=[+($0, 
6)])
                         LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
-                    LogicalAggregate(group=[{0, 1}], agg#0=[MIN($2)])
-                      LogicalProject(D4=[$0], D6=[$2], $f0=[true])
-                        LogicalFilter(condition=[=($1, $0)])
-                          LogicalProject(D4=[+($0, 4)], D5=[+($0, 5)], 
D6=[+($0, 6)])
-                            LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
     </Resource>
   </TestCase>
diff --git a/core/src/test/resources/sql/sub-query.iq 
b/core/src/test/resources/sql/sub-query.iq
index 4a9a30fa5a..015a94f69e 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -4396,6 +4396,22 @@ WHERE 'Regular' IN (
 
 !ok
 
+SELECT deptno, (SELECT CASE WHEN SUM(sal) > 10 then 'VIP' else 'Regular' END 
expr
+                   FROM emp e
+                   WHERE d.deptno = e.deptno) a
+FROM dept d;
++--------+---------+
+| DEPTNO | A       |
++--------+---------+
+|     10 | VIP     |
+|     20 | VIP     |
+|     30 | VIP     |
+|     40 | Regular |
++--------+---------+
+(4 rows)
+
+!ok
+
 # Test case for [CALCITE-5789]
 select deptno from dept d1 where exists (
  select 1 from dept d2 where d2.deptno = d1.deptno and exists (

Reply via email to