This is an automated email from the ASF dual-hosted git repository.
rubenql pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new ac7d25c74f [CALCITE-6788] LoptOptimizeJoinRule should be able to
delegate costs to the planner
ac7d25c74f is described below
commit ac7d25c74f5ba335c965b3d0492e712d43e1657e
Author: Ruben Quesada Lopez <[email protected]>
AuthorDate: Fri Mar 7 20:09:06 2025 +0000
[CALCITE-6788] LoptOptimizeJoinRule should be able to delegate costs to the
planner
---
.../calcite/rel/rules/LoptOptimizeJoinRule.java | 63 ++++++++++++---------
.../org/apache/calcite/test/RelOptRulesTest.java | 66 ++++++++++++++++++++++
.../org/apache/calcite/test/RelOptRulesTest.xml | 56 ++++++++++++++++++
3 files changed, 157 insertions(+), 28 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java
index 90a0c9be82..16cb367ac2 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java
@@ -132,7 +132,7 @@ public LoptOptimizeJoinRule(RelFactories.JoinFactory
joinFactory,
findRemovableSelfJoins(mq, multiJoin);
- findBestOrderings(mq, call.builder(), multiJoin, semiJoinOpt, call);
+ findBestOrderings(call, multiJoin, semiJoinOpt);
}
/**
@@ -442,12 +442,10 @@ private static boolean isSelfJoinFilterUnique(
* @param semiJoinOpt optimal semijoins for each factor
* @param call RelOptRuleCall associated with this rule
*/
- private static void findBestOrderings(
- RelMetadataQuery mq,
- RelBuilder relBuilder,
+ private void findBestOrderings(
+ RelOptRuleCall call,
LoptMultiJoin multiJoin,
- LoptSemiJoinOptimizer semiJoinOpt,
- RelOptRuleCall call) {
+ LoptSemiJoinOptimizer semiJoinOpt) {
final List<RelNode> plans = new ArrayList<>();
final List<String> fieldNames =
@@ -461,8 +459,7 @@ private static void findBestOrderings(
}
LoptJoinTree joinTree =
createOrdering(
- mq,
- relBuilder,
+ call,
multiJoin,
semiJoinOpt,
i);
@@ -679,9 +676,8 @@ private static void setFactorJoinKeys(
* @return constructed join tree or null if it is not possible for
* firstFactor to appear as the first factor in the join
*/
- private static @Nullable LoptJoinTree createOrdering(
- RelMetadataQuery mq,
- RelBuilder relBuilder,
+ private @Nullable LoptJoinTree createOrdering(
+ RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
int firstFactor) {
@@ -712,7 +708,7 @@ private static void setFactorJoinKeys(
} else {
nextFactor =
getBestNextFactor(
- mq,
+ call.getMetadataQuery(),
multiJoin,
factorsToAdd,
factorsAdded,
@@ -733,8 +729,7 @@ private static void setFactorJoinKeys(
factorsNeeded.and(factorsAdded);
joinTree =
addFactorToTree(
- mq,
- relBuilder,
+ call,
multiJoin,
semiJoinOpt,
joinTree,
@@ -878,9 +873,8 @@ private static boolean isJoinTree(RelNode rel) {
* @return optimal join tree with the new factor added if it is possible to
* add the factor; otherwise, null is returned
*/
- private static @Nullable LoptJoinTree addFactorToTree(
- RelMetadataQuery mq,
- RelBuilder relBuilder,
+ private @Nullable LoptJoinTree addFactorToTree(
+ RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
@Nullable LoptJoinTree joinTree,
@@ -888,6 +882,8 @@ private static boolean isJoinTree(RelNode rel) {
BitSet factorsNeeded,
List<RexNode> filtersToAdd,
boolean selfJoin) {
+ final RelMetadataQuery mq = call.getMetadataQuery();
+ final RelBuilder relBuilder = call.builder();
// if the factor corresponds to the null generating factor in an outer
// join that can be removed, then create a replacement join
@@ -943,8 +939,7 @@ private static boolean isJoinTree(RelNode rel) {
selfJoin);
LoptJoinTree pushDownTree =
pushDownFactor(
- mq,
- relBuilder,
+ call,
multiJoin,
semiJoinOpt,
joinTree,
@@ -959,10 +954,10 @@ private static boolean isJoinTree(RelNode rel) {
RelOptCost costPushDown = null;
RelOptCost costTop = null;
if (pushDownTree != null) {
- costPushDown = mq.getCumulativeCost(pushDownTree.getJoinTree());
+ costPushDown = config.costFunction().getCost(call,
pushDownTree.getJoinTree());
}
if (topTree != null) {
- costTop = mq.getCumulativeCost(topTree.getJoinTree());
+ costTop = config.costFunction().getCost(call, topTree.getJoinTree());
}
if (pushDownTree == null) {
@@ -1035,9 +1030,8 @@ private static int rowWidthCost(RelNode tree) {
* join tree if it is possible to do the pushdown; otherwise, null is
* returned
*/
- private static @Nullable LoptJoinTree pushDownFactor(
- RelMetadataQuery mq,
- RelBuilder relBuilder,
+ private @Nullable LoptJoinTree pushDownFactor(
+ RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
LoptJoinTree joinTree,
@@ -1110,8 +1104,7 @@ private static int rowWidthCost(RelNode tree) {
LoptJoinTree subTree = (childNo == 0) ? left : right;
subTree =
addFactorToTree(
- mq,
- relBuilder,
+ call,
multiJoin,
semiJoinOpt,
subTree,
@@ -1165,8 +1158,8 @@ private static int rowWidthCost(RelNode tree) {
// create the new join tree with the factor pushed down
return createJoinSubtree(
- mq,
- relBuilder,
+ call.getMetadataQuery(),
+ call.builder(),
multiJoin,
left,
right,
@@ -2089,12 +2082,26 @@ private static boolean
areSelfJoinKeysUnique(RelMetadataQuery mq,
joinInfo.leftSet());
}
+ /** Function to compute cost. */
+ @FunctionalInterface
+ public interface CostFunction {
+ @Nullable RelOptCost getCost(RelOptRuleCall call, RelNode relNode);
+ }
+
/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableLoptOptimizeJoinRule.Config.of()
.withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs());
+ /** Function to calculate intermediate cost computations. */
+ @Value.Default default CostFunction costFunction() {
+ return (call, rel) -> call.getMetadataQuery().getCumulativeCost(rel);
+ }
+
+ /** Sets {@link #costFunction()}. */
+ Config withCostFunction(CostFunction function);
+
@Override default LoptOptimizeJoinRule toRule() {
return new LoptOptimizeJoinRule(this);
}
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index efb177aad1..cc2b121ce8 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -26,6 +26,7 @@
import org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptCostImpl;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptRule;
@@ -53,6 +54,7 @@
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.hint.HintPredicates;
import org.apache.calcite.rel.hint.HintStrategyTable;
@@ -62,6 +64,7 @@
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalTableModify;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule;
import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule;
@@ -78,6 +81,7 @@
import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
import org.apache.calcite.rel.rules.JoinAssociateRule;
import org.apache.calcite.rel.rules.JoinCommuteRule;
+import org.apache.calcite.rel.rules.LoptOptimizeJoinRule;
import org.apache.calcite.rel.rules.MeasureRules;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.ProjectCorrelateTransposeRule;
@@ -9733,6 +9737,68 @@ private void
checkJoinAssociateRuleWithTopAlwaysTrueCondition(boolean allowAlway
.check();
}
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6788">[CALCITE-6788]
+ * LoptOptimizeJoinRule should be able to delegate costs to the planner</a>.
*/
+ @Test void testLoptOptimizeJoinRuleWithDefaultCost() {
+ // Use the default rule
+ checkLoptOptimizeJoinRule(CoreRules.MULTI_JOIN_OPTIMIZE);
+ }
+
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6788">[CALCITE-6788]
+ * LoptOptimizeJoinRule should be able to delegate costs to the planner</a>.
*/
+ @Test void testLoptOptimizeJoinRuleWithSpecialCost() {
+ // Use an ad-hoc version of the rule that uses planner#getCost instead of
mq#getCumulativeCost
+ checkLoptOptimizeJoinRule(LoptOptimizeJoinRule.Config.DEFAULT
+ .withCostFunction((c, r) -> c.getPlanner().getCost(r,
c.getMetadataQuery()))
+ .toRule());
+ }
+
+ private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
+ final HepProgram preProgram = new HepProgramBuilder()
+ .addMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN)
+ .build();
+
+ final HepProgram program = HepProgram.builder()
+ .addMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .addRuleInstance(rule)
+ .build();
+
+ // Special planner that artificially favors joins on the same table
+ final HepPlanner planner = new HepPlanner(program) {
+ @Override public RelOptCost getCost(RelNode rel, RelMetadataQuery mq) {
+ if (rel instanceof Join
+ && rel.getInput(0).stripped() instanceof TableScan
+ && rel.getInput(1).stripped() instanceof TableScan) {
+ TableScan left = (TableScan) rel.getInput(0).stripped();
+ TableScan right = (TableScan) rel.getInput(1).stripped();
+ if (left.getTable().equals(right.getTable())) {
+ // Tiny cost for self-joins
+ return getCostFactory().makeTinyCost();
+ }
+ }
+
+ // General case: just define a kind of cumulative cost based on the
rowCount (to avoid
+ // the infinite costs from the Logical operators)
+ RelOptCost cost = new RelOptCostImpl(mq.getRowCount(rel));
+ for (RelNode input : rel.getInputs()) {
+ cost = cost.plus(getCost(input, mq));
+ }
+ return cost;
+ }
+ };
+
+ sql("select e1.empno from emp e1"
+ + " inner join dept d1 on d1.deptno = e1.deptno"
+ + " inner join emp e2 on e1.ename = e2.ename"
+ + " inner join dept d2 on d2.deptno = e1.deptno")
+ .withPre(preProgram)
+ .withPlanner(planner)
+ .check();
+ }
+
/**
* Test case for
* <a
href="https://issues.apache.org/jira/browse/CALCITE-6874">[CALCITE-6874]
diff --git
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index ea2e73bd05..4f63a49369 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -7436,6 +7436,62 @@ LogicalProject(EMPNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testLoptOptimizeJoinRuleWithDefaultCost">
+ <Resource name="sql">
+ <![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno =
e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on
d2.deptno = e1.deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(EMPNO=[$0])
+ MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))],
isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]],
outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL,
ALL]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(EMPNO=[$0])
+ LogicalProject(EMPNO=[$9], ENAME=[$10], JOB=[$11], MGR=[$12],
HIREDATE=[$13], SAL=[$14], COMM=[$15], DEPTNO=[$16], SLACKER=[$17],
DEPTNO0=[$18], NAME=[$19], EMPNO0=[$0], ENAME0=[$1], JOB0=[$2], MGR0=[$3],
HIREDATE0=[$4], SAL0=[$5], COMM0=[$6], DEPTNO1=[$7], SLACKER0=[$8],
DEPTNO2=[$20], NAME0=[$21])
+ LogicalJoin(condition=[=($10, $1)], joinType=[inner])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalJoin(condition=[=($11, $7)], joinType=[inner])
+ LogicalJoin(condition=[=($9, $7)], joinType=[inner])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testLoptOptimizeJoinRuleWithSpecialCost">
+ <Resource name="sql">
+ <![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno =
e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on
d2.deptno = e1.deptno]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(EMPNO=[$0])
+ MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))],
isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]],
outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL,
ALL]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(EMPNO=[$0])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],
SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$18], NAME=[$19],
EMPNO0=[$9], ENAME0=[$10], JOB0=[$11], MGR0=[$12], HIREDATE0=[$13], SAL0=[$14],
COMM0=[$15], DEPTNO1=[$16], SLACKER0=[$17], DEPTNO2=[$20], NAME0=[$21])
+ LogicalJoin(condition=[=($20, $7)], joinType=[inner])
+ LogicalJoin(condition=[=($18, $7)], joinType=[inner])
+ LogicalJoin(condition=[=($1, $10)], joinType=[inner])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>