This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 11b6bcda50 [multistage] Fix Predicate Pushdown by Using Rule 
Collection (#10409)
11b6bcda50 is described below

commit 11b6bcda50469898e96a03aa6a0c92bbcbe6463c
Author: Ankit Sultana <[email protected]>
AuthorDate: Wed Mar 15 23:02:43 2023 +0530

    [multistage] Fix Predicate Pushdown by Using Rule Collection (#10409)
    
    * [multistage] Fix Predicate Pushdow by Using Rule Collection
    
    * Fix tests && minor refactors
    * Add tests for IN/Not-In
    * Address feedback
---
 .../calcite/rel/rules/PinotQueryRuleSets.java      | 66 ++++++++-------
 .../org/apache/pinot/query/QueryEnvironment.java   | 25 ++++--
 .../src/test/resources/queries/AggregatePlans.json | 60 +++++++-------
 .../src/test/resources/queries/GroupByPlans.json   | 96 ++++++++++------------
 .../src/test/resources/queries/JoinPlans.json      | 74 +++++++++++++++++
 .../src/test/resources/queries/Case.json           |  6 +-
 6 files changed, 204 insertions(+), 123 deletions(-)

diff --git 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
index 06b70d4e04..b393f98417 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
@@ -18,6 +18,7 @@
  */
 package org.apache.calcite.rel.rules;
 
+import com.google.common.collect.ImmutableList;
 import java.util.Arrays;
 import java.util.Collection;
 import org.apache.calcite.adapter.enumerable.EnumerableRules;
@@ -32,14 +33,11 @@ public class PinotQueryRuleSets {
     // do not instantiate.
   }
 
-  public static final Collection<RelOptRule> LOGICAL_OPT_RULES =
+  public static final Collection<RelOptRule> BASIC_RULES =
       Arrays.asList(EnumerableRules.ENUMERABLE_FILTER_RULE, 
EnumerableRules.ENUMERABLE_JOIN_RULE,
           EnumerableRules.ENUMERABLE_PROJECT_RULE, 
EnumerableRules.ENUMERABLE_WINDOW_RULE,
           EnumerableRules.ENUMERABLE_SORT_RULE, 
EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE,
 
-          // ------------------------------------------------------------------
-          // Calcite core rules
-
           // push a filter into a join
           CoreRules.FILTER_INTO_JOIN,
           // push filter through an aggregation
@@ -90,29 +88,39 @@ public class PinotQueryRuleSets {
           PinotReduceAggregateFunctionsRule.INSTANCE,
           CoreRules.AGGREGATE_REDUCE_FUNCTIONS,
 
-          // remove unnecessary sort rule
-          CoreRules.SORT_REMOVE,
-
-          // prune empty results rules
-          PruneEmptyRules.AGGREGATE_INSTANCE, PruneEmptyRules.FILTER_INSTANCE, 
PruneEmptyRules.JOIN_LEFT_INSTANCE,
-          PruneEmptyRules.JOIN_RIGHT_INSTANCE, 
PruneEmptyRules.PROJECT_INSTANCE, PruneEmptyRules.SORT_INSTANCE,
-          PruneEmptyRules.UNION_INSTANCE,
-
-          // ------------------------------------------------------------------
-          // Pinot specific rules
-          // ------------------------------------------------------------------
-
-          // ---- rules apply before exchange insertion.
-          PinotFilterExpandSearchRule.INSTANCE,
-
-          // ---- rules that insert exchange.
-          // add an extra exchange for sort
-          PinotSortExchangeNodeInsertRule.INSTANCE,
-          // copy exchanges down, this must be done after 
SortExchangeNodeInsertRule
-          PinotSortExchangeCopyRule.SORT_EXCHANGE_COPY,
-
-          PinotJoinExchangeNodeInsertRule.INSTANCE,
-          PinotAggregateExchangeNodeInsertRule.INSTANCE,
-          PinotWindowExchangeNodeInsertRule.INSTANCE
-      );
+          // Expand all SEARCH nodes to simplified filter nodes. SEARCH nodes 
get created for queries with range
+          // predicates, in-clauses, etc.
+          PinotFilterExpandSearchRule.INSTANCE
+          );
+
+  // Filter pushdown rules run using a RuleCollection since we want to push 
down a filter as much as possible in a
+  // single HepInstruction.
+  public static final Collection<RelOptRule> FILTER_PUSHDOWN_RULES = 
ImmutableList.of(
+      CoreRules.FILTER_INTO_JOIN,
+      CoreRules.FILTER_AGGREGATE_TRANSPOSE,
+      CoreRules.FILTER_SET_OP_TRANSPOSE,
+      CoreRules.FILTER_PROJECT_TRANSPOSE
+  );
+
+  // The pruner rules run top-down to ensure Calcite restarts from root node 
after applying a transformation.
+  public static final Collection<RelOptRule> PRUNE_RULES = ImmutableList.of(
+      CoreRules.PROJECT_MERGE,
+      CoreRules.AGGREGATE_REMOVE,
+      CoreRules.SORT_REMOVE,
+      PruneEmptyRules.AGGREGATE_INSTANCE, PruneEmptyRules.FILTER_INSTANCE, 
PruneEmptyRules.JOIN_LEFT_INSTANCE,
+      PruneEmptyRules.JOIN_RIGHT_INSTANCE, PruneEmptyRules.PROJECT_INSTANCE, 
PruneEmptyRules.SORT_INSTANCE,
+      PruneEmptyRules.UNION_INSTANCE
+  );
+
+  // Pinot specific rules that should be run after all other rules
+  public static final Collection<RelOptRule> PINOT_POST_RULES = 
ImmutableList.of(
+      // add an extra exchange for sort
+      PinotSortExchangeNodeInsertRule.INSTANCE,
+      // copy exchanges down, this must be done after 
SortExchangeNodeInsertRule
+      PinotSortExchangeCopyRule.SORT_EXCHANGE_COPY,
+
+      PinotJoinExchangeNodeInsertRule.INSTANCE,
+      PinotAggregateExchangeNodeInsertRule.INSTANCE,
+      PinotWindowExchangeNodeInsertRule.INSTANCE
+  );
 }
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
index 936c25460d..472295a51c 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java
@@ -20,7 +20,6 @@ package org.apache.pinot.query;
 
 import com.google.common.annotations.VisibleForTesting;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.Properties;
 import org.apache.calcite.config.CalciteConnectionConfigImpl;
 import org.apache.calcite.config.CalciteConnectionProperty;
@@ -28,6 +27,7 @@ import org.apache.calcite.jdbc.CalciteSchema;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.hep.HepMatchOrder;
 import org.apache.calcite.plan.hep.HepProgram;
 import org.apache.calcite.plan.hep.HepProgramBuilder;
 import org.apache.calcite.prepare.PinotCalciteCatalogReader;
@@ -79,7 +79,6 @@ public class QueryEnvironment {
   private final HepProgram _hepProgram;
 
   // Pinot extensions
-  private final Collection<RelOptRule> _logicalRuleSet;
   private final WorkerManager _workerManager;
   private final TableCache _tableCache;
 
@@ -110,12 +109,24 @@ public class QueryEnvironment {
             .addRelBuilderConfigTransform(c -> c.withAggregateUnique(true)))
         .build();
 
-    // optimizer rules
-    _logicalRuleSet = PinotQueryRuleSets.LOGICAL_OPT_RULES;
-
-    // optimizer
     HepProgramBuilder hepProgramBuilder = new HepProgramBuilder();
-    for (RelOptRule relOptRule : _logicalRuleSet) {
+    // Set the match order as DEPTH_FIRST. The default is arbitrary which 
works the same as DEPTH_FIRST, but it's
+    // best to be explicit.
+    hepProgramBuilder.addMatchOrder(HepMatchOrder.DEPTH_FIRST);
+    // First run the basic rules using 1 HepInstruction per rule. We use 1 
HepInstruction per rule for simplicity:
+    // the rules used here can rest assured that they are the only ones 
evaluated in a dedicated graph-traversal.
+    for (RelOptRule relOptRule : PinotQueryRuleSets.BASIC_RULES) {
+      hepProgramBuilder.addRuleInstance(relOptRule);
+    }
+    // Pushdown filters using a single HepInstruction.
+    
hepProgramBuilder.addRuleCollection(PinotQueryRuleSets.FILTER_PUSHDOWN_RULES);
+
+    // Prune duplicate/unnecessary nodes using a single HepInstruction.
+    // TODO: We can consider using HepMatchOrder.TOP_DOWN if we find cases 
where it would help.
+    hepProgramBuilder.addRuleCollection(PinotQueryRuleSets.PRUNE_RULES);
+
+    // Run pinot specific rules that should run after all other rules, using 1 
HepInstruction per rule.
+    for (RelOptRule relOptRule : PinotQueryRuleSets.PINOT_POST_RULES) {
       hepProgramBuilder.addRuleInstance(relOptRule);
     }
     _hepProgram = hepProgramBuilder.build();
diff --git a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json 
b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
index 3e7b7168dd..6171c79be6 100644
--- a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
@@ -6,14 +6,13 @@
         "sql": "EXPLAIN PLAN FOR SELECT AVG(a.col4) as avg FROM a WHERE a.col3 
>= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(avg=[/($0, $1)])",
-          "\n  LogicalProject($f0=[CASE(=($1, 0), null:DECIMAL(1000, 0), $0)], 
$f1=[$1])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)])",
-          "\n      LogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[COUNT()])",
-          "\n          LogicalProject(col4=[$0], col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink 
floyd'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\nLogicalProject(avg=[/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), 
$1)])",
+          "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)])",
+          "\n    LogicalExchange(distribution=[hash])",
+          "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[COUNT()])",
+          "\n        LogicalProject(col4=[$0], col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink 
floyd'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -22,14 +21,13 @@
         "sql": "EXPLAIN PLAN FOR SELECT AVG(a.col4) as avg, SUM(a.col4) as 
