This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new c50ed0db23 Remove PinotAggregateToSemiJoinRule which can mistakenly 
remove DISTINCT from IN clause (#14719)
c50ed0db23 is described below

commit c50ed0db23a8b920bf7cf8fcbb293477344223b3
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Wed Dec 25 22:14:18 2024 -0800

    Remove PinotAggregateToSemiJoinRule which can mistakenly remove DISTINCT 
from IN clause (#14719)
---
 .../rel/rules/PinotAggregateToSemiJoinRule.java    | 132 ---------------------
 .../calcite/rel/rules/PinotQueryRuleSets.java      |   1 -
 .../src/test/resources/queries/JoinPlans.json      |  72 ++++++++++-
 .../test/resources/queries/PinotHintablePlans.json |  52 ++++++++
 4 files changed, 122 insertions(+), 135 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java
deleted file mode 100644
index 327921df71..0000000000
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java
+++ /dev/null
@@ -1,132 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.calcite.rel.rules;
-
-import java.util.ArrayList;
-import java.util.List;
-import javax.annotation.Nullable;
-import org.apache.calcite.plan.RelOptCluster;
-import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.plan.RelOptUtil;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.core.Aggregate;
-import org.apache.calcite.rel.core.Join;
-import org.apache.calcite.rel.core.JoinInfo;
-import org.apache.calcite.rel.rules.CoreRules;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.tools.RelBuilder;
-import org.apache.calcite.tools.RelBuilderFactory;
-import org.apache.calcite.util.ImmutableBitSet;
-import org.apache.calcite.util.ImmutableIntList;
-
-
-/**
- * SemiJoinRule that matches an Aggregate on top of a Join with an Aggregate 
as its right child.
- *
- * @see CoreRules#PROJECT_TO_SEMI_JOIN
- */
-public class PinotAggregateToSemiJoinRule extends RelOptRule {
-  public static final PinotAggregateToSemiJoinRule INSTANCE =
-      new PinotAggregateToSemiJoinRule(PinotRuleUtils.PINOT_REL_FACTORY);
-
-  public PinotAggregateToSemiJoinRule(RelBuilderFactory factory) {
-    super(operand(Aggregate.class,
-            some(operand(Join.class, some(operand(RelNode.class, any()), 
operand(Aggregate.class, any()))))), factory,
-        null);
-  }
-
-  @Override
-  public void onMatch(RelOptRuleCall call) {
-    final Aggregate topAgg = call.rel(0);
-    final Join join = (Join) PinotRuleUtils.unboxRel(topAgg.getInput());
-    final RelNode left = PinotRuleUtils.unboxRel(join.getInput(0));
-    final Aggregate rightAgg = (Aggregate) 
PinotRuleUtils.unboxRel(join.getInput(1));
-    perform(call, topAgg, join, left, rightAgg);
-  }
-
-
-  protected void perform(RelOptRuleCall call, @Nullable Aggregate topAgg,
-      Join join, RelNode left, Aggregate rightAgg) {
-    final RelOptCluster cluster = join.getCluster();
-    final RexBuilder rexBuilder = cluster.getRexBuilder();
-    if (topAgg != null) {
-      final ImmutableBitSet aggBits = 
ImmutableBitSet.of(RelOptUtil.getAllFields(topAgg));
-      final ImmutableBitSet rightBits =
-          ImmutableBitSet.range(left.getRowType().getFieldCount(),
-              join.getRowType().getFieldCount());
-      if (aggBits.intersects(rightBits)) {
-        return;
-      }
-    } else {
-      if (join.getJoinType().projectsRight()
-          && !isEmptyAggregate(rightAgg)) {
-        return;
-      }
-    }
-    final JoinInfo joinInfo = join.analyzeCondition();
-    if (!joinInfo.rightSet().equals(
-        ImmutableBitSet.range(rightAgg.getGroupCount()))) {
-      // Rule requires that aggregate key to be the same as the join key.
-      // By the way, neither a super-set nor a sub-set would work.
-      return;
-    }
-    if (!joinInfo.isEqui()) {
-      return;
-    }
-    final RelBuilder relBuilder = call.builder();
-    relBuilder.push(left);
-    switch (join.getJoinType()) {
-      case SEMI:
-      case INNER:
-        final List<Integer> newRightKeyBuilder = new ArrayList<>();
-        final List<Integer> aggregateKeys = rightAgg.getGroupSet().asList();
-        for (int key : joinInfo.rightKeys) {
-          newRightKeyBuilder.add(aggregateKeys.get(key));
-        }
-        final ImmutableIntList newRightKeys = 
ImmutableIntList.copyOf(newRightKeyBuilder);
-        relBuilder.push(rightAgg.getInput());
-        final RexNode newCondition =
-            RelOptUtil.createEquiJoinCondition(relBuilder.peek(2, 0),
-                joinInfo.leftKeys, relBuilder.peek(2, 1), newRightKeys,
-                rexBuilder);
-        relBuilder.semiJoin(newCondition).hints(join.getHints());
-        break;
-
-      case LEFT:
-        // The right-hand side produces no more than 1 row (because of the
-        // Aggregate) and no fewer than 1 row (because of LEFT), and therefore
-        // we can eliminate the semi-join.
-        break;
-
-      default:
-        throw new AssertionError(join.getJoinType());
-    }
-    if (topAgg != null) {
-      relBuilder.aggregate(relBuilder.groupKey(topAgg.getGroupSet()), 
topAgg.getAggCallList());
-    }
-    final RelNode relNode = relBuilder.build();
-    call.transformTo(relNode);
-  }
-
-  private static boolean isEmptyAggregate(Aggregate aggregate) {
-    return aggregate.getRowType().getFieldCount() == 0;
-  }
-}
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
index fdb75ee78f..e831e7460a 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java
@@ -73,7 +73,6 @@ public class PinotQueryRuleSets {
 
       // join and semi-join rules
       CoreRules.PROJECT_TO_SEMI_JOIN,
-      PinotAggregateToSemiJoinRule.INSTANCE,
 
       // convert non-all union into all-union + distinct
       CoreRules.UNION_TO_DISTINCT,
diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json 
b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
index fb63399fac..d48795dc30 100644
--- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
@@ -111,7 +111,7 @@
       },
       {
         "description": "Inner join with group by",
-        "sql": "EXPLAIN PLAN FOR SELECT a.col1, AVG(b.col3) FROM a JOIN b ON 
a.col1 = b.col2  WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY 
a.col1",
+        "sql": "EXPLAIN PLAN FOR SELECT a.col1, AVG(b.col3) FROM a JOIN b ON 
a.col1 = b.col2 WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY 
a.col1",
         "output": [
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, 
$2)])",
@@ -222,6 +222,21 @@
       },
       {
         "description": "Semi join with IN clause",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a WHERE col3 IN 
(SELECT col3 FROM b)",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$0], col2=[$1])",
+          "\n  LogicalJoin(condition=[=($2, $3)], joinType=[semi])",
+          "\n    LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n      LogicalTableScan(table=[[default, a]])",
+          "\n    PinotLogicalExchange(distribution=[broadcast], 
relExchangeType=[PIPELINE_BREAKER])",
+          "\n      LogicalProject(col3=[$2])",
+          "\n        LogicalTableScan(table=[[default, b]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "Semi join with IN clause and join strategy override",
         "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(join_strategy = 
