This is an automated email from the ASF dual-hosted git repository.

rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new ac7d25c74f [CALCITE-6788] LoptOptimizeJoinRule should be able to 
delegate costs to the planner
ac7d25c74f is described below

commit ac7d25c74f5ba335c965b3d0492e712d43e1657e
Author: Ruben Quesada Lopez <[email protected]>
AuthorDate: Fri Mar 7 20:09:06 2025 +0000

    [CALCITE-6788] LoptOptimizeJoinRule should be able to delegate costs to the 
planner
---
 .../calcite/rel/rules/LoptOptimizeJoinRule.java    | 63 ++++++++++++---------
 .../org/apache/calcite/test/RelOptRulesTest.java   | 66 ++++++++++++++++++++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 56 ++++++++++++++++++
 3 files changed, 157 insertions(+), 28 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java 
b/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java
index 90a0c9be82..16cb367ac2 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java
@@ -132,7 +132,7 @@ public LoptOptimizeJoinRule(RelFactories.JoinFactory 
joinFactory,
 
     findRemovableSelfJoins(mq, multiJoin);
 
-    findBestOrderings(mq, call.builder(), multiJoin, semiJoinOpt, call);
+    findBestOrderings(call, multiJoin, semiJoinOpt);
   }
 
   /**
@@ -442,12 +442,10 @@ private static boolean isSelfJoinFilterUnique(
    * @param semiJoinOpt optimal semijoins for each factor
    * @param call RelOptRuleCall associated with this rule
    */