sum, MAX(a.col4) as max FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(avg=[/($0, $1)], sum=[CASE(=($1, 0), 
null:DECIMAL(1000, 0), $2)], max=[$3])",
-          "\n  LogicalProject($f0=[CASE(=($1, 0), null:DECIMAL(1000, 0), $0)], 
$f1=[$1], sum=[$0], max=[$2])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)], max=[MAX($2)])",
-          "\n      LogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[COUNT()], max=[MAX($0)])",
-          "\n          LogicalProject(col4=[$0], col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink 
floyd'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\nLogicalProject(avg=[/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), 
$1)], sum=[CASE(=($1, 0), null:DECIMAL(1000, 0), $0)], max=[$2])",
+          "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)], max=[MAX($2)])",
+          "\n    LogicalExchange(distribution=[hash])",
+          "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[COUNT()], max=[MAX($0)])",
+          "\n        LogicalProject(col4=[$0], col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink 
floyd'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -38,14 +36,13 @@
         "sql": "EXPLAIN PLAN FOR SELECT AVG(a.col3) as avg, COUNT(*) as count 
FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(avg=[/(CAST($0):DOUBLE, $1)], count=[$1])",
-          "\n  LogicalProject($f0=[CASE(=($1, 0), null:INTEGER, $0)], 
$f1=[$1])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)])",
-          "\n      LogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], 
agg#1=[COUNT()])",
-          "\n          LogicalProject(col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink 
floyd'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\nLogicalProject(avg=[/(CAST(CASE(=($1, 0), null:INTEGER, 
$0)):DOUBLE, $1)], count=[$1])",
+          "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)])",
+          "\n    LogicalExchange(distribution=[hash])",
+          "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], 
agg#1=[COUNT()])",
+          "\n        LogicalProject(col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink 
floyd'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -97,14 +94,13 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
AVG(a.col3) as avg, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 
'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(avg=[/(CAST($0):DOUBLE, $1)], count=[$1])",
-          "\n  LogicalProject($f0=[CASE(=($1, 0), null:INTEGER, $0)], 
$f1=[$1])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)])",
-          "\n      LogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], 
agg#1=[COUNT()])",
-          "\n          LogicalProject(col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink 
floyd'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\nLogicalProject(avg=[/(CAST(CASE(=($1, 0), null:INTEGER, 
$0)):DOUBLE, $1)], count=[$1])",
+          "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], 
agg#1=[$SUM0($1)])",
+          "\n    LogicalExchange(distribution=[hash])",
+          "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], 
agg#1=[COUNT()])",
+          "\n        LogicalProject(col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink 
floyd'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json 
b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
index 588b164b92..d1b6f5c811 100644
--- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
@@ -18,12 +18,11 @@
         "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3), AVG(a.col3), 
