This is an automated email from the ASF dual-hosted git repository.
zhenchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 3dc7feae24 [CALCITE-5740] Support for AggToSemiJoinRule
3dc7feae24 is described below
commit 3dc7feae24d3b2a39b3ca0a163de58df7f88c854
Author: Zhen Chen <[email protected]>
AuthorDate: Wed Jan 21 14:09:22 2026 +0800
[CALCITE-5740] Support for AggToSemiJoinRule
---
.../org/apache/calcite/rel/rules/CoreRules.java | 6 ++
.../org/apache/calcite/rel/rules/SemiJoinRule.java | 76 +++++++++++++++++++---
.../org/apache/calcite/test/RelOptRulesTest.java | 13 ++++
.../org/apache/calcite/test/RelOptRulesTest.xml | 23 +++++++
core/src/test/resources/sql/hep.iq | 30 +++++++++
5 files changed, 140 insertions(+), 8 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
index fbca06e760..21d5c971d6 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
@@ -162,6 +162,12 @@ private CoreRules() {}
public static final AggregateJoinTransposeRule
AGGREGATE_JOIN_TRANSPOSE_EXTENDED =
AggregateJoinTransposeRule.Config.EXTENDED.toRule();
+ /** Rule that creates a {@link Join#isSemiJoin semi-join} from a
+ * {@link Aggregate} on top of a {@link Join} with an {@link Aggregate} as
its
+ * right input. */
+ public static final SemiJoinRule.AggregateToSemiJoinRule
AGGREGATE_TO_SEMI_JOIN =
+
SemiJoinRule.AggregateToSemiJoinRule.AggregateToSemiJoinRuleConfig.DEFAULT.toRule();
+
/** Rule that pushes an {@link Aggregate}
* past a non-distinct {@link Union}. */
public static final AggregateUnionTransposeRule AGGREGATE_UNION_TRANSPOSE =
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
index 4a10ec533a..427ea979bd 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
@@ -67,13 +67,15 @@ protected SemiJoinRule(Config config) {
super(config);
}
- protected void perform(RelOptRuleCall call, @Nullable Project project,
+ protected void perform(RelOptRuleCall call, @Nullable RelNode topRel,
Join join, RelNode left, Aggregate aggregate) {
final RelOptCluster cluster = join.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
- if (project != null) {
- final ImmutableBitSet bits =
- RelOptUtil.InputFinder.bits(project.getProjects(), null);
+ if (topRel != null) {
+ final ImmutableBitSet bits = getUsedFields(topRel);
+ if (bits.isEmpty()) {
+ return;
+ }
final ImmutableBitSet rightBits =
ImmutableBitSet.range(left.getRowType().getFieldCount(),
join.getRowType().getFieldCount());
@@ -123,13 +125,72 @@ protected void perform(RelOptRuleCall call, @Nullable
Project project,
default:
throw new AssertionError(join.getJoinType());
}
- if (project != null) {
- relBuilder.project(project.getProjects(),
project.getRowType().getFieldNames());
+ if (topRel != null) {
+ if (topRel instanceof Project) {
+ Project topProject = (Project) topRel;
+ relBuilder.project(topProject.getProjects(),
topProject.getRowType().getFieldNames());
+ } else if (topRel instanceof Aggregate) {
+ Aggregate topAgg = (Aggregate) topRel;
+ relBuilder.aggregate(
+ relBuilder.groupKey(topAgg.getGroupSet(), topAgg.getGroupSets()),
+ topAgg.getAggCallList());
+ }
}
final RelNode relNode = relBuilder.build();
call.transformTo(relNode);
}
+ /** Returns a bit set of the input fields used by a relational expression. */
+ private static ImmutableBitSet getUsedFields(RelNode rel) {
+ final RelMetadataQuery mq = rel.getCluster().getMetadataQuery();
+ return ImmutableBitSet.union(mq.getInputFieldsUsed(rel));
+ }
+
+ /** SemiJoinRule that matches a Aggregate on top of a Join with an Aggregate
+ * as its right child.
+ *
+ * @see CoreRules#AGGREGATE_TO_SEMI_JOIN */
+ public static class AggregateToSemiJoinRule extends SemiJoinRule {
+ /** Creates a AggregateToSemiJoinRule. */
+ protected AggregateToSemiJoinRule(AggregateToSemiJoinRuleConfig config) {
+ super(config);
+ }
+
+ @Override public void onMatch(RelOptRuleCall call) {
+ final Aggregate topAgg = call.rel(0);
+ final Join join = call.rel(1);
+ final RelNode left = call.rel(2);
+ final Aggregate rightAgg = call.rel(3);
+ perform(call, topAgg, join, left, rightAgg);
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable
+ public interface AggregateToSemiJoinRuleConfig extends SemiJoinRule.Config
{
+ AggregateToSemiJoinRuleConfig DEFAULT =
ImmutableAggregateToSemiJoinRuleConfig.of()
+ .withDescription("SemiJoinRule:aggregate")
+ .withOperandFor(Aggregate.class, Join.class, Aggregate.class);
+
+ @Override default AggregateToSemiJoinRule toRule() {
+ return new AggregateToSemiJoinRule(this);
+ }
+
+ /** Defines an operand tree for the given classes. */
+ default AggregateToSemiJoinRuleConfig withOperandFor(
+ Class<? extends Aggregate> topAggClass,
+ Class<? extends Join> joinClass,
+ Class<? extends Aggregate> rightAggClass) {
+ return withOperandSupplier(b ->
+ b.operand(topAggClass).oneInput(b2 ->
+ b2.operand(joinClass)
+ .predicate(SemiJoinRule::isJoinTypeSupported).inputs(
+ b3 -> b3.operand(RelNode.class).anyInputs(),
+ b4 -> b4.operand(rightAggClass).anyInputs())))
+ .as(AggregateToSemiJoinRuleConfig.class);
+ }
+ }
+ }
+
/** SemiJoinRule that matches a Project on top of a Join with an Aggregate
* as its right child.
*
@@ -251,8 +312,7 @@ protected
JoinOnUniqueToSemiJoinRule(JoinOnUniqueToSemiJoinRuleConfig config) {
final Join join = call.rel(1);
final RelNode left = call.rel(2);
- final ImmutableBitSet bits =
- RelOptUtil.InputFinder.bits(project.getProjects(), null);
+ final ImmutableBitSet bits = getUsedFields(project);
final ImmutableBitSet rightBits =
ImmutableBitSet.range(left.getRowType().getFieldCount(),
join.getRowType().getFieldCount());
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index f46030dd8b..ff0e5ac9d2 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -2098,6 +2098,19 @@ private void
checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
.check();
}
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-5740">[CALCITE-5740]
+ * Support for AggToSemiJoinRule </a>. */
+ @Test void testAggregateToSemiJoinRule() {
+ final String sql = "select distinct emp.deptno from emp\n"
+ + "join (select distinct mgr from emp) d on emp.deptno = d.mgr";
+ sql(sql)
+ .withDecorrelate(true)
+ .withPreRule(CoreRules.AGGREGATE_PROJECT_MERGE)
+ .withRule(CoreRules.AGGREGATE_TO_SEMI_JOIN)
+ .check();
+ }
+
/** Test case for
* <a
href="https://issues.apache.org/jira/browse/CALCITE-1495">[CALCITE-1495]
* SemiJoinRule should not apply to RIGHT and FULL JOIN</a>. */
diff --git
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index c7c39f82a8..d33273b466 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -1130,6 +1130,29 @@ LogicalProject(MGR=[$0], SUM_SAL=[$2])
LogicalAggregate(group=[{0, 1}], SUM_SAL=[SUM($2)])
LogicalProject(MGR=[$3], DEPTNO=[$7], SAL=[$5])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testAggregateToSemiJoinRule">
+ <Resource name="sql">
+ <![CDATA[select distinct emp.deptno from emp
+join (select distinct mgr from emp) d on emp.deptno = d.mgr]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalAggregate(group=[{7}])
+ LogicalJoin(condition=[=($7, $9)], joinType=[inner])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalAggregate(group=[{3}])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalAggregate(group=[{7}])
+ LogicalJoin(condition=[=($7, $12)], joinType=[semi])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
diff --git a/core/src/test/resources/sql/hep.iq
b/core/src/test/resources/sql/hep.iq
index 8dd5302347..6e1146c90d 100644
--- a/core/src/test/resources/sql/hep.iq
+++ b/core/src/test/resources/sql/hep.iq
@@ -238,4 +238,34 @@ EnumerableHashJoin(condition=[AND(=($0, $6), OR(AND(>($1,
11), <=($7, 32)), AND(
!}
!set hep-rules original
+# [CALCITE-5740] Support for AggToSemiJoinRule
+!set hep-rules "
++CoreRules.AGGREGATE_PROJECT_MERGE,
++CoreRules.AGGREGATE_TO_SEMI_JOIN"
+
+select dept.deptno, count(*)
+from dept join (
+ select distinct deptno from emp
+ where sal > 100) using (deptno)
+group by dept.deptno;
++--------+--------+
+| DEPTNO | EXPR$1 |
++--------+--------+
+| 10 | 1 |
+| 20 | 1 |
+| 30 | 1 |
++--------+--------+
+(3 rows)
+
+!ok
+EnumerableAggregate(group=[{0}], EXPR$1=[COUNT()])
+ EnumerableHashJoin(condition=[=($0, $3)], joinType=[semi])
+ EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+ EnumerableTableScan(table=[[scott, DEPT]])
+ EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)],
expr#9=[100.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], EMPNO=[$t0], SAL=[$t5],
DEPTNO=[$t7], $condition=[$t10])
+ EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+!set hep-rules original
+
# End hep.iq