This is an automated email from the ASF dual-hosted git repository.

zhenchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 3dc7feae24 [CALCITE-5740] Support for AggToSemiJoinRule
3dc7feae24 is described below

commit 3dc7feae24d3b2a39b3ca0a163de58df7f88c854
Author: Zhen Chen <[email protected]>
AuthorDate: Wed Jan 21 14:09:22 2026 +0800

    [CALCITE-5740] Support for AggToSemiJoinRule
---
 .../org/apache/calcite/rel/rules/CoreRules.java    |  6 ++
 .../org/apache/calcite/rel/rules/SemiJoinRule.java | 76 +++++++++++++++++++---
 .../org/apache/calcite/test/RelOptRulesTest.java   | 13 ++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 23 +++++++
 core/src/test/resources/sql/hep.iq                 | 30 +++++++++
 5 files changed, 140 insertions(+), 8 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java 
b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
index fbca06e760..21d5c971d6 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
@@ -162,6 +162,12 @@ private CoreRules() {}
   public static final AggregateJoinTransposeRule 
AGGREGATE_JOIN_TRANSPOSE_EXTENDED =
       AggregateJoinTransposeRule.Config.EXTENDED.toRule();
 
+  /** Rule that creates a {@link Join#isSemiJoin semi-join} from a
+   * {@link Aggregate} on top of a {@link Join} with an {@link Aggregate} as 
its
+   * right input. */
+  public static final SemiJoinRule.AggregateToSemiJoinRule 
AGGREGATE_TO_SEMI_JOIN =
+      
SemiJoinRule.AggregateToSemiJoinRule.AggregateToSemiJoinRuleConfig.DEFAULT.toRule();
+
   /** Rule that pushes an {@link Aggregate}
    * past a non-distinct {@link Union}. */
   public static final AggregateUnionTransposeRule AGGREGATE_UNION_TRANSPOSE =
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java 
b/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
index 4a10ec533a..427ea979bd 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/SemiJoinRule.java
@@ -67,13 +67,15 @@ protected SemiJoinRule(Config config) {
     super(config);
   }
 