MAX(a.col3), MIN(a.col3) FROM a GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($2):DOUBLE 
NOT NULL, $3)], EXPR$3=[$4], EXPR$4=[$5])",
-          "\n  LogicalProject(col1=[$0], EXPR$1=[$1], $f2=[$1], $f3=[$2], 
EXPR$3=[$3], EXPR$4=[$4])",
-          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[$SUM0($2)], EXPR$3=[MAX($3)], EXPR$4=[MIN($4)])",
-          "\n      LogicalExchange(distribution=[hash[0]])",
-          "\n        LogicalAggregate(group=[{3}], EXPR$1=[$SUM0($2)], 
agg#1=[COUNT()], EXPR$3=[MAX($2)], EXPR$4=[MIN($2)])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE 
NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])",
+          "\n  LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[$SUM0($2)], EXPR$3=[MAX($3)], EXPR$4=[MIN($4)])",
+          "\n    LogicalExchange(distribution=[hash[0]])",
+          "\n      LogicalAggregate(group=[{3}], EXPR$1=[$SUM0($2)], 
agg#1=[COUNT()], EXPR$3=[MAX($2)], EXPR$4=[MIN($2)])",
+          "\n        LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -76,15 +75,13 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =($5, 5))])",
-          "\n    LogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2], $f3=[$3], 
$f4=[$4], $f5=[/(CAST($5):DOUBLE NOT NULL, $1)])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2], 
$f3=[$3], $f4=[$4], $f5=[$2])",
-          "\n        LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
EXPR$2=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n          LogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{2}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
-          "\n              LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
-          "\n                LogicalFilter(condition=[AND(>=($2, 0), =($1, 
'a'))])",
-          "\n                  LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
EXPR$2=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
+          "\n      LogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalAggregate(group=[{2}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+          "\n          LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
+          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 
'a'))])",
+          "\n              LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -94,15 +91,13 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =($5, 5))])",
-          "\n    LogicalProject(col1=[$0], count=[$1], SUM=[$2], $f3=[$3], 
$f4=[$4], $f5=[/(CAST($5):DOUBLE NOT NULL, $1)])",
-          "\n      LogicalProject(col1=[$0], count=[$1], SUM=[$2], $f3=[$3], 
$f4=[$4], $f5=[$2])",
-          "\n        LogicalAggregate(group=[{0}], count=[$SUM0($1)], 
SUM=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n          LogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{2}], count=[COUNT()], 
SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
-          "\n              LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
-          "\n                LogicalFilter(condition=[AND(>=($2, 0), =($1, 
'a'))])",
-          "\n                  LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+          "\n    LogicalAggregate(group=[{0}], count=[$SUM0($1)], 
SUM=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
+          "\n      LogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalAggregate(group=[{2}], count=[COUNT()], 
SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+          "\n          LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
+          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 
'a'))])",
+          "\n              LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -136,12 +131,11 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ skipLeafStageGroupByAggregation */ 
a.col1, SUM(a.col3), AVG(a.col3), MAX(a.col3), MIN(a.col3) FROM a GROUP BY 
a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($2):DOUBLE 
NOT NULL, $3)], EXPR$3=[$4], EXPR$4=[$5])",
-          "\n  LogicalProject(col1=[$0], EXPR$1=[$1], $f2=[$1], $f3=[$2], 
EXPR$3=[$3], EXPR$4=[$4])",
-          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])",
-          "\n      LogicalExchange(distribution=[hash[0]])",
-          "\n        LogicalProject(col1=[$3], col3=[$2])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE 
NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])",
+          "\n  LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], 
agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])",
+          "\n    LogicalExchange(distribution=[hash[0]])",
+          "\n      LogicalProject(col1=[$3], col3=[$2])",
+          "\n        LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -204,14 +198,12 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =($5, 5))])",
-          "\n    LogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2], $f3=[$3], 
$f4=[$4], $f5=[/(CAST($5):DOUBLE NOT NULL, $1)])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2], 
$f3=[$3], $f4=[$4], $f5=[$2])",
-          "\n        LogicalAggregate(group=[{2}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
-          "\n          LogicalExchange(distribution=[hash[2]])",
-          "\n            LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
-          "\n              LogicalFilter(condition=[AND(>=($2, 0), =($1, 
'a'))])",
-          "\n                LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+          "\n    LogicalAggregate(group=[{2}], EXPR$1=[COUNT()], 
EXPR$2=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+          "\n      LogicalExchange(distribution=[hash[2]])",
+          "\n        LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -221,14 +213,12 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1])",
-          "\n  LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), 
=($4, 5))])",
-          "\n    LogicalProject(col1=[$0], EXPR$1=[$1], $f2=[$2], $f3=[$3], 
$f4=[/(CAST($4):DOUBLE NOT NULL, $5)])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$1], $f2=[$2], $f3=[$3], 
$f4=[$1], $f5=[$4])",
-          "\n        LogicalAggregate(group=[{2}], EXPR$1=[$SUM0($1)], 
agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])",
-          "\n          LogicalExchange(distribution=[hash[2]])",
-          "\n            LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
-          "\n              LogicalFilter(condition=[AND(>=($2, 0), =($1, 
'a'))])",
-          "\n                LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), 
=(/(CAST($1):DOUBLE NOT NULL, $4), 5))])",
+          "\n    LogicalAggregate(group=[{2}], EXPR$1=[$SUM0($1)], 
agg#1=[MAX($1)], agg#2=[MIN($1)], agg#3=[COUNT()])",
+          "\n      LogicalExchange(distribution=[hash[2]])",
+          "\n        LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -238,14 +228,12 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =($5, 5))])",
-          "\n    LogicalProject(col1=[$0], count=[$1], SUM=[$2], $f3=[$3], 
$f4=[$4], $f5=[/(CAST($5):DOUBLE NOT NULL, $1)])",
-          "\n      LogicalProject(col1=[$0], count=[$1], SUM=[$2], $f3=[$3], 
$f4=[$4], $f5=[$2])",
-          "\n        LogicalAggregate(group=[{2}], count=[COUNT()], 
SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
-          "\n          LogicalExchange(distribution=[hash[2]])",
-          "\n            LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
-          "\n              LogicalFilter(condition=[AND(>=($2, 0), =($1, 
'a'))])",
-          "\n                LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), 
<=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+          "\n    LogicalAggregate(group=[{2}], count=[COUNT()], 
SUM=[$SUM0($1)], agg#2=[MAX($1)], agg#3=[MIN($1)])",
+          "\n      LogicalExchange(distribution=[hash[2]])",
+          "\n        LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       }
diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json 
b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
index aef530a060..ab57467c57 100644
--- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
@@ -236,6 +236,80 @@
           "\n      LogicalTableScan(table=[[b]])",
           "\n"
         ]
