This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new fcbbea2676 [CALCITE-7000] Extend IntersectToSemiJoinRule to support
n-way inputs
fcbbea2676 is described below
commit fcbbea2676d485c370255325cdfcb8a93694f06e
Author: Zhen Chen <[email protected]>
AuthorDate: Wed May 7 09:47:37 2025 +0800
[CALCITE-7000] Extend IntersectToSemiJoinRule to support n-way inputs
---
.../calcite/rel/rules/IntersectToSemiJoinRule.java | 96 ++++++++++++++++------
.../org/apache/calcite/test/RelOptRulesTest.java | 13 +++
.../org/apache/calcite/test/RelOptRulesTest.xml | 47 ++++++++++-
core/src/test/resources/sql/planner.iq | 59 ++++++++++++-
4 files changed, 186 insertions(+), 29 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToSemiJoinRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToSemiJoinRule.java
index 3da2dfa97d..66afcaa88d 100644
---
a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToSemiJoinRule.java
+++
b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToSemiJoinRule.java
@@ -34,9 +34,51 @@
import java.util.List;
/**
- * Planner rule that translates a {@link org.apache.calcite.rel.core.Intersect}
+ * Planner rule that translates a {@link Intersect}
* to a series of {@link org.apache.calcite.rel.core.Join} that type is
- * {@link org.apache.calcite.rel.core.JoinRelType#SEMI}.
+ * {@link JoinRelType#SEMI}. This rule supports n-way Intersect conversion,
+ * as this rule can be repeatedly applied during query optimization to
+ * refine the plan.
+ *
+ * <h2>Example</h2>
+ *
+ <p>Original sql:
+ * <pre>{@code
+ * select ename from emp where deptno = 10
+ * intersect
+ * select deptno from emp where ename in ('a', 'b')
+ * intersect
+ * select ename from empnullables
+ * }</pre>
+ *
+ * <p>Original plan:
+ * <pre>{@code
+ * LogicalIntersect(all=[false])
+ * LogicalProject(ENAME=[$1])
+ * LogicalFilter(condition=[=($7, 10)])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ * LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
+ * LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ * LogicalProject(ENAME=[$1])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
+ * }</pre>
+ *
+ * <p>Plan after conversion:
+ * <pre>{@code
+ * LogicalProject(ENAME=[CAST($0):VARCHAR])
+ * LogicalAggregate(group=[{0}])
+ * LogicalJoin(condition=[<=>(CAST($0):VARCHAR, CAST($1):VARCHAR)],
joinType=[semi])
+ * LogicalJoin(condition=[=(CAST($0):VARCHAR, $1)], joinType=[semi])
+ * LogicalProject(ENAME=[$1])
+ * LogicalFilter(condition=[=($7, 10)])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ * LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
+ * LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ * LogicalProject(ENAME=[$1])
+ * LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
+ * }</pre>
*/
@Value.Enclosing
public class IntersectToSemiJoinRule
@@ -60,32 +102,40 @@ protected IntersectToSemiJoinRule(Config config) {
final RexBuilder rexBuilder = builder.getRexBuilder();
List<RelNode> inputs = intersect.getInputs();
- if (inputs.size() != 2) {
+ if (inputs.size() < 2) {
return;
}
- RelNode left = inputs.get(0);
- RelNode right = inputs.get(1);
-
- List<RexNode> conditions = new ArrayList<>();
- int fieldCount = left.getRowType().getFieldCount();
-
- for (int i = 0; i < fieldCount; i++) {
- RelDataType leftFieldType =
left.getRowType().getFieldList().get(i).getType();
- RelDataType rightFieldType =
right.getRowType().getFieldList().get(i).getType();
-
- conditions.add(
- builder.isNotDistinctFrom(
- rexBuilder.makeInputRef(leftFieldType, i),
- rexBuilder.makeInputRef(rightFieldType, i + fieldCount)));
+ final RelDataType leastRowType = intersect.getRowType();
+ RelNode current = inputs.get(0);
+ builder.push(current);
+
+ for (int i = 1; i < inputs.size(); i++) {
+ RelNode next = inputs.get(i);
+ List<RexNode> conditions = new ArrayList<>();
+ int fieldCount = current.getRowType().getFieldCount();
+
+ for (int j = 0; j < fieldCount; j++) {
+ RelDataType leftFieldType =
current.getRowType().getFieldList().get(j).getType();
+ RelDataType rightFieldType =
next.getRowType().getFieldList().get(j).getType();
+ RelDataType leastFieldType =
leastRowType.getFieldList().get(j).getType();
+
+ conditions.add(
+ builder.isNotDistinctFrom(
+ rexBuilder.makeCast(leastFieldType,
+ rexBuilder.makeInputRef(leftFieldType, j)),
+ rexBuilder.makeCast(leastFieldType,
+ rexBuilder.makeInputRef(rightFieldType, j + fieldCount))));
+ }
+ RexNode condition = RexUtil.composeConjunction(rexBuilder, conditions);
+
+ builder.push(next)
+ .join(JoinRelType.SEMI, condition);
+ current = builder.peek();
}
- RexNode condition = RexUtil.composeConjunction(rexBuilder, conditions);
-
- builder.push(left)
- .push(right)
- .join(JoinRelType.SEMI, condition)
- .distinct();
+ builder.distinct()
+ .convert(leastRowType, true);
call.transformTo(builder.build());
}
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 765e5ecd55..88dad59f12 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -3633,6 +3633,19 @@ private void
checkPushJoinThroughUnionOnRightDoesNotMatchSemiOrAntiJoin(JoinRelT
.check();
}
+ /** Test case for <a
href="https://issues.apache.org/jira/browse/CALCITE-7000">[CALCITE-7000]
+ * Extend IntersectToSemiJoinRule to support n-way inputs</a>. */
+ @Test void testIntersectToSemiJoin2() {
+ final String sql = "select ename from emp where deptno = 10\n"
+ + "intersect\n"
+ + "select deptno from emp where ename in ('a', 'b')\n"
+ + "intersect\n"
+ + "select ename from empnullables\n";
+ sql(sql).withPreRule(CoreRules.INTERSECT_MERGE)
+ .withRule(CoreRules.INTERSECT_TO_SEMI_JOIN)
+ .check();
+ }
+
/** Test case for <a
href="https://issues.apache.org/jira/browse/CALCITE-6880">
* [CALCITE-6880] Implement IntersectToSemiJoinRule</a>. */
@Test void testIntersectToSemiJoinMultiCol() {
diff --git
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 96fdbe60cb..2ab97cd100 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -6344,13 +6344,52 @@ LogicalIntersect(all=[false])
<Resource name="planAfter">
<![CDATA[
LogicalAggregate(group=[{0}])
- LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
+ LogicalJoin(condition=[=($0, $1)], joinType=[semi])
LogicalProject(ENAME=[$1])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(ENAME=[$1])
LogicalFilter(condition=[=($7, 20)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testIntersectToSemiJoin2">
+ <Resource name="sql">
+ <![CDATA[select ename from emp where deptno = 10
+intersect
+select deptno from emp where ename in ('a', 'b')
+intersect
+select ename from empnullables
+]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalIntersect(all=[false])
+ LogicalProject(ENAME=[$1])
+ LogicalFilter(condition=[=($7, 10)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
+ LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalProject(ENAME=[$1])
+ LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(ENAME=[CAST($0):VARCHAR])
+ LogicalAggregate(group=[{0}])
+ LogicalJoin(condition=[IS NOT DISTINCT FROM(CAST($0):VARCHAR,
CAST($1):VARCHAR)], joinType=[semi])
+ LogicalJoin(condition=[=(CAST($0):VARCHAR, $1)], joinType=[semi])
+ LogicalProject(ENAME=[$1])
+ LogicalFilter(condition=[=($7, 10)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
+ LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalProject(ENAME=[$1])
+ LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
]]>
</Resource>
</TestCase>
@@ -6382,7 +6421,7 @@ LogicalIntersect(all=[true])
<![CDATA[
LogicalIntersect(all=[true])
LogicalAggregate(group=[{0}])
- LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
+ LogicalJoin(condition=[=($0, $1)], joinType=[semi])
LogicalProject(ENAME=[$1])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -6416,7 +6455,7 @@ LogicalIntersect(all=[false])
<Resource name="planAfter">
<![CDATA[
LogicalAggregate(group=[{0}])
- LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
+ LogicalJoin(condition=[=($0, $1)], joinType=[semi])
LogicalProject(ENAME=[$1])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -6447,7 +6486,7 @@ LogicalIntersect(all=[false])
<Resource name="planAfter">
<![CDATA[
LogicalAggregate(group=[{0, 1}])
- LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $2), IS NOT DISTINCT
FROM($1, $3))], joinType=[semi])
+ LogicalJoin(condition=[AND(=($0, $2), =($1, $3))], joinType=[semi])
LogicalProject(DEPTNO=[$7], ENAME=[$1])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
diff --git a/core/src/test/resources/sql/planner.iq
b/core/src/test/resources/sql/planner.iq
index 9d6298b58c..dea07b77fa 100644
--- a/core/src/test/resources/sql/planner.iq
+++ b/core/src/test/resources/sql/planner.iq
@@ -38,7 +38,7 @@ select * from t as t2 where t2.i > 0;
!ok
-EnumerableNestedLoopJoin(condition=[IS NOT DISTINCT FROM($0, $1)],
joinType=[semi])
+EnumerableHashJoin(condition=[=($0, $1)], joinType=[semi])
EnumerableValues(tuples=[[{ 0 }, { 1 }]])
EnumerableCalc(expr#0=[{inputs}], expr#1=[0], expr#2=[>($t0, $t1)],
EXPR$0=[$t0], $condition=[$t2])
EnumerableValues(tuples=[[{ 0 }, { 1 }]])
@@ -58,7 +58,7 @@ select * from t as t2 where t2.i > 0;
!ok
-EnumerableNestedLoopJoin(condition=[IS NOT DISTINCT FROM($0, $1)],
joinType=[semi])
+EnumerableHashJoin(condition=[=($0, $1)], joinType=[semi])
EnumerableValues(tuples=[[{ 0 }, { 1 }]])
EnumerableCalc(expr#0=[{inputs}], expr#1=[0], expr#2=[>($t0, $t1)],
EXPR$0=[$t0], $condition=[$t2])
EnumerableValues(tuples=[[{ 0 }, { 1 }]])
@@ -128,9 +128,64 @@ EnumerableCalc(expr#0..2=[{inputs}], $f0=[$t1], $f1=[$t2])
EnumerableCalc(expr#0=[{inputs}], expr#1=[IS NOT NULL($t0)],
DEPTNO=[$t0], $condition=[$t1])
EnumerableValues(tuples=[[{ 10 }, { 10 }, { 20 }, { 30 }, { 30 }, {
50 }, { 50 }, { 60 }, { null }]])
!plan
+!set planner-rules original
+
+# [CALCITE-7000] Extend IntersectToSemiJoinRule to support n-way inputs
+!set planner-rules "
+-EnumerableRules.ENUMERABLE_INTERSECT_RULE,
+-CoreRules.INTERSECT_TO_DISTINCT,
++CoreRules.INTERSECT_TO_SEMI_JOIN"
+select a from (values (1.0), (2.0), (3.0), (4.0), (5.0)) as t1 (a)
+intersect
+select a from (values (1), (2)) as t2 (a)
+intersect
+select a from (values (1.0), (4.0), (null)) as t3 (a);
++-----+
+| A |
++-----+
+| 1.0 |
++-----+
+(1 row)
+!ok
+
+EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], A=[$t1])
+ EnumerableNestedLoopJoin(condition=[OR(AND(IS NULL(CAST($0):DECIMAL(11, 1)),
IS NULL(CAST($1):DECIMAL(11, 1))), =(CAST($0):DECIMAL(11, 1),
CAST($1):DECIMAL(11, 1)))], joinType=[semi])
+ EnumerableAggregate(group=[{0}])
+ EnumerableHashJoin(condition=[=($1, $3)], joinType=[semi])
+ EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) NOT
NULL], expr#2=[CAST($t1):DECIMAL(11, 1) NOT NULL], A=[$t1], A0=[$t2])
+ EnumerableValues(tuples=[[{ 1.0 }, { 2.0 }, { 3.0 }, { 4.0 }, { 5.0
}]])
+ EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) NOT
NULL], expr#2=[CAST($t1):DECIMAL(11, 1) NOT NULL], A=[$t1], A0=[$t2])
+ EnumerableValues(tuples=[[{ 1 }, { 2 }]])
+ EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)],
A=[$t1])
+ EnumerableValues(tuples=[[{ 1.0 }, { 4.0 }, { null }]])
+!plan
!set planner-rules original
+# [CALCITE-7000] Extend IntersectToSemiJoinRule to support n-way inputs
+select a from (values (1.0), (2.0), (3.0), (4.0), (5.0)) as t1 (a)
+intersect
+select a from (values (1), (2)) as t2 (a)
+intersect
+select a from (values (1.0), (4.0), (null)) as t3 (a);
++-----+
+| A |
++-----+
+| 1.0 |
++-----+
+(1 row)
+
+!ok
+
+EnumerableIntersect(all=[false])
+ EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) NOT
NULL], A=[$t1])
+ EnumerableValues(tuples=[[{ 1.0 }, { 2.0 }, { 3.0 }, { 4.0 }, { 5.0 }]])
+ EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) NOT
NULL], A=[$t1])
+ EnumerableValues(tuples=[[{ 1 }, { 2 }]])
+ EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], A=[$t1])
+ EnumerableValues(tuples=[[{ 1.0 }, { 4.0 }, { null }]])
+!plan
+
# Test predicate push down with/without expand disjunction.
with t1 (id1, col11, col12) as (values (1, 11, 111), (2, 12, 122), (3, 13,
133), (4, 14, 144), (5, 15, 155)),
t2 (id2, col21, col22) as (values (1, 21, 211), (2, 22, 222), (3, 23, 233),
(4, 24, 244), (5, 25, 255)),