-  protected void perform(RelOptRuleCall call, @Nullable Project project,
+  protected void perform(RelOptRuleCall call, @Nullable RelNode topRel,
       Join join, RelNode left, Aggregate aggregate) {
     final RelOptCluster cluster = join.getCluster();
     final RexBuilder rexBuilder = cluster.getRexBuilder();
-    if (project != null) {
-      final ImmutableBitSet bits =
-          RelOptUtil.InputFinder.bits(project.getProjects(), null);
+    if (topRel != null) {
+      final ImmutableBitSet bits = getUsedFields(topRel);
+      if (bits.isEmpty()) {
+        return;
+      }
       final ImmutableBitSet rightBits =
           ImmutableBitSet.range(left.getRowType().getFieldCount(),
               join.getRowType().getFieldCount());
@@ -123,13 +125,72 @@ protected void perform(RelOptRuleCall call, @Nullable 
Project project,
     default:
       throw new AssertionError(join.getJoinType());
     }
-    if (project != null) {
-      relBuilder.project(project.getProjects(), 
project.getRowType().getFieldNames());
+    if (topRel != null) {
+      if (topRel instanceof Project) {
+        Project topProject = (Project) topRel;
+        relBuilder.project(topProject.getProjects(), 
topProject.getRowType().getFieldNames());
+      } else if (topRel instanceof Aggregate) {
+        Aggregate topAgg = (Aggregate) topRel;
+        relBuilder.aggregate(
+            relBuilder.groupKey(topAgg.getGroupSet(), topAgg.getGroupSets()),
+            topAgg.getAggCallList());
+      }
     }
     final RelNode relNode = relBuilder.build();
     call.transformTo(relNode);
   }
 
+  /** Returns a bit set of the input fields used by a relational expression. */
+  private static ImmutableBitSet getUsedFields(RelNode rel) {
+    final RelMetadataQuery mq = rel.getCluster().getMetadataQuery();
+    return ImmutableBitSet.union(mq.getInputFieldsUsed(rel));
+  }
+
+  /** SemiJoinRule that matches a Aggregate on top of a Join with an Aggregate
+   * as its right child.
+   *
+   * @see CoreRules#AGGREGATE_TO_SEMI_JOIN */
+  public static class AggregateToSemiJoinRule extends SemiJoinRule {
+    /** Creates a AggregateToSemiJoinRule. */
+    protected AggregateToSemiJoinRule(AggregateToSemiJoinRuleConfig config) {
+      super(config);
+    }
+
+    @Override public void onMatch(RelOptRuleCall call) {
+      final Aggregate topAgg = call.rel(0);
+      final Join join = call.rel(1);
+      final RelNode left = call.rel(2);
+      final Aggregate rightAgg = call.rel(3);
+      perform(call, topAgg, join, left, rightAgg);
+    }
+
+    /** Rule configuration. */
+    @Value.Immutable
+    public interface AggregateToSemiJoinRuleConfig extends SemiJoinRule.Config 
{
+      AggregateToSemiJoinRuleConfig DEFAULT = 
ImmutableAggregateToSemiJoinRuleConfig.of()
+          .withDescription("SemiJoinRule:aggregate")
+          .withOperandFor(Aggregate.class, Join.class, Aggregate.class);
+
+      @Override default AggregateToSemiJoinRule toRule() {
+        return new AggregateToSemiJoinRule(this);
+      }
+
+      /** Defines an operand tree for the given classes. */
+      default AggregateToSemiJoinRuleConfig withOperandFor(
+          Class<? extends Aggregate> topAggClass,
+          Class<? extends Join> joinClass,
+          Class<? extends Aggregate> rightAggClass) {
+        return withOperandSupplier(b ->
+            b.operand(topAggClass).oneInput(b2 ->
+                b2.operand(joinClass)
+                    .predicate(SemiJoinRule::isJoinTypeSupported).inputs(
+                        b3 -> b3.operand(RelNode.class).anyInputs(),
+                        b4 -> b4.operand(rightAggClass).anyInputs())))
+            .as(AggregateToSemiJoinRuleConfig.class);
+      }
+    }
+  }
+
   /** SemiJoinRule that matches a Project on top of a Join with an Aggregate
    * as its right child.
    *
@@ -251,8 +312,7 @@ protected 
JoinOnUniqueToSemiJoinRule(JoinOnUniqueToSemiJoinRuleConfig config) {
       final Join join = call.rel(1);
       final RelNode left = call.rel(2);
 
-      final ImmutableBitSet bits =
-          RelOptUtil.InputFinder.bits(project.getProjects(), null);
+      final ImmutableBitSet bits = getUsedFields(project);
       final ImmutableBitSet rightBits =
           ImmutableBitSet.range(left.getRowType().getFieldCount(),
               join.getRowType().getFieldCount());
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java 
b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index f46030dd8b..ff0e5ac9d2 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -2098,6 +2098,19 @@ private void 
checkSemiOrAntiJoinProjectTranspose(JoinRelType type) {
         .check();
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-5740";>[CALCITE-5740]
+   * Support for AggToSemiJoinRule </a>. */
+  @Test void testAggregateToSemiJoinRule() {
+    final String sql = "select distinct emp.deptno from emp\n"
+        + "join (select distinct mgr from emp) d on emp.deptno = d.mgr";
+    sql(sql)
+        .withDecorrelate(true)
+        .withPreRule(CoreRules.AGGREGATE_PROJECT_MERGE)
+        .withRule(CoreRules.AGGREGATE_TO_SEMI_JOIN)
+        .check();
+  }
+
   /** Test case for
    * <a 
href="https://issues.apache.org/jira/browse/CALCITE-1495";>[CALCITE-1495]
    * SemiJoinRule should not apply to RIGHT and FULL JOIN</a>. */
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index c7c39f82a8..d33273b466 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -1130,6 +1130,29 @@ LogicalProject(MGR=[$0], SUM_SAL=[$2])
     LogicalAggregate(group=[{0, 1}], SUM_SAL=[SUM($2)])
       LogicalProject(MGR=[$3], DEPTNO=[$7], SAL=[$5])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testAggregateToSemiJoinRule">
+    <Resource name="sql">
+      <![CDATA[select distinct emp.deptno from emp
+join (select distinct mgr from emp) d on emp.deptno = d.mgr]]>
+    </Resource>
+    <Resource name="planBefore">
+      <![CDATA[
+LogicalAggregate(group=[{7}])
+  LogicalJoin(condition=[=($7, $9)], joinType=[inner])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{3}])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+    </Resource>
+    <Resource name="planAfter">
+      <![CDATA[
+LogicalAggregate(group=[{7}])
+  LogicalJoin(condition=[=($7, $12)], joinType=[semi])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
     </Resource>
   </TestCase>
diff --git a/core/src/test/resources/sql/hep.iq 
b/core/src/test/resources/sql/hep.iq
index 8dd5302347..6e1146c90d 100644
--- a/core/src/test/resources/sql/hep.iq
+++ b/core/src/test/resources/sql/hep.iq
@@ -238,4 +238,34 @@ EnumerableHashJoin(condition=[AND(=($0, $6), OR(AND(>($1, 
11), <=($7, 32)), AND(
 !}
 !set hep-rules original
 
+# [CALCITE-5740] Support for AggToSemiJoinRule
+!set hep-rules "
++CoreRules.AGGREGATE_PROJECT_MERGE,
++CoreRules.AGGREGATE_TO_SEMI_JOIN"
+
+select dept.deptno, count(*)
+from dept join (
+  select distinct deptno from emp
+  where sal > 100) using (deptno)
+group by dept.deptno;
++--------+--------+
+| DEPTNO | EXPR$1 |
++--------+--------+
+|     10 |      1 |
+|     20 |      1 |
+|     30 |      1 |
++--------+--------+
+(3 rows)
+
+!ok
+EnumerableAggregate(group=[{0}], EXPR$1=[COUNT()])
+  EnumerableHashJoin(condition=[=($0, $3)], joinType=[semi])
+    EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+      EnumerableTableScan(table=[[scott, DEPT]])
+    EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)], 
expr#9=[100.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], EMPNO=[$t0], SAL=[$t5], 
DEPTNO=[$t7], $condition=[$t10])
+      EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+!set hep-rules original
+
 # End hep.iq

Reply via email to