-  private static void findBestOrderings(
-      RelMetadataQuery mq,
-      RelBuilder relBuilder,
+  private void findBestOrderings(
+      RelOptRuleCall call,
       LoptMultiJoin multiJoin,
-      LoptSemiJoinOptimizer semiJoinOpt,
-      RelOptRuleCall call) {
+      LoptSemiJoinOptimizer semiJoinOpt) {
     final List<RelNode> plans = new ArrayList<>();
 
     final List<String> fieldNames =
@@ -461,8 +459,7 @@ private static void findBestOrderings(
       }
       LoptJoinTree joinTree =
           createOrdering(
-              mq,
-              relBuilder,
+              call,
               multiJoin,
               semiJoinOpt,
               i);
@@ -679,9 +676,8 @@ private static void setFactorJoinKeys(
    * @return constructed join tree or null if it is not possible for
    * firstFactor to appear as the first factor in the join
    */
-  private static @Nullable LoptJoinTree createOrdering(
-      RelMetadataQuery mq,
-      RelBuilder relBuilder,
+  private @Nullable LoptJoinTree createOrdering(
+      RelOptRuleCall call,
       LoptMultiJoin multiJoin,
       LoptSemiJoinOptimizer semiJoinOpt,
       int firstFactor) {
@@ -712,7 +708,7 @@ private static void setFactorJoinKeys(
         } else {
           nextFactor =
               getBestNextFactor(
-                  mq,
+                  call.getMetadataQuery(),
                   multiJoin,
                   factorsToAdd,
                   factorsAdded,
@@ -733,8 +729,7 @@ private static void setFactorJoinKeys(
       factorsNeeded.and(factorsAdded);
       joinTree =
           addFactorToTree(
-              mq,
-              relBuilder,
+              call,
               multiJoin,
               semiJoinOpt,
               joinTree,
@@ -878,9 +873,8 @@ private static boolean isJoinTree(RelNode rel) {
    * @return optimal join tree with the new factor added if it is possible to
    * add the factor; otherwise, null is returned
    */
-  private static @Nullable LoptJoinTree addFactorToTree(
-      RelMetadataQuery mq,
-      RelBuilder relBuilder,
+  private @Nullable LoptJoinTree addFactorToTree(
+      RelOptRuleCall call,
       LoptMultiJoin multiJoin,
       LoptSemiJoinOptimizer semiJoinOpt,
       @Nullable LoptJoinTree joinTree,
@@ -888,6 +882,8 @@ private static boolean isJoinTree(RelNode rel) {
       BitSet factorsNeeded,
       List<RexNode> filtersToAdd,
       boolean selfJoin) {
+    final RelMetadataQuery mq = call.getMetadataQuery();
+    final RelBuilder relBuilder = call.builder();
 
     // if the factor corresponds to the null generating factor in an outer
     // join that can be removed, then create a replacement join
@@ -943,8 +939,7 @@ private static boolean isJoinTree(RelNode rel) {
             selfJoin);
     LoptJoinTree pushDownTree =
         pushDownFactor(
-            mq,
-            relBuilder,
+            call,
             multiJoin,
             semiJoinOpt,
             joinTree,
@@ -959,10 +954,10 @@ private static boolean isJoinTree(RelNode rel) {
     RelOptCost costPushDown = null;
     RelOptCost costTop = null;
     if (pushDownTree != null) {
-      costPushDown = mq.getCumulativeCost(pushDownTree.getJoinTree());
+      costPushDown = config.costFunction().getCost(call, 
pushDownTree.getJoinTree());
     }
     if (topTree != null) {
-      costTop = mq.getCumulativeCost(topTree.getJoinTree());
+      costTop = config.costFunction().getCost(call, topTree.getJoinTree());
     }
 
     if (pushDownTree == null) {
@@ -1035,9 +1030,8 @@ private static int rowWidthCost(RelNode tree) {
    * join tree if it is possible to do the pushdown; otherwise, null is
    * returned
    */
-  private static @Nullable LoptJoinTree pushDownFactor(
-      RelMetadataQuery mq,
-      RelBuilder relBuilder,
+  private @Nullable LoptJoinTree pushDownFactor(
+      RelOptRuleCall call,
       LoptMultiJoin multiJoin,
       LoptSemiJoinOptimizer semiJoinOpt,
       LoptJoinTree joinTree,
@@ -1110,8 +1104,7 @@ private static int rowWidthCost(RelNode tree) {
     LoptJoinTree subTree = (childNo == 0) ? left : right;
     subTree =
         addFactorToTree(
-            mq,
-            relBuilder,
+            call,
             multiJoin,
             semiJoinOpt,
             subTree,
@@ -1165,8 +1158,8 @@ private static int rowWidthCost(RelNode tree) {
 
     // create the new join tree with the factor pushed down
     return createJoinSubtree(
-        mq,
-        relBuilder,
+        call.getMetadataQuery(),
+        call.builder(),
         multiJoin,
         left,
         right,
@@ -2089,12 +2082,26 @@ private static boolean 
areSelfJoinKeysUnique(RelMetadataQuery mq,
         joinInfo.leftSet());
   }
 
+  /** Function to compute cost. */
+  @FunctionalInterface
+  public interface CostFunction {
+    @Nullable RelOptCost getCost(RelOptRuleCall call, RelNode relNode);
+  }
+
   /** Rule configuration. */
   @Value.Immutable
   public interface Config extends RelRule.Config {
     Config DEFAULT = ImmutableLoptOptimizeJoinRule.Config.of()
         .withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs());
 
+    /** Function to calculate intermediate cost computations. */
+    @Value.Default default CostFunction costFunction() {
+      return (call, rel) -> call.getMetadataQuery().getCumulativeCost(rel);
+    }
+
+    /** Sets {@link #costFunction()}. */
+    Config withCostFunction(CostFunction function);
+
     @Override default LoptOptimizeJoinRule toRule() {
       return new LoptOptimizeJoinRule(this);
     }
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index efb177aad1..cc2b121ce8 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -26,6 +26,7 @@
 import org.apache.calcite.config.CalciteConnectionConfig;
 import org.apache.calcite.plan.Contexts;
 import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptCost;
 import org.apache.calcite.plan.RelOptCostImpl;
 import org.apache.calcite.plan.RelOptPlanner;
 import org.apache.calcite.plan.RelOptRule;
@@ -53,6 +54,7 @@
 import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.calcite.rel.core.Minus;
 import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.TableScan;
 import org.apache.calcite.rel.core.Union;
 import org.apache.calcite.rel.hint.HintPredicates;
 import org.apache.calcite.rel.hint.HintStrategyTable;
@@ -62,6 +64,7 @@
 import org.apache.calcite.rel.logical.LogicalFilter;
 import org.apache.calcite.rel.logical.LogicalProject;
 import org.apache.calcite.rel.logical.LogicalTableModify;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule;
 import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
 import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule;
@@ -78,6 +81,7 @@
 import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
 import org.apache.calcite.rel.rules.JoinAssociateRule;
 import org.apache.calcite.rel.rules.JoinCommuteRule;
+import org.apache.calcite.rel.rules.LoptOptimizeJoinRule;
 import org.apache.calcite.rel.rules.MeasureRules;
 import org.apache.calcite.rel.rules.MultiJoin;
 import org.apache.calcite.rel.rules.ProjectCorrelateTransposeRule;
@@ -9733,6 +9737,68 @@ private void 
checkJoinAssociateRuleWithTopAlwaysTrueCondition(boolean allowAlway
         .check();
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6788";>[CALCITE-6788]
+   * LoptOptimizeJoinRule should be able to delegate costs to the planner</a>. 
*/
+  @Test void testLoptOptimizeJoinRuleWithDefaultCost() {
+    // Use the default rule
+    checkLoptOptimizeJoinRule(CoreRules.MULTI_JOIN_OPTIMIZE);
+  }
+
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6788";>[CALCITE-6788]
+   * LoptOptimizeJoinRule should be able to delegate costs to the planner</a>. 
*/
+  @Test void testLoptOptimizeJoinRuleWithSpecialCost() {
+    // Use an ad-hoc version of the rule that uses planner#getCost instead of 
mq#getCumulativeCost
+    checkLoptOptimizeJoinRule(LoptOptimizeJoinRule.Config.DEFAULT
+        .withCostFunction((c, r) -> c.getPlanner().getCost(r, 
c.getMetadataQuery()))
+        .toRule());
+  }
+
+  private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
+    final HepProgram preProgram = new HepProgramBuilder()
+        .addMatchOrder(HepMatchOrder.BOTTOM_UP)
+        .addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN)
+        .build();
+
+    final HepProgram program = HepProgram.builder()
+        .addMatchOrder(HepMatchOrder.BOTTOM_UP)
+        .addRuleInstance(rule)
+        .build();
+
+    // Special planner that artificially favors joins on the same table
+    final HepPlanner planner = new HepPlanner(program) {
+      @Override public RelOptCost getCost(RelNode rel, RelMetadataQuery mq) {
+        if (rel instanceof Join
+            && rel.getInput(0).stripped() instanceof TableScan
+            && rel.getInput(1).stripped() instanceof TableScan) {
+          TableScan left = (TableScan) rel.getInput(0).stripped();
+          TableScan right = (TableScan) rel.getInput(1).stripped();
+          if (left.getTable().equals(right.getTable())) {
+            // Tiny cost for self-joins
+            return getCostFactory().makeTinyCost();
+          }
+        }
+
+        // General case: just define a kind of cumulative cost based on the 
rowCount (to avoid
+        // the infinite costs from the Logical operators)
+        RelOptCost cost = new RelOptCostImpl(mq.getRowCount(rel));
+        for (RelNode input : rel.getInputs()) {
+          cost = cost.plus(getCost(input, mq));
+        }
+        return cost;
+      }
+    };
+
+    sql("select e1.empno from emp e1"
+        + " inner join dept d1 on d1.deptno = e1.deptno"
+        + " inner join emp e2 on e1.ename = e2.ename"
+        + " inner join dept d2 on d2.deptno = e1.deptno")
+        .withPre(preProgram)
+        .withPlanner(planner)
+        .check();
+  }
+
   /**
    * Test case for
    * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6874";>[CALCITE-6874]
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index ea2e73bd05..4f63a49369 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -7436,6 +7436,62 @@ LogicalProject(EMPNO=[$0])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
       LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testLoptOptimizeJoinRuleWithDefaultCost">
+    <Resource name="sql">
+      <![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno = 
e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on 
d2.deptno = e1.deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(EMPNO=[$0])
+  MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))], 
isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], 
outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL, 
ALL]])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(EMPNO=[$0])
+  LogicalProject(EMPNO=[$9], ENAME=[$10], JOB=[$11], MGR=[$12], 
HIREDATE=[$13], SAL=[$14], COMM=[$15], DEPTNO=[$16], SLACKER=[$17], 
DEPTNO0=[$18], NAME=[$19], EMPNO0=[$0], ENAME0=[$1], JOB0=[$2], MGR0=[$3], 
HIREDATE0=[$4], SAL0=[$5], COMM0=[$6], DEPTNO1=[$7], SLACKER0=[$8], 
DEPTNO2=[$20], NAME0=[$21])
+    LogicalJoin(condition=[=($10, $1)], joinType=[inner])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+      LogicalJoin(condition=[=($11, $7)], joinType=[inner])
+        LogicalJoin(condition=[=($9, $7)], joinType=[inner])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+          LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+        LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testLoptOptimizeJoinRuleWithSpecialCost">
+    <Resource name="sql">
+      <![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno = 
e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on 
d2.deptno = e1.deptno]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalProject(EMPNO=[$0])
+  MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))], 
isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], 
outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL, 
ALL]])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalProject(EMPNO=[$0])
+  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], 
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$18], NAME=[$19], 
EMPNO0=[$9], ENAME0=[$10], JOB0=[$11], MGR0=[$12], HIREDATE0=[$13], SAL0=[$14], 
COMM0=[$15], DEPTNO1=[$16], SLACKER0=[$17], DEPTNO2=[$20], NAME0=[$21])
+    LogicalJoin(condition=[=($20, $7)], joinType=[inner])
+      LogicalJoin(condition=[=($18, $7)], joinType=[inner])
+        LogicalJoin(condition=[=($1, $10)], joinType=[inner])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+        LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
     </Resource>
   </TestCase>

Reply via email to