'hash') */ col1, col2 FROM a WHERE col3 IN (SELECT col3 FROM b)",
         "output": [
           "Execution Plan",
@@ -237,7 +252,60 @@
         ]
       },
       {
-        "description": "Semi join with multiple IN clause",
+        "description": "Semi join with IN clause on distinct values",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a WHERE col3 IN 
(SELECT DISTINCT col3 FROM b)",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$0], col2=[$1])",
+          "\n  LogicalJoin(condition=[=($2, $3)], joinType=[semi])",
+          "\n    LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n      LogicalTableScan(table=[[default, a]])",
+          "\n    PinotLogicalExchange(distribution=[broadcast], 
relExchangeType=[PIPELINE_BREAKER])",
+          "\n      PinotLogicalAggregate(group=[{0}], aggType=[FINAL])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          PinotLogicalAggregate(group=[{2}], aggType=[LEAF])",
+          "\n            LogicalTableScan(table=[[default, b]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "Semi join with IN clause then aggregate with group by",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, SUM(col6) FROM a WHERE col3 IN 
(SELECT col3 FROM b) GROUP BY col1",
+        "output": [
+          "Execution Plan",
+          "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
aggType=[FINAL])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], 
aggType=[LEAF])",
+          "\n      LogicalJoin(condition=[=($1, $3)], joinType=[semi])",
+          "\n        LogicalProject(col1=[$0], col3=[$2], col6=[$5])",
+          "\n          LogicalTableScan(table=[[default, a]])",
+          "\n        PinotLogicalExchange(distribution=[broadcast], 
relExchangeType=[PIPELINE_BREAKER])",
+          "\n          LogicalProject(col3=[$2])",
+          "\n            LogicalTableScan(table=[[default, b]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "Semi join with IN clause of distinct values then 
aggregate with group by",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, SUM(col6) FROM a WHERE col3 IN 
(SELECT DISTINCT col3 FROM b) GROUP BY col1",
+        "output": [
+          "Execution Plan",
+          "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
aggType=[FINAL])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], 
aggType=[LEAF])",
+          "\n      LogicalJoin(condition=[=($1, $3)], joinType=[semi])",
+          "\n        LogicalProject(col1=[$0], col3=[$2], col6=[$5])",
+          "\n          LogicalTableScan(table=[[default, a]])",
+          "\n        PinotLogicalExchange(distribution=[broadcast], 
relExchangeType=[PIPELINE_BREAKER])",
+          "\n          PinotLogicalAggregate(group=[{0}], aggType=[FINAL])",
+          "\n            PinotLogicalExchange(distribution=[hash[0]])",
+          "\n              PinotLogicalAggregate(group=[{2}], aggType=[LEAF])",
+          "\n                LogicalTableScan(table=[[default, b]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "Semi join with multiple IN clause and join strategy 
override",
         "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(join_strategy = 