+      },
+      {
+        "description": "Semi join with multiple IN clause",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a WHERE col2 = 'test' 
AND col3 IN (SELECT col3 FROM b WHERE col1='foo') AND col3 IN (SELECT col3 FROM 
b WHERE col1='bar') AND col3 IN (SELECT col3 FROM b WHERE col1='foobar')",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$2], col2=[$0])",
+          "\n  LogicalJoin(condition=[=($1, $3)], joinType=[semi])",
+          "\n    LogicalExchange(distribution=[hash[1]])",
+          "\n      LogicalJoin(condition=[=($1, $3)], joinType=[semi])",
+          "\n        LogicalExchange(distribution=[hash[1]])",
+          "\n          LogicalJoin(condition=[=($1, $3)], joinType=[semi])",
+          "\n            LogicalExchange(distribution=[hash[1]])",
+          "\n              LogicalProject(col2=[$1], col3=[$2], col1=[$3])",
+          "\n                LogicalFilter(condition=[=($1, 'test')])",
+          "\n                  LogicalTableScan(table=[[a]])",
+          "\n            LogicalExchange(distribution=[hash[0]])",
+          "\n              LogicalProject(col3=[$2], col1=[$3])",
+          "\n                LogicalFilter(condition=[=($3, 'foo')])",
+          "\n                  LogicalTableScan(table=[[b]])",
+          "\n        LogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalProject(col3=[$2], col1=[$3])",
+          "\n            LogicalFilter(condition=[=($3, 'bar')])",
+          "\n              LogicalTableScan(table=[[b]])",
+          "\n    LogicalExchange(distribution=[hash[0]])",
+          "\n      LogicalProject(col3=[$2], col1=[$3])",
+          "\n        LogicalFilter(condition=[=($3, 'foobar')])",
+          "\n          LogicalTableScan(table=[[b]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "Semi join with multiple NOT IN clause",
+        "sql": "EXPLAIN PLAN FOR SELECT col1, col2 FROM a WHERE col2 = 'test' 
AND col3 NOT IN (SELECT col3 FROM b WHERE col1='foo') AND col3 NOT IN (SELECT 
col3 FROM b WHERE col1='bar') AND col3 NOT IN (SELECT col3 FROM b WHERE 
col1='foobar')",
+        "output": [
+          "Execution Plan",
+          "\nLogicalProject(col1=[$1], col2=[$0])",
+          "\n  LogicalFilter(condition=[IS NOT TRUE($8)])",
+          "\n    LogicalJoin(condition=[=($6, $7)], joinType=[left])",
+          "\n      LogicalExchange(distribution=[hash[6]])",
+          "\n        LogicalProject(col2=[$0], col1=[$2], col30=[$3], 
$f1=[$4], col32=[$5], $f10=[$7], col34=[$1])",
+          "\n          LogicalFilter(condition=[IS NOT TRUE($7)])",
+          "\n            LogicalJoin(condition=[=($5, $6)], joinType=[left])",
+          "\n              LogicalExchange(distribution=[hash[5]])",
+          "\n                LogicalProject(col2=[$0], col3=[$1], col1=[$2], 
col30=[$3], $f1=[$5], col32=[$1])",
+          "\n                  LogicalFilter(condition=[IS NOT TRUE($5)])",
+          "\n                    LogicalJoin(condition=[=($3, $4)], 
joinType=[left])",
+          "\n                      LogicalExchange(distribution=[hash[3]])",
+          "\n                        LogicalProject(col2=[$1], col3=[$2], 
col1=[$3], col30=[$2])",
+          "\n                          LogicalFilter(condition=[=($1, 
'test')])",
+          "\n                            LogicalTableScan(table=[[a]])",
+          "\n                      LogicalExchange(distribution=[hash[0]])",
+          "\n                        LogicalAggregate(group=[{0}], 
agg#0=[MIN($1)])",
+          "\n                          
LogicalExchange(distribution=[hash[0]])",
+          "\n                            LogicalAggregate(group=[{0}], 
agg#0=[MIN($1)])",
+          "\n                              LogicalProject(col3=[$2], 
$f1=[true])",
+          "\n                                LogicalFilter(condition=[=($3, 
'foo')])",
+          "\n                                  LogicalTableScan(table=[[b]])",
+          "\n              LogicalExchange(distribution=[hash[0]])",
+          "\n                LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n                  LogicalExchange(distribution=[hash[0]])",
+          "\n                    LogicalAggregate(group=[{0}], 
agg#0=[MIN($1)])",
+          "\n                      LogicalProject(col3=[$2], $f1=[true])",
+          "\n                        LogicalFilter(condition=[=($3, 'bar')])",
+          "\n                          LogicalTableScan(table=[[b]])",
+          "\n      LogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n          LogicalExchange(distribution=[hash[0]])",
+          "\n            LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n              LogicalProject(col3=[$2], $f1=[true])",
+          "\n                LogicalFilter(condition=[=($3, 'foobar')])",
+          "\n                  LogicalTableScan(table=[[b]])",
+          "\n"
+        ]
       }
     ]
   },
