This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new edd6d11abb [CALCITE-6674] Make RelDecorrelator rules configurable
edd6d11abb is described below

commit edd6d11abb9dee7b41834108da96ed6aef091531
Author: Ruben Quesada Lopez <[email protected]>
AuthorDate: Fri Nov 8 12:05:40 2024 +0000

    [CALCITE-6674] Make RelDecorrelator rules configurable
---
 .../apache/calcite/plan/AbstractRelOptPlanner.java |  14 ++
 .../org/apache/calcite/plan/RelOptPlanner.java     |  10 +
 .../apache/calcite/sql2rel/RelDecorrelator.java    | 221 ++++++++++++---------
 .../calcite/sql2rel/RelDecorrelatorTest.java       |  87 ++++++++
 4 files changed, 235 insertions(+), 97 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/plan/AbstractRelOptPlanner.java 
b/core/src/main/java/org/apache/calcite/plan/AbstractRelOptPlanner.java
index d2ed0a57a7..6cb7c4fed4 100644
--- a/core/src/main/java/org/apache/calcite/plan/AbstractRelOptPlanner.java
+++ b/core/src/main/java/org/apache/calcite/plan/AbstractRelOptPlanner.java
@@ -21,6 +21,7 @@ import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.metadata.RelMetadataProvider;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.rex.RexExecutor;
+import org.apache.calcite.sql2rel.RelDecorrelator;
 import org.apache.calcite.util.CancelFlag;
 import org.apache.calcite.util.Pair;
 import org.apache.calcite.util.Util;