'hash') */ col1, col2 FROM a WHERE col2 = 'test' AND col3 IN (SELECT col3 FROM 
b WHERE col1='foo') AND col3 IN (SELECT col3 FROM b WHERE col1='bar') AND col3 
IN (SELECT col3 FROM b WHERE col1='foobar')",
         "output": [
           "Execution Plan",
diff --git 
a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json 
b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
index f26a133016..998bf05606 100644
--- a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
@@ -293,6 +293,58 @@
           "\n"
         ]
       },
+      {
+        "description": "agg + semi-join on colocated tables then group by on 
partition column with join and agg hint",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ 
joinOptions(is_colocated_by_join_keys='true'), 
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM 
a /*+ tableOptions(partition_function='hashcode', partition_key='col2', 
partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b /*+ 
tableOptions(partition_function='hashcode', partition_key='col1', 
partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1",
+        "output": [
+          "Execution Plan",
+          "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
aggType=[DIRECT])",
+          "\n  LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
+          "\n    LogicalProject(col2=[$1], col3=[$2])",
+          "\n      LogicalTableScan(table=[[default, a]])",
+          "\n    PinotLogicalExchange(distribution=[hash[0]], 
relExchangeType=[PIPELINE_BREAKER])",
+          "\n      LogicalProject(col1=[$0])",
+          "\n        LogicalFilter(condition=[>($2, 0)])",
+          "\n          LogicalTableScan(table=[[default, b]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "agg + semi-join with distinct values on colocated 
tables then group by on partition column",
+        "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a /*+ 
tableOptions(partition_function='hashcode', partition_key='col2', 
partition_size='4') */ WHERE a.col2 IN (SELECT DISTINCT col1 FROM b /*+ 
tableOptions(partition_function='hashcode', partition_key='col1', 
partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1",
+        "output": [
+          "Execution Plan",
+          "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
aggType=[FINAL])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
aggType=[LEAF])",
+          "\n      LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
+          "\n        LogicalProject(col2=[$1], col3=[$2])",
+          "\n          LogicalTableScan(table=[[default, a]])",
+          "\n        PinotLogicalExchange(distribution=[broadcast], 
relExchangeType=[PIPELINE_BREAKER])",
+          "\n          PinotLogicalAggregate(group=[{0}], aggType=[FINAL])",
+          "\n            PinotLogicalExchange(distribution=[hash[0]])",
+          "\n              PinotLogicalAggregate(group=[{0}], aggType=[LEAF])",
+          "\n                LogicalFilter(condition=[>($2, 0)])",
+          "\n                  LogicalTableScan(table=[[default, b]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "agg + semi-join with distinct values on colocated 
tables then group by on partition column with join and agg hint",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ 
joinOptions(is_colocated_by_join_keys='true'), 
aggOptions(is_partitioned_by_group_by_keys='true') */ a.col2, SUM(a.col3) FROM 
a /*+ tableOptions(partition_function='hashcode', partition_key='col2', 
partition_size='4') */ WHERE a.col2 IN (SELECT DISTINCT col1 FROM b /*+ 
tableOptions(partition_function='hashcode', partition_key='col1', 
partition_size='4') */ WHERE b.col3 > 0) GROUP BY 1",
+        "output": [
+          "Execution Plan",
+          "\nPinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], 
aggType=[DIRECT])",
+          "\n  LogicalJoin(condition=[=($0, $2)], joinType=[semi])",
+          "\n    LogicalProject(col2=[$1], col3=[$2])",
+          "\n      LogicalTableScan(table=[[default, a]])",
+          "\n    PinotLogicalExchange(distribution=[hash[0]], 
relExchangeType=[PIPELINE_BREAKER])",
+          "\n      PinotLogicalAggregate(group=[{0}], aggType=[DIRECT])",
+          "\n        LogicalFilter(condition=[>($2, 0)])",
+          "\n          LogicalTableScan(table=[[default, b]])",
+          "\n"
+        ]
+      },
       {
         "description": "agg + semi-join on pre-partitioned main tables then 
group by on partition column",
         "sql": "EXPLAIN PLAN FOR SELECT a.col2, SUM(a.col3) FROM a /*+ 
tableOptions(partition_function='hashcode', partition_key='col2', 
partition_size='4') */ WHERE a.col2 IN (SELECT col1 FROM b WHERE b.col3 > 0) 
GROUP BY 1",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to