diff --git a/pinot-query-runtime/src/test/resources/queries/Case.json 
b/pinot-query-runtime/src/test/resources/queries/Case.json
index 449f602f71..82b7612c4f 100644
--- a/pinot-query-runtime/src/test/resources/queries/Case.json
+++ b/pinot-query-runtime/src/test/resources/queries/Case.json
@@ -29,7 +29,11 @@
       { "sql": "SELECT intCol, CASE WHEN intCol % 2 = 0 THEN intCol ELSE 
intCol * 2 END AS intVal, strCol FROM {tbl1}"},
       { "sql": "SELECT intCol, CASE WHEN floatCol > 4.0 THEN floatCol ELSE 
floatCol / 2.0 END AS floatVal, strCol FROM {tbl1}"},
       { "sql": "SELECT intCol, CASE WHEN doubleCol > 6.0 THEN doubleCol ELSE 
doubleCol / 2.0 END AS doubleVal, strCol FROM {tbl1}"},
-      { "sql": "SELECT intCol, CASE WHEN (SELECT SUM(floatCol) FROM {tbl1}) > 
16.0 THEN 'Large sum' ELSE 'Small sum' END AS aggVal, strCol FROM {tbl1}"}
+      {
+        "ignored": true,
+        "comment": "See https://github.com/apache/pinot/issues/10415 for more 
details",
+        "sql": "SELECT intCol, CASE WHEN (SELECT SUM(floatCol) FROM {tbl1}) > 
16.0 THEN 'Large sum' ELSE 'Small sum' END AS aggVal, strCol FROM {tbl1}"
+      }
     ]
   },
   "nested_case_when_test": {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to