@@ -86,6 +87,8 @@ public abstract class AbstractRelOptPlanner implements 
RelOptPlanner {
 
   private @Nullable RexExecutor executor;
 
+  private @Nullable RelDecorrelator decorrelator;
+
   //~ Constructors -----------------------------------------------------------
 
   /**
@@ -290,6 +293,17 @@ public abstract class AbstractRelOptPlanner implements 
RelOptPlanner {
     return executor;
   }
 
+  @Override public void setDecorrelator(@Nullable RelDecorrelator 
decorrelator) {
+    this.decorrelator = decorrelator;
+  }
+
+  @Override public RelDecorrelator getDecorrelator() {
+    if (decorrelator == null) {
+      throw new IllegalStateException("RelDecorrelator has not been set");
+    }
+    return decorrelator;
+  }
+
   @Override public void onCopy(RelNode rel, RelNode newRel) {
     // do nothing
   }
diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptPlanner.java 
b/core/src/main/java/org/apache/calcite/plan/RelOptPlanner.java
index 3f1fe01fef..5341819ce9 100644
--- a/core/src/main/java/org/apache/calcite/plan/RelOptPlanner.java
+++ b/core/src/main/java/org/apache/calcite/plan/RelOptPlanner.java
@@ -21,6 +21,7 @@ import 
org.apache.calcite.rel.metadata.CachingRelMetadataProvider;
 import org.apache.calcite.rel.metadata.RelMetadataProvider;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.rex.RexExecutor;
+import org.apache.calcite.sql2rel.RelDecorrelator;
 import org.apache.calcite.util.CancelFlag;
 import org.apache.calcite.util.trace.CalciteTrace;
 
@@ -334,6 +335,15 @@ public interface RelOptPlanner {
   /** Returns the executor used to evaluate constant expressions. */
   @Nullable RexExecutor getExecutor();
 
+  /** Sets the decorrelator. */
+  void setDecorrelator(@Nullable RelDecorrelator decorrelator);
+
+  /** Returns the decorrelator used to decorrelate expressions.
+   *
+   * @throws IllegalStateException if the decorrelator has not been set
+   * */
+  RelDecorrelator getDecorrelator();
+
   /** Called when a relational expression is copied to a similar expression. */
   void onCopy(RelNode rel, RelNode newRel);
 
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java 
b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index b78403a90b..6a7376f132 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -83,6 +83,7 @@ import org.apache.calcite.sql.fun.SqlSingleValueAggFunction;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
+import org.apache.calcite.tools.RuleSet;
 import org.apache.calcite.util.Holder;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.Litmus;
@@ -137,8 +138,6 @@ import static java.util.Objects.requireNonNull;
  *   <li>make {@link #currentRel} immutable (would require a fresh
  *      RelDecorrelator for each node being decorrelated)</li>
  *   <li>make fields of {@code CorelMap} immutable</li>
- *   <li>make sub-class rules static, and have them create their own
- *   de-correlator</li>
  * </ul>
  *
  * <p>Note: make all the members protected scope so that they can be
@@ -208,6 +207,25 @@ public class RelDecorrelator implements ReflectiveVisitor {
    */
   public static RelNode decorrelateQuery(RelNode rootRel,
       RelBuilder relBuilder) {
+    return decorrelateQuery(rootRel, relBuilder, null);
+  }
+
+  /**
+   * Decorrelates a query specifying a set of rules to be used in the
+   * "remove correlation via rules" pre-processing.
+   *
+   * @param rootRel           Root node of the query
+   * @param relBuilder        Builder for relational expressions
+   * @param decorrelationRules  Rules to be used in the decorrelation, if 
<code>null</code>
+   *                            a default rule set will be used
+   *
+   * @return Equivalent query with all
+   * {@link org.apache.calcite.rel.core.Correlate} instances removed
+   *
+   * @see #removeCorrelationViaRule(RelNode, RuleSet)
+   */
+  public static RelNode decorrelateQuery(RelNode rootRel,
+      RelBuilder relBuilder, @Nullable RuleSet decorrelationRules) {
     final CorelMap corelMap = new CorelMapBuilder().build(rootRel);
     if (!corelMap.hasCorrelation()) {
       return rootRel;
@@ -218,7 +236,9 @@ public class RelDecorrelator implements ReflectiveVisitor {
         new RelDecorrelator(corelMap,
             cluster.getPlanner().getContext(), relBuilder);
 
-    RelNode newRootRel = decorrelator.removeCorrelationViaRule(rootRel);
+    RelNode newRootRel = decorrelationRules == null
+        ? decorrelator.removeCorrelationViaRule(rootRel)
+        : decorrelator.removeCorrelationViaRule(rootRel, decorrelationRules);
 
     if (SQL2REL_LOGGER.isDebugEnabled()) {
       SQL2REL_LOGGER.debug(
@@ -252,9 +272,11 @@ public class RelDecorrelator implements ReflectiveVisitor {
     final RelBuilderFactory f = relBuilderFactory();
     HepProgram program = HepProgram.builder()
         .addRuleInstance(
-            AdjustProjectForCountAggregateRule.config(false, this, f).toRule())
+            AdjustProjectForCountAggregateRule.DEFAULT_WITHOUT_FAVLOR
+                .withRelBuilderFactory(f).toRule())
         .addRuleInstance(
-            AdjustProjectForCountAggregateRule.config(true, this, f).toRule())
+            AdjustProjectForCountAggregateRule.DEFAULT_WITH_FAVLOR
+                .withRelBuilderFactory(f).toRule())
         .addRuleInstance(
             FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.DEFAULT
                 .withRelBuilderFactory(f)
@@ -354,26 +376,54 @@ public class RelDecorrelator implements ReflectiveVisitor 
{
   private HepPlanner createPlanner(HepProgram program) {
     // Create a planner with a hook to update the mapping tables when a
     // node is copied when it is registered.
-    return new HepPlanner(
-        program,
-        context,
-        true,
-        createCopyHook(),
-        RelOptCostImpl.FACTORY);
+    HepPlanner planner =
+        new HepPlanner(
+            program,
+            context,
+            true,
+            createCopyHook(),
+            RelOptCostImpl.FACTORY);
+    planner.setDecorrelator(this);
+    return planner;
   }
 
+  /**
+   * Remove some instances of {@link org.apache.calcite.rel.core.Correlate} 
from a query plan
+   * by applying a default set of rules (only some of the
+   * {@link org.apache.calcite.rel.core.Correlate}s might be removable in such 
way).
+   */
   public RelNode removeCorrelationViaRule(RelNode root) {
     final RelBuilderFactory f = relBuilderFactory();
     HepProgram program = HepProgram.builder()
-        .addRuleInstance(RemoveSingleAggregateRule.config(f).toRule())
+        .addRuleInstance(RemoveSingleAggregateRule.DEFAULT.toRule())
         .addRuleInstance(
-            RemoveCorrelationForScalarProjectRule.config(this, f).toRule())
+            
RemoveCorrelationForScalarProjectRule.DEFAULT.withRelBuilderFactory(f).toRule())
         .addRuleInstance(
-            RemoveCorrelationForScalarAggregateRule.config(this, f).toRule())
+            
RemoveCorrelationForScalarAggregateRule.DEFAULT.withRelBuilderFactory(f).toRule())
         .build();
+    return removeCorrelationViaRule(root, program);
+  }
 
-    HepPlanner planner = createPlanner(program);
+  /**
+   * Remove some instances of {@link org.apache.calcite.rel.core.Correlate} 
from a query plan
+   * by applying a certain {@link RuleSet} (only some of the
+   * {@link org.apache.calcite.rel.core.Correlate}s might be removable in such 
way).
+   */
+  public RelNode removeCorrelationViaRule(RelNode root, RuleSet ruleSet) {
+    final RelBuilderFactory f = relBuilderFactory();
+    final HepProgramBuilder builder = HepProgram.builder();
+    for (RelOptRule rule : ruleSet) {
+      if (rule instanceof RelRule) {
+        rule = ((RelRule<?>) rule).config.withRelBuilderFactory(f).toRule();
+      }
+      builder.addRuleInstance(rule);
+    }
+    final HepProgram program = builder.build();
+    return removeCorrelationViaRule(root, program);
+  }
 
+  private RelNode removeCorrelationViaRule(RelNode root, HepProgram program) {
+    HepPlanner planner = createPlanner(program);
     planner.setRoot(root);
     return planner.findBestExp();
   }
@@ -1890,14 +1940,14 @@ public class RelDecorrelator implements 
ReflectiveVisitor {
    */
   public static final class RemoveSingleAggregateRule
       extends 
RelRule<RemoveSingleAggregateRule.RemoveSingleAggregateRuleConfig> {
-    static RemoveSingleAggregateRuleConfig config(RelBuilderFactory f) {
-      return 
ImmutableRemoveSingleAggregateRuleConfig.builder().withRelBuilderFactory(f)
-          .withOperandSupplier(b0 ->
-              b0.operand(Aggregate.class).oneInput(b1 ->
-                  b1.operand(Project.class).oneInput(b2 ->
-                      b2.operand(Aggregate.class).anyInputs())))
-          .build();
-    }
+
+    static final RemoveSingleAggregateRuleConfig DEFAULT =
+        ImmutableRemoveSingleAggregateRuleConfig.builder()
+            .withOperandSupplier(b0 ->
+                b0.operand(Aggregate.class).oneInput(b1 ->
+                    b1.operand(Project.class).oneInput(b2 ->
+                        b2.operand(Aggregate.class).anyInputs())))
+            .build();
 
     /** Creates a RemoveSingleAggregateRule. */
     RemoveSingleAggregateRule(RemoveSingleAggregateRuleConfig config) {
@@ -1949,29 +1999,24 @@ public class RelDecorrelator implements 
ReflectiveVisitor {
   public static final class RemoveCorrelationForScalarProjectRule
       extends RelRule<RemoveCorrelationForScalarProjectRule
       .RemoveCorrelationForScalarProjectRuleConfig> {
-    private final RelDecorrelator d;
-
-    static RemoveCorrelationForScalarProjectRuleConfig config(RelDecorrelator 
decorrelator,
-        RelBuilderFactory relBuilderFactory) {
-      return ImmutableRemoveCorrelationForScalarProjectRuleConfig.builder()
-          .withRelBuilderFactory(relBuilderFactory)
-          .withOperandSupplier(b0 ->
-                  b0.operand(Correlate.class).inputs(
-                      b1 -> b1.operand(RelNode.class).anyInputs(),
-                      b2 -> b2.operand(Aggregate.class).oneInput(b3 ->
-                          b3.operand(Project.class).oneInput(b4 ->
-                              b4.operand(RelNode.class).anyInputs()))))
-          .withDecorrelator(decorrelator)
-          .build();
-    }
+
+    static final RemoveCorrelationForScalarProjectRuleConfig DEFAULT =
+        ImmutableRemoveCorrelationForScalarProjectRuleConfig.builder()
+            .withOperandSupplier(b0 ->
+                b0.operand(Correlate.class).inputs(
+                    b1 -> b1.operand(RelNode.class).anyInputs(),
+                    b2 -> b2.operand(Aggregate.class).oneInput(b3 ->
+                        b3.operand(Project.class).oneInput(b4 ->
+                            b4.operand(RelNode.class).anyInputs()))))
+            .build();
 
     /** Creates a RemoveCorrelationForScalarProjectRule. */
     
RemoveCorrelationForScalarProjectRule(RemoveCorrelationForScalarProjectRuleConfig
 config) {
       super(config);
-      this.d = requireNonNull(config.decorrelator());
     }
 
     @Override public void onMatch(RelOptRuleCall call) {
+      final RelDecorrelator d = call.getPlanner().getDecorrelator();
       final Correlate correlate = call.rel(0);
       final RelNode left = call.rel(1);
       final Aggregate aggregate = call.rel(2);
@@ -2153,12 +2198,9 @@ public class RelDecorrelator implements 
ReflectiveVisitor {
       d.removeCorVarFromTree(correlate);
     }
 
-    /** Rule configuration.
-     *
-     * <p>Extends {@link RelDecorrelator.Config} because rule needs a
-     * decorrelator instance. */
+    /** Rule configuration. */
     @Value.Immutable(singleton = false)
-    public interface RemoveCorrelationForScalarProjectRuleConfig extends 
RelDecorrelator.Config {
+    public interface RemoveCorrelationForScalarProjectRuleConfig extends 
RelRule.Config {
       @Override default RemoveCorrelationForScalarProjectRule toRule() {
         return new RemoveCorrelationForScalarProjectRule(this);
       }
@@ -2169,31 +2211,26 @@ public class RelDecorrelator implements 
ReflectiveVisitor {
   public static final class RemoveCorrelationForScalarAggregateRule
       extends RelRule<RemoveCorrelationForScalarAggregateRule
       .RemoveCorrelationForScalarAggregateRuleConfig> {
-    private final RelDecorrelator d;
-
-    static RemoveCorrelationForScalarAggregateRuleConfig 
config(RelDecorrelator d,
-        RelBuilderFactory relBuilderFactory) {
-      return ImmutableRemoveCorrelationForScalarAggregateRuleConfig.builder()
-          .withRelBuilderFactory(relBuilderFactory)
-          .withOperandSupplier(b0 ->
-              b0.operand(Correlate.class).inputs(
-                  b1 -> b1.operand(RelNode.class).anyInputs(),
-                  b2 -> b2.operand(Project.class).oneInput(b3 ->
-                      b3.operand(Aggregate.class)
-                          .predicate(Aggregate::isSimple).oneInput(b4 ->
-                          b4.operand(Project.class).oneInput(b5 ->
-                              b5.operand(RelNode.class).anyInputs())))))
-          .withDecorrelator(d)
-          .build();
-    }
+
+    static final RemoveCorrelationForScalarAggregateRuleConfig DEFAULT =
+        ImmutableRemoveCorrelationForScalarAggregateRuleConfig.builder()
+            .withOperandSupplier(b0 ->
+                b0.operand(Correlate.class).inputs(
+                    b1 -> b1.operand(RelNode.class).anyInputs(),
+                    b2 -> b2.operand(Project.class).oneInput(b3 ->
+                        b3.operand(Aggregate.class)
+                            .predicate(Aggregate::isSimple).oneInput(b4 ->
+                                b4.operand(Project.class).oneInput(b5 ->
+                                    b5.operand(RelNode.class).anyInputs())))))
+            .build();
 
     /** Creates a RemoveCorrelationForScalarAggregateRule. */
     
RemoveCorrelationForScalarAggregateRule(RemoveCorrelationForScalarAggregateRuleConfig
 config) {
       super(config);
-      d = requireNonNull(config.decorrelator());
     }
 
     @Override public void onMatch(RelOptRuleCall call) {
+      final RelDecorrelator d = call.getPlanner().getDecorrelator();
       final Correlate correlate = call.rel(0);
       final RelNode left = call.rel(1);
       final Project aggOutputProject = call.rel(2);
@@ -2542,12 +2579,9 @@ public class RelDecorrelator implements 
ReflectiveVisitor {
       d.removeCorVarFromTree(correlate);
     }
 
-    /** Rule configuration.
-     *
-     * <p>Extends {@link RelDecorrelator.Config} because rule needs a
-     * decorrelator instance. */
+    /** Rule configuration. */
     @Value.Immutable(singleton = false)
-    public interface RemoveCorrelationForScalarAggregateRuleConfig extends 
RelDecorrelator.Config {
+    public interface RemoveCorrelationForScalarAggregateRuleConfig extends 
RelRule.Config {
       @Override default RemoveCorrelationForScalarAggregateRule toRule() {
         return new RemoveCorrelationForScalarAggregateRule(this);
       }
@@ -2565,31 +2599,33 @@ public class RelDecorrelator implements 
ReflectiveVisitor {
   /** Planner rule that adjusts projects when counts are added. */
   public static final class AdjustProjectForCountAggregateRule
       extends 
RelRule<AdjustProjectForCountAggregateRule.AdjustProjectForCountAggregateRuleConfig>
 {
-    final RelDecorrelator d;
-
-    static AdjustProjectForCountAggregateRuleConfig config(
-        boolean flavor, RelDecorrelator decorrelator, RelBuilderFactory 
relBuilderFactory) {
-      return ImmutableAdjustProjectForCountAggregateRuleConfig.builder()
-          .withRelBuilderFactory(relBuilderFactory)
-          .withOperandSupplier(b0 ->
-              b0.operand(Correlate.class).inputs(
-                  b1 -> b1.operand(RelNode.class).anyInputs(),
-                  b2 -> flavor
-                      ? b2.operand(Project.class).oneInput(b3 ->
-                      b3.operand(Aggregate.class).anyInputs())
-                      : b2.operand(Aggregate.class).anyInputs()))
-          .withFlavor(flavor)
-          .withDecorrelator(decorrelator)
-          .build();
-    }
+
+    static final AdjustProjectForCountAggregateRuleConfig DEFAULT_WITH_FAVLOR =
+        ImmutableAdjustProjectForCountAggregateRuleConfig.builder()
+            .withOperandSupplier(b0 ->
+                b0.operand(Correlate.class).inputs(
+                    b1 -> b1.operand(RelNode.class).anyInputs(),
+                    b2 -> b2.operand(Project.class)
+                        .oneInput(b3 -> 
b3.operand(Aggregate.class).anyInputs())))
+            .withFlavor(true)
+            .build();
+
+    static final AdjustProjectForCountAggregateRuleConfig 
DEFAULT_WITHOUT_FAVLOR =
+        ImmutableAdjustProjectForCountAggregateRuleConfig.builder()
+            .withOperandSupplier(b0 ->
+                b0.operand(Correlate.class).inputs(
+                    b1 -> b1.operand(RelNode.class).anyInputs(),
+                    b2 -> b2.operand(Aggregate.class).anyInputs()))
+            .withFlavor(false)
+            .build();
 
     /** Creates an AdjustProjectForCountAggregateRule. */
     
AdjustProjectForCountAggregateRule(AdjustProjectForCountAggregateRuleConfig 
config) {
       super(config);
-      this.d = requireNonNull(config.decorrelator());
     }
 
     @Override public void onMatch(RelOptRuleCall call) {
+      final RelDecorrelator d = call.getPlanner().getDecorrelator();
       final Correlate correlate = call.rel(0);
       final RelNode left = call.rel(1);
       final Project aggOutputProject;
@@ -2612,10 +2648,11 @@ public class RelDecorrelator implements 
ReflectiveVisitor {
             .projectNamed(projects.leftList(), projects.rightList(), true);
         aggOutputProject = (Project) relBuilder.build();
       }
-      onMatch2(call, correlate, left, aggOutputProject, aggregate);
+      onMatch2(d, call, correlate, left, aggOutputProject, aggregate);
     }
 
     private void onMatch2(
+        RelDecorrelator d,
         RelOptRuleCall call,
         Correlate correlate,
         RelNode leftInput,
@@ -2710,7 +2747,7 @@ public class RelDecorrelator implements ReflectiveVisitor 
{
 
     /** Rule configuration. */
     @Value.Immutable(singleton = false)
-    public interface AdjustProjectForCountAggregateRuleConfig extends 
RelDecorrelator.Config {
+    public interface AdjustProjectForCountAggregateRuleConfig extends 
RelRule.Config {
       @Override default AdjustProjectForCountAggregateRule toRule() {
         return new AdjustProjectForCountAggregateRule(this);
       }
@@ -3018,16 +3055,6 @@ public class RelDecorrelator implements 
ReflectiveVisitor {
     }
   }
 
-  /** Base configuration for rules that are non-static in a RelDecorrelator. */
-  public interface Config extends RelRule.Config {
-    /** Returns the RelDecorrelator that will be context for the created
-     * rule instance. */
-    RelDecorrelator decorrelator();
-
-    /** Sets {@link #decorrelator}. */
-    Config withDecorrelator(RelDecorrelator decorrelator);
-  }
-
   // -------------------------------------------------------------------------
   //  Getter/Setter
   // -------------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java 
b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
index 9c5a3d597c..029e185995 100644
--- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
+++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
@@ -34,6 +34,8 @@ import org.apache.calcite.tools.Planner;
 import org.apache.calcite.tools.Program;
 import org.apache.calcite.tools.Programs;
 import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RuleSet;
+import org.apache.calcite.tools.RuleSets;
 import org.apache.calcite.util.Holder;
 import org.apache.calcite.util.TestUtil;
 
@@ -175,4 +177,89 @@ public class RelDecorrelatorTest {
         + "                LogicalTableScan(table=[[scott, EMP]])\n";
     assertThat(after, hasTree(planAfter));
   }
+
+  /**
+   * Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6674";>[CALCITE-6674] Make
+   * RelDecorrelator rules configurable</a>.
+   */
+  @Test void testDecorrelatorCustomizeRules() {
+    final FrameworkConfig frameworkConfig = config().build();
+    final RelBuilder builder = RelBuilder.create(frameworkConfig);
+    final RelOptCluster cluster = builder.getCluster();
+    final Planner planner = Frameworks.getPlanner(frameworkConfig);
+    final String sql = "select ROW("
+        + "(select deptno\n"
+        + "from dept\n"
+        + "where dept.deptno = emp.deptno), emp.ename)\n"
+        + "from emp";
+    final RelNode parsedRel;
+    try {
+      final SqlNode parse = planner.parse(sql);
+      final SqlNode validate = planner.validate(parse);
+      parsedRel = planner.rel(validate).rel;
+    } catch (Exception e) {
+      throw TestUtil.rethrow(e);
+    }
+
+    // Convert SubQuery into Correlate
+    final HepProgram hepProgram = HepProgram.builder()
+        
.addRuleCollection(ImmutableList.of(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE))
+        .build();
+    final Program program =
+        Programs.of(hepProgram, true,
+            requireNonNull(cluster.getMetadataProvider()));
+    final RelNode original =
+        program.run(cluster.getPlanner(), parsedRel, cluster.traitSet(),
+            Collections.emptyList(), Collections.emptyList());
+    final String planOriginal = ""
+        + "LogicalProject(EXPR$0=[ROW($8, $1)])\n"
+        + "  LogicalCorrelate(correlation=[$cor0], joinType=[left], 
requiredColumns=[{7}])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n"
+        + "    LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n"
+        + "      LogicalProject(DEPTNO=[$0])\n"
+        + "        LogicalFilter(condition=[=($0, $cor0.DEPTNO)])\n"
+        + "          LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(original, hasTree(planOriginal));
+
+    // Default decorrelate
+    final RelNode decorrelatedDefault = 
RelDecorrelator.decorrelateQuery(original, builder);
+    final String planDecorrelatedDefault = ""
+        + "LogicalProject(EXPR$0=[ROW($8, $1)])\n"
+        + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], DEPTNO8=[$8])\n"
+        + "    LogicalJoin(condition=[=($8, $7)], joinType=[left])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n"
+        + "      LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(decorrelatedDefault, hasTree(planDecorrelatedDefault));
+
+    // Decorrelate using explicitly the same rules as the default ones: same 
result
+    final RuleSet defaultRules =
+        
RuleSets.ofList(RelDecorrelator.RemoveSingleAggregateRule.DEFAULT.toRule(),
+            
RelDecorrelator.RemoveCorrelationForScalarProjectRule.DEFAULT.toRule(),
+            
RelDecorrelator.RemoveCorrelationForScalarAggregateRule.DEFAULT.toRule());
+    final RelNode decorrelatedDefault2 =
+        RelDecorrelator.decorrelateQuery(original, builder, defaultRules);
+    assertThat(decorrelatedDefault2, hasTree(planDecorrelatedDefault));
+
+    // Decorrelate using just the relevant rule for this query: same result
+    final RuleSet relevantRule =
+        
RuleSets.ofList(RelDecorrelator.RemoveCorrelationForScalarProjectRule.DEFAULT.toRule());
+    final RelNode decorrelatedRelevantRule =
+        RelDecorrelator.decorrelateQuery(original, builder, relevantRule);
+    assertThat(decorrelatedRelevantRule, hasTree(planDecorrelatedDefault));
+
+    // Decorrelate without any pre-rules (just the "main" decorrelate 
program): decorrelated
+    // but aggregate is kept
+    final RuleSet noRules = RuleSets.ofList(Collections.emptyList());
+    final RelNode decorrelatedNoRules =
+        RelDecorrelator.decorrelateQuery(original, builder, noRules);
+    final String planDecorrelatedNoRules = ""
+        + "LogicalProject(EXPR$0=[ROW($9, $1)])\n"
+        + "  LogicalJoin(condition=[=($7, $8)], joinType=[left])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n"
+        + "    LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n"
+        + "      LogicalProject(DEPTNO1=[$0], DEPTNO=[$0])\n"
+        + "        LogicalTableScan(table=[[scott, DEPT]])\n";
+    assertThat(decorrelatedNoRules, hasTree(planDecorrelatedNoRules));
+  }
 }

Reply via email to