This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 84c6f47e4f [Feature](Nereids) add WinMagic rule to rewrite scalar
sub-query to window function (#17968)
84c6f47e4f is described below
commit 84c6f47e4f922766cd7dfa8f8b5c78665f41d896
Author: mch_ucchi <[email protected]>
AuthorDate: Mon Mar 27 23:58:41 2023 +0800
[Feature](Nereids) add WinMagic rule to rewrite scalar sub-query to window
function (#17968)
refer paper: WinMagic - Subquery Elimination Using Window Aggregation
SQL like TPC-H Q2 and Q17, which contains a correlated sub-query with only
one aggregation function output, we can eliminate the sub-query and transform
it to window function. For example, TPC-H Q17 is
```sql
select
sum(l_extendedprice) / 7.0 as avg_yearly
from
lineitem,
part
where
p_partkey = l_partkey
and p_brand = 'Brand#23'
and p_container = 'MED BOX'
and l_quantity < (
select
0.2 * avg(l_quantity)
from
lineitem
where
l_partkey = p_partkey
);
```
we rewrite it to
```sql
select
sum(l_extendedprice) / 7.0 as avg_yearly
from (
select
l_extendedprice, l_quantity, avg(l_quantity) over(partition by
l_partkey) avg_l_quantity
from
lineitem,
part
where
p_partkey = l_partkey
and p_brand = 'Brand#23'
and p_container = 'MED BOX' )
where l_quantity < 0.2 * avg_l_quantity
```
now the rule can only handle: where conjuncts in outer scope contain one
sub-query and the conjunct contain sub-query is a comparison-predicate, we will
support compound-predicate and more than one conjuncts containing sub-query
later.
---
.../doris/nereids/jobs/batch/NereidsRewriter.java | 3 +
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../logical/AggScalarSubQueryToWindowFunction.java | 385 +++++++++++++++++++++
.../rules/rewrite/logical/ExistsApplyToJoin.java | 16 +-
.../rules/rewrite/logical/InApplyToJoin.java | 4 +-
...CorrelatedFilterUnderApplyAggregateProject.java | 2 +-
.../rewrite/logical/PullUpProjectUnderApply.java | 2 +-
.../rules/rewrite/logical/ScalarApplyToJoin.java | 8 +-
.../logical/UnCorrelatedApplyAggregateFilter.java | 2 +-
.../rewrite/logical/UnCorrelatedApplyFilter.java | 2 +-
.../logical/UnCorrelatedApplyProjectFilter.java | 2 +-
.../org/apache/doris/nereids/trees/TreeNode.java | 13 +
.../nereids/trees/plans/logical/LogicalApply.java | 26 +-
.../org/apache/doris/nereids/util/PlanUtils.java | 16 +
.../AggScalarSubQueryToWindowFunctionTest.java | 177 ++++++++++
.../rules/rewrite/logical/PushdownLimitTest.java | 5 +-
16 files changed, 629 insertions(+), 35 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
index ef830e4e9c..d00790dffc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
@@ -32,6 +32,7 @@ import
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewrite;
import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithAggregate;
import
org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithoutAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.AdjustNullable;
+import
org.apache.doris.nereids.rules.rewrite.logical.AggScalarSubQueryToWindowFunction;
import org.apache.doris.nereids.rules.rewrite.logical.BuildAggForUnion;
import
org.apache.doris.nereids.rules.rewrite.logical.CheckAndStandardizeWindowFunctionAndFrame;
import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
@@ -104,6 +105,8 @@ public class NereidsRewriter extends BatchRewriteJob {
),
topic("Subquery unnesting",
+ custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION,
AggScalarSubQueryToWindowFunction::new),
+
bottomUp(
new EliminateUselessPlanUnderApply(),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index c85fdfeef8..cd7d7a7c90 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -108,6 +108,7 @@ public enum RuleType {
ELIMINATE_SORT_UNDER_APPLY(RuleTypeClass.REWRITE),
ELIMINATE_SORT_UNDER_APPLY_PROJECT(RuleTypeClass.REWRITE),
PULL_UP_PROJECT_UNDER_APPLY(RuleTypeClass.REWRITE),
+ AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION(RuleTypeClass.REWRITE),
UN_CORRELATED_APPLY_FILTER(RuleTypeClass.REWRITE),
UN_CORRELATED_APPLY_PROJECT_FILTER(RuleTypeClass.REWRITE),
UN_CORRELATED_APPLY_AGGREGATE_FILTER(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunction.java
new file mode 100644
index 0000000000..2603ca611c
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunction.java
@@ -0,0 +1,385 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanUtils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * change the plan:
+ * logicalFilter(logicalApply(any(), logicalAggregate()))
+ * to
+ * logicalProject((logicalFilter(logicalWindow(logicalFilter(any())))))
+ * refer paper: WinMagic - Subquery Elimination Using Window Aggregation
+ */
+
+public class AggScalarSubQueryToWindowFunction extends
DefaultPlanRewriter<JobContext> implements CustomRewriter {
+ private static final Set<Class<? extends AggregateFunction>>
SUPPORTED_FUNCTION = ImmutableSet.of(
+ Min.class, Max.class, Count.class, Sum.class, Avg.class
+ );
+ private static final Set<Class<? extends LogicalPlan>> LEFT_SUPPORTED_PLAN
= ImmutableSet.of(
+ LogicalRelation.class, LogicalJoin.class, LogicalProject.class,
LogicalFilter.class, LogicalLimit.class
+ );
+ private static final Set<Class<? extends LogicalPlan>>
RIGHT_SUPPORTED_PLAN = ImmutableSet.of(
+ LogicalRelation.class, LogicalJoin.class, LogicalProject.class,
LogicalFilter.class, LogicalAggregate.class
+ );
+ private List<LogicalPlan> outerPlans = null;
+ private List<LogicalPlan> innerPlans = null;
+ private LogicalAggregate aggOp = null;
+ private List<AggregateFunction> functions = null;
+
+ @Override
+ public Plan rewriteRoot(Plan plan, JobContext context) {
+ return plan.accept(this, context);
+ }
+
+ @Override
+ public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter,
JobContext context) {
+ LogicalApply<Plan, LogicalAggregate<Plan>> apply =
checkPattern(filter);
+ if (apply == null) {
+ return filter;
+ }
+ if (!check(filter, apply)) {
+ return filter;
+ }
+ return trans(filter, apply);
+ }
+
+ private LogicalApply<Plan, LogicalAggregate<Plan>>
checkPattern(LogicalFilter<? extends Plan> filter) {
+ LogicalPlan plan = ((LogicalPlan) filter.child());
+ if (plan instanceof LogicalProject) {
+ plan = ((LogicalPlan) ((LogicalProject) plan).child());
+ }
+ if (!(plan instanceof LogicalApply)) {
+ return null;
+ }
+ LogicalApply apply = (LogicalApply) plan;
+ if (!checkApplyNode(apply)) {
+ return null;
+ }
+ return apply.right() instanceof LogicalAggregate ? apply : null;
+ }
+
+ private boolean check(LogicalFilter<? extends Plan> filter,
LogicalApply<Plan, LogicalAggregate<Plan>> apply) {
+ LogicalPlan outer = ((LogicalPlan) apply.child(0));
+ LogicalPlan inner = ((LogicalPlan) apply.child(1));
+ outerPlans = PlanCollector.INSTANCE.collect(outer);
+ innerPlans = PlanCollector.INSTANCE.collect(inner);
+ Optional<LogicalFilter> innerFilter = innerPlans.stream()
+ .filter(LogicalFilter.class::isInstance)
+ .map(LogicalFilter.class::cast).findFirst();
+ return innerFilter.isPresent()
+ && checkPlanType() && checkAggType()
+ && checkRelation(apply.getCorrelationSlot())
+ && checkPredicate(Sets.newHashSet(filter.getConjuncts()),
+ Sets.newHashSet(innerFilter.get().getConjuncts()));
+ }
+
+ // check children's nodes because query process will be changed
+ private boolean checkPlanType() {
+ return outerPlans.stream().allMatch(p ->
LEFT_SUPPORTED_PLAN.stream().anyMatch(c -> c.isInstance(p)))
+ && innerPlans.stream().allMatch(p ->
RIGHT_SUPPORTED_PLAN.stream().anyMatch(c -> c.isInstance(p)));
+ }
+
+ private boolean checkApplyNode(LogicalApply apply) {
+ return apply.isScalar() && apply.isCorrelated() &&
apply.getSubCorrespondingConjunct().isPresent()
+ && apply.getSubCorrespondingConjunct().get() instanceof
ComparisonPredicate;
+ }
+
+ // check aggregation of inner scope
+ private boolean checkAggType() {
+ List<LogicalAggregate> aggSet =
innerPlans.stream().filter(LogicalAggregate.class::isInstance)
+ .map(LogicalAggregate.class::cast)
+ .collect(Collectors.toList());
+ if (aggSet.size() > 1) {
+ // window functions don't support nesting.
+ return false;
+ }
+ aggOp = aggSet.get(0);
+ functions = ((List<AggregateFunction>)
ExpressionUtils.<AggregateFunction>collectAll(
+ aggOp.getOutputExpressions(),
AggregateFunction.class::isInstance));
+ Preconditions.checkArgument(functions.size() == 1);
+ return functions.stream().allMatch(f ->
SUPPORTED_FUNCTION.contains(f.getClass()) && !f.isDistinct());
+ }
+
+ // check if the relations of the outer's includes the inner's
+ private boolean checkRelation(List<Expression> correlatedSlots) {
+ List<LogicalRelation> outerTables =
outerPlans.stream().filter(LogicalRelation.class::isInstance)
+ .map(LogicalRelation.class::cast)
+ .collect(Collectors.toList());
+ List<LogicalRelation> innerTables =
innerPlans.stream().filter(LogicalRelation.class::isInstance)
+ .map(LogicalRelation.class::cast)
+ .collect(Collectors.toList());
+
+ Set<Long> outerIds = outerTables.stream().map(node ->
node.getTable().getId()).collect(Collectors.toSet());
+ Set<Long> innerIds = innerTables.stream().map(node ->
node.getTable().getId()).collect(Collectors.toSet());
+
+ Set<Long> outerCopy = Sets.newHashSet(outerIds);
+ outerIds.removeAll(innerIds);
+ innerIds.removeAll(outerCopy);
+ if (outerIds.isEmpty() || !innerIds.isEmpty()) {
+ return false;
+ }
+
+ Set<ExprId> correlatedRelationOutput = outerTables.stream()
+ .filter(node -> outerIds.contains(node.getTable().getId()))
+
.map(LogicalRelation::getOutputExprIdSet).flatMap(Collection::stream).collect(Collectors.toSet());
+ return ExpressionUtils.collect(correlatedSlots,
NamedExpression.class::isInstance).stream()
+ .map(NamedExpression.class::cast)
+ .allMatch(e ->
correlatedRelationOutput.contains(e.getExprId()));
+ }
+
+ private boolean checkPredicate(Set<Expression> outerConjuncts,
Set<Expression> innerConjuncts) {
+ Iterator<Expression> innerIter = innerConjuncts.iterator();
+ // inner predicate should be the sub-set of outer predicate.
+ while (innerIter.hasNext()) {
+ Expression innerExpr = innerIter.next();
+ Iterator<Expression> outerIter = outerConjuncts.iterator();
+ while (outerIter.hasNext()) {
+ Expression outerExpr = outerIter.next();
+ if (ExpressionIdenticalChecker.INSTANCE.check(innerExpr,
outerExpr)) {
+ innerIter.remove();
+ outerIter.remove();
+ }
+ }
+ }
+ // now the expressions are all like 'expr op literal' or flipped, and
whose expr is not correlated.
+ return innerConjuncts.size() == 0;
+ }
+
+ private Plan trans(LogicalFilter<? extends Plan> filter,
LogicalApply<Plan, LogicalAggregate<Plan>> apply) {
+ LogicalAggregate<Plan> agg = apply.right();
+
+ // transform algorithm
+ // first: find the slot in outer scope corresponding to the slot in
aggregate function in inner scope.
+ // second: find the aggregation function in inner scope, and replace
it to window function, and the aggregate
+ // slot is the slot in outer scope in the first step.
+ // third: the expression containing aggregation function in inner
scope will be the child of an alias,
+ // so in the predicate between outer and inner, we change the alias to
expression which is the alias's child,
+ // and change the aggregation function to the alias of window function.
+
+ // for example, in tpc-h Q17
+ // window filter conjuncts is
+ // cast(l_quantity#id1 as decimal(27, 9)) < `0.2 * avg(l_quantity)`#id2
+ // and
+ // 0.2 * avg(l_quantity#id3) as `0.2 * l_quantity`#id2
+ // is agg's output expression
+ // we change it to
+ // cast(l_quantity#id1 as decimal(27, 9)) < 0.2 * `avg(l_quantity#id1)
over(window)`#id4
+ // and
+ // avg(l_quantity#id1) over(window) as `avg(l_quantity#id1)
over(window)`#id4
+
+ // it's a simple case, but we may meet some complex cases in ut.
+ // TODO: support compound predicate and multi apply node.
+
+ Expression windowFilterConjunct =
apply.getSubCorrespondingConjunct().get();
+ windowFilterConjunct = PlanUtils.maybeCommuteComparisonPredicate(
+ (ComparisonPredicate) windowFilterConjunct, apply.left());
+
+ // build window function, replace the slot
+ List<Expression> windowAggSlots =
windowFilterConjunct.child(0).collectToList(Slot.class::isInstance);
+
+ AggregateFunction function = functions.get(0);
+ if (function instanceof NullableAggregateFunction) {
+ // adjust agg function's nullable.
+ function = ((NullableAggregateFunction)
function).withAlwaysNullable(false);
+ }
+
+ WindowExpression windowFunction =
createWindowFunction(apply.getCorrelationSlot(),
+ function.withChildren(windowAggSlots));
+ NamedExpression windowFunctionAlias = new Alias(windowFunction,
windowFunction.toSql());
+
+ // build filter conjunct, get the alias of the agg output and extract
its child.
+ // then replace the agg to window function, then build conjunct
+ // we ensure aggOut is Alias.
+ NamedExpression aggOut = agg.getOutputExpressions().get(0);
+ Expression aggOutExpr = aggOut.child(0);
+ // change the agg function to window function alias.
+ aggOutExpr = MapReplacer.INSTANCE.replace(aggOutExpr, ImmutableMap
+ .of(AggregateFunction.class, e ->
windowFunctionAlias.toSlot()));
+
+ // we change the child contains the original agg output to agg output
expr.
+ // for comparison predicate, it is always the child(1), since we
ensure the window agg slot is in child(0)
+ // for in predicate, we should extract the options and find the
corresponding child.
+ windowFilterConjunct = windowFilterConjunct
+ .withChildren(windowFilterConjunct.child(0), aggOutExpr);
+
+ LogicalFilter newFilter = ((LogicalFilter)
filter.withChildren(apply.left()));
+ LogicalWindow newWindow = new
LogicalWindow<>(ImmutableList.of(windowFunctionAlias), newFilter);
+ LogicalFilter windowFilter = new
LogicalFilter<>(ImmutableSet.of(windowFilterConjunct), newWindow);
+ return windowFilter;
+ }
+
+ private WindowExpression createWindowFunction(List<Expression>
correlatedSlots, AggregateFunction function) {
+ // partition by clause is set by all the correlated slots.
+
Preconditions.checkArgument(correlatedSlots.stream().allMatch(Slot.class::isInstance));
+ return new WindowExpression(function, correlatedSlots,
Collections.emptyList());
+ }
+
+ private static class PlanCollector extends DefaultPlanVisitor<Void,
List<LogicalPlan>> {
+ public static final PlanCollector INSTANCE = new PlanCollector();
+
+ public List<LogicalPlan> collect(LogicalPlan plan) {
+ List<LogicalPlan> buffer = Lists.newArrayList();
+ plan.accept(this, buffer);
+ return buffer;
+ }
+
+ @Override
+ public Void visit(Plan plan, List<LogicalPlan> buffer) {
+ Preconditions.checkArgument(plan instanceof LogicalPlan);
+ buffer.add(((LogicalPlan) plan));
+ plan.children().forEach(child -> child.accept(this, buffer));
+ return null;
+ }
+ }
+
+ private static class ExpressionIdenticalChecker extends
DefaultExpressionVisitor<Boolean, Expression> {
+ public static final ExpressionIdenticalChecker INSTANCE = new
ExpressionIdenticalChecker();
+
+ public boolean check(Expression expression, Expression expression1) {
+ return expression.accept(this, expression1);
+ }
+
+ private boolean isClassMatch(Object o1, Object o2) {
+ return o1.getClass().equals(o2.getClass());
+ }
+
+ private boolean isSameChild(Expression expression, Expression
expression1) {
+ if (expression.children().size() != expression1.children().size())
{
+ return false;
+ }
+ for (int i = 0; i < expression.children().size(); ++i) {
+ if (!expression.children().get(i).accept(this,
expression1.children().get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private boolean isSameObjects(Object... o) {
+ Preconditions.checkArgument(o.length % 2 == 0);
+ for (int i = 0; i < o.length; i += 2) {
+ if (!Objects.equals(o[i], o[i + 1])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private boolean isSameOperator(Expression expression, Expression
expression1, Object... o) {
+ return isSameObjects(o) && isSameChild(expression, expression1);
+ }
+
+ @Override
+ public Boolean visit(Expression expression, Expression expression1) {
+ return isClassMatch(expression, expression1) &&
isSameChild(expression, expression1);
+ }
+
+ @Override
+ public Boolean visitNamedExpression(NamedExpression namedExpression,
Expression expr) {
+ return isClassMatch(namedExpression, expr)
+ && isSameOperator(namedExpression, expr,
namedExpression.getName(),
+ ((NamedExpression) expr).getName());
+ }
+
+ @Override
+ public Boolean visitLiteral(Literal literal, Expression expr) {
+ return isClassMatch(literal, expr)
+ && isSameOperator(literal, expr, literal.getValue(),
((Literal) expr).getValue());
+ }
+
+ @Override
+ public Boolean visitEqualTo(EqualTo equalTo, Expression expr) {
+ return isSameChild(equalTo, expr) ||
isSameChild(equalTo.commute(), expr);
+ }
+ }
+
+ private static class MapReplacer extends
DefaultExpressionRewriter<Map<Class<? extends Expression>,
+ Function<Expression, Expression>>> {
+ public static final MapReplacer INSTANCE = new MapReplacer();
+
+ public Expression replace(Expression e, Map<Class<? extends
Expression>,
+ Function<Expression, Expression>> context) {
+ return e.accept(this, context);
+ }
+
+ @Override
+ public Expression visit(Expression e, Map<Class<? extends Expression>,
+ Function<Expression, Expression>> context) {
+ Expression replaced = e;
+ for (Class c : context.keySet()) {
+ if (c.isInstance(e)) {
+ replaced = context.get(c).apply(e);
+ break;
+ }
+ }
+ return super.visit(replaced, context);
+ }
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java
index 26a3f7de80..5e6a4ac265 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java
@@ -92,11 +92,11 @@ public class ExistsApplyToJoin extends
OneRewriteRuleFactory {
private Plan correlatedToJoin(LogicalApply apply) {
Optional<Expression> correlationFilter = apply.getCorrelationFilter();
Expression predicate = null;
- if (correlationFilter.isPresent() &&
apply.getSubCorrespondingConject().isPresent()) {
+ if (correlationFilter.isPresent() &&
apply.getSubCorrespondingConjunct().isPresent()) {
predicate = ExpressionUtils.and(correlationFilter.get(),
- (Expression) apply.getSubCorrespondingConject().get());
- } else if (apply.getSubCorrespondingConject().isPresent()) {
- predicate = (Expression) apply.getSubCorrespondingConject().get();
+ (Expression) apply.getSubCorrespondingConjunct().get());
+ } else if (apply.getSubCorrespondingConjunct().isPresent()) {
+ predicate = (Expression) apply.getSubCorrespondingConjunct().get();
} else if (correlationFilter.isPresent()) {
predicate = correlationFilter.get();
}
@@ -134,8 +134,8 @@ public class ExistsApplyToJoin extends
OneRewriteRuleFactory {
LogicalAggregate newAgg = new LogicalAggregate<>(new ArrayList<>(),
ImmutableList.of(alias), newLimit);
LogicalJoin newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION,
- unapply.getSubCorrespondingConject().isPresent()
- ? ExpressionUtils.extractConjunction((Expression)
unapply.getSubCorrespondingConject().get())
+ unapply.getSubCorrespondingConjunct().isPresent()
+ ? ExpressionUtils.extractConjunction((Expression)
unapply.getSubCorrespondingConjunct().get())
: ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE,
unapply.getMarkJoinSlotReference(),
(LogicalPlan) unapply.left(), newAgg);
return new LogicalFilter<>(ImmutableSet.of(new
EqualTo(newAgg.getOutput().get(0),
@@ -145,8 +145,8 @@ public class ExistsApplyToJoin extends
OneRewriteRuleFactory {
private Plan unCorrelatedExist(LogicalApply unapply) {
LogicalLimit newLimit = new LogicalLimit<>(1, 0, LimitPhase.ORIGIN,
(LogicalPlan) unapply.right());
return new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION,
- unapply.getSubCorrespondingConject().isPresent()
- ? ExpressionUtils.extractConjunction((Expression)
unapply.getSubCorrespondingConject().get())
+ unapply.getSubCorrespondingConjunct().isPresent()
+ ? ExpressionUtils.extractConjunction((Expression)
unapply.getSubCorrespondingConjunct().get())
: ExpressionUtils.EMPTY_CONDITION,
JoinHint.NONE, unapply.getMarkJoinSlotReference(), (LogicalPlan)
unapply.left(), newLimit);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
index 8325ad18fb..a0d9cfd6f5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
@@ -102,8 +102,8 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
predicate = new EqualTo(left, right);
}
- if (apply.getSubCorrespondingConject().isPresent()) {
- predicate = ExpressionUtils.and(predicate,
apply.getSubCorrespondingConject().get());
+ if (apply.getSubCorrespondingConjunct().isPresent()) {
+ predicate = ExpressionUtils.and(predicate,
apply.getSubCorrespondingConjunct().get());
}
List<Expression> conjuncts =
ExpressionUtils.extractConjunction(predicate);
if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpCorrelatedFilterUnderApplyAggregateProject.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpCorrelatedFilterUnderApplyAggregateProject.java
index 9c2b22aa3e..99b6b123fa 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpCorrelatedFilterUnderApplyAggregateProject.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpCorrelatedFilterUnderApplyAggregateProject.java
@@ -80,7 +80,7 @@ public class PullUpCorrelatedFilterUnderApplyAggregateProject
extends OneRewrite
LogicalAggregate newAgg =
agg.withChildren(ImmutableList.of(newFilter));
return new LogicalApply<>(apply.getCorrelationSlot(),
apply.getSubqueryExpr(),
apply.getCorrelationFilter(),
apply.getMarkJoinSlotReference(),
- apply.getSubCorrespondingConject(), apply.left(),
newAgg);
+ apply.getSubCorrespondingConjunct(), apply.left(),
newAgg);
}).toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpProjectUnderApply.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpProjectUnderApply.java
index bca11e2351..a8105775ba 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpProjectUnderApply.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpProjectUnderApply.java
@@ -58,7 +58,7 @@ public class PullUpProjectUnderApply extends
OneRewriteRuleFactory {
LogicalProject<Plan> project = apply.right();
LogicalApply newCorrelate = new
LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
apply.getCorrelationFilter(),
apply.getMarkJoinSlotReference(),
- apply.getSubCorrespondingConject(),
apply.left(), project.child());
+ apply.getSubCorrespondingConjunct(),
apply.left(), project.child());
List<NamedExpression> newProjects = new ArrayList<>();
newProjects.addAll(apply.left().getOutput());
if (apply.getSubqueryExpr() instanceof ScalarSubquery) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java
index 97a1315825..999f0b3a83 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java
@@ -61,8 +61,8 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
(LogicalPlan) apply.right());
return new LogicalJoin<>(JoinType.CROSS_JOIN,
ExpressionUtils.EMPTY_CONDITION,
- apply.getSubCorrespondingConject().isPresent()
- ? ExpressionUtils.extractConjunction((Expression)
apply.getSubCorrespondingConject().get())
+ apply.getSubCorrespondingConjunct().isPresent()
+ ? ExpressionUtils.extractConjunction((Expression)
apply.getSubCorrespondingConjunct().get())
: ExpressionUtils.EMPTY_CONDITION,
JoinHint.NONE,
apply.getMarkJoinSlotReference(),
@@ -86,9 +86,9 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN,
ExpressionUtils.EMPTY_CONDITION,
ExpressionUtils.extractConjunction(
- apply.getSubCorrespondingConject().isPresent()
+ apply.getSubCorrespondingConjunct().isPresent()
? ExpressionUtils.and(
- (Expression)
apply.getSubCorrespondingConject().get(),
+ (Expression)
apply.getSubCorrespondingConjunct().get(),
correlationFilter.get())
: correlationFilter.get()),
JoinHint.NONE,
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyAggregateFilter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyAggregateFilter.java
index ab05f23921..cc6590f7c1 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyAggregateFilter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyAggregateFilter.java
@@ -89,7 +89,7 @@ public class UnCorrelatedApplyAggregateFilter extends
OneRewriteRuleFactory {
apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate),
apply.getMarkJoinSlotReference(),
- apply.getSubCorrespondingConject(),
+ apply.getSubCorrespondingConjunct(),
apply.left(), newAgg);
}).toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyFilter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyFilter.java
index 95959a6205..d33897f260 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyFilter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyFilter.java
@@ -69,7 +69,7 @@ public class UnCorrelatedApplyFilter extends
OneRewriteRuleFactory {
Plan child =
PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate),
filter.child());
return new LogicalApply<>(apply.getCorrelationSlot(),
apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate),
apply.getMarkJoinSlotReference(),
- apply.getSubCorrespondingConject(),
+ apply.getSubCorrespondingConjunct(),
apply.left(), child);
}).toRule(RuleType.UN_CORRELATED_APPLY_FILTER);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyProjectFilter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyProjectFilter.java
index c4cff23293..1907a26ced 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyProjectFilter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyProjectFilter.java
@@ -90,7 +90,7 @@ public class UnCorrelatedApplyProjectFilter extends
OneRewriteRuleFactory {
LogicalProject newProject =
project.withProjectsAndChild(projects, child);
return new LogicalApply<>(apply.getCorrelationSlot(),
apply.getSubqueryExpr(),
ExpressionUtils.optionalAnd(correlatedPredicate),
apply.getMarkJoinSlotReference(),
- apply.getSubCorrespondingConject(),
+ apply.getSubCorrespondingConjunct(),
apply.left(), newProject);
}).toRule(RuleType.UN_CORRELATED_APPLY_PROJECT_FILTER);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
index 92b99ec68e..0394ebea87 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
@@ -222,6 +222,19 @@ public interface TreeNode<NODE_TYPE extends
TreeNode<NODE_TYPE>> {
return (T) result.build();
}
+ /**
+ * Collect the nodes that satisfied the predicate to list.
+ */
+ default <T> List<T> collectToList(Predicate<TreeNode<NODE_TYPE>>
predicate) {
+ ImmutableList.Builder<TreeNode<NODE_TYPE>> result =
ImmutableList.builder();
+ foreach(node -> {
+ if (predicate.test(node)) {
+ result.add(node);
+ }
+ });
+ return (List<T>) result.build();
+ }
+
/**
* iterate top down and test predicate if contains any instance of the
classes
* @param types classes array
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
index 4c3a092244..df28b35059 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
@@ -55,7 +55,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
// The slot replaced by the subquery in MarkJoin
private final Optional<MarkJoinSlotReference> markJoinSlotReference;
- private final Optional<Expression> subCorrespondingConject;
+ private final Optional<Expression> subCorrespondingConjunct;
/**
* Constructor.
@@ -65,22 +65,22 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
List<Expression> correlationSlot,
SubqueryExpr subqueryExpr, Optional<Expression> correlationFilter,
Optional<MarkJoinSlotReference> markJoinSlotReference,
- Optional<Expression> subCorrespondingConject,
+ Optional<Expression> subCorrespondingConjunct,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
super(PlanType.LOGICAL_APPLY, groupExpression, logicalProperties,
leftChild, rightChild);
this.correlationSlot = correlationSlot == null ? ImmutableList.of() :
ImmutableList.copyOf(correlationSlot);
this.subqueryExpr = Objects.requireNonNull(subqueryExpr, "subquery can
not be null");
this.correlationFilter = correlationFilter;
this.markJoinSlotReference = markJoinSlotReference;
- this.subCorrespondingConject = subCorrespondingConject;
+ this.subCorrespondingConjunct = subCorrespondingConjunct;
}
public LogicalApply(List<Expression> correlationSlot, SubqueryExpr
subqueryExpr,
Optional<Expression> correlationFilter,
Optional<MarkJoinSlotReference> markJoinSlotReference,
- Optional<Expression> subCorrespondingConject,
+ Optional<Expression> subCorrespondingConjunct,
LEFT_CHILD_TYPE input, RIGHT_CHILD_TYPE subquery) {
this(Optional.empty(), Optional.empty(), correlationSlot, subqueryExpr,
- correlationFilter, markJoinSlotReference,
subCorrespondingConject, input, subquery);
+ correlationFilter, markJoinSlotReference,
subCorrespondingConjunct, input, subquery);
}
public List<Expression> getCorrelationSlot() {
@@ -123,8 +123,8 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
return markJoinSlotReference;
}
- public Optional<Expression> getSubCorrespondingConject() {
- return subCorrespondingConject;
+ public Optional<Expression> getSubCorrespondingConjunct() {
+ return subCorrespondingConjunct;
}
@Override
@@ -142,7 +142,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
"isMarkJoin", markJoinSlotReference.isPresent(),
"MarkJoinSlotReference", markJoinSlotReference.isPresent() ?
markJoinSlotReference.get() : "empty",
"scalarSubCorrespondingSlot",
- subCorrespondingConject.isPresent() ?
subCorrespondingConject.get() : "empty");
+ subCorrespondingConjunct.isPresent() ?
subCorrespondingConjunct.get() : "empty");
}
@Override
@@ -158,13 +158,13 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
&& Objects.equals(subqueryExpr, that.getSubqueryExpr())
&& Objects.equals(correlationFilter,
that.getCorrelationFilter())
&& Objects.equals(markJoinSlotReference,
that.getMarkJoinSlotReference())
- && Objects.equals(subCorrespondingConject,
that.getSubCorrespondingConject());
+ && Objects.equals(subCorrespondingConjunct,
that.getSubCorrespondingConjunct());
}
@Override
public int hashCode() {
return Objects.hash(
- correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, subCorrespondingConject);
+ correlationSlot, subqueryExpr, correlationFilter,
markJoinSlotReference, subCorrespondingConjunct);
}
@Override
@@ -189,7 +189,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
public LogicalBinary<Plan, Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalApply<>(correlationSlot, subqueryExpr,
correlationFilter,
- markJoinSlotReference, subCorrespondingConject,
+ markJoinSlotReference, subCorrespondingConjunct,
children.get(0), children.get(1));
}
@@ -197,13 +197,13 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
public Plan withGroupExpression(Optional<GroupExpression> groupExpression)
{
return new LogicalApply<>(groupExpression,
Optional.of(getLogicalProperties()),
correlationSlot, subqueryExpr, correlationFilter,
- markJoinSlotReference, subCorrespondingConject, left(),
right());
+ markJoinSlotReference, subCorrespondingConjunct, left(),
right());
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties>
logicalProperties) {
return new LogicalApply<>(Optional.empty(), logicalProperties,
correlationSlot, subqueryExpr, correlationFilter,
- markJoinSlotReference, subCorrespondingConject, left(),
right());
+ markJoinSlotReference, subCorrespondingConjunct, left(),
right());
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
index 1a08fed6dd..f68664bc32 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
@@ -17,10 +17,14 @@
package org.apache.doris.nereids.util;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import com.google.common.collect.Sets;
+
import java.util.Optional;
import java.util.Set;
@@ -38,4 +42,16 @@ public class PlanUtils {
public static Plan filterOrSelf(Set<Expression> predicates, Plan plan) {
return filter(predicates, plan).map(Plan.class::cast).orElse(plan);
}
+
+ /**
+ * normalize comparison predicate on a binary plan to its two sides are
corresponding to the child's output.
+ */
+ public static ComparisonPredicate
maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
+ Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
+ Set<Slot> leftSlots = left.getOutputSet();
+ Set<Slot> buffer = Sets.newHashSet(slots);
+ buffer.removeAll(leftSlots);
+ return buffer.isEmpty() ? expression : expression.commute();
+ }
+
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunctionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunctionTest.java
new file mode 100644
index 0000000000..d63385f656
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunctionTest.java
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.datasets.tpch.TPCHTestBase;
+import org.apache.doris.nereids.datasets.tpch.TPCHUtils;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+public class AggScalarSubQueryToWindowFunctionTest extends TPCHTestBase
implements MemoPatternMatchSupported {
+ private static final String SQL_TEMPLATE = " select\n"
+ + " sum(l_extendedprice) / 7.0 as avg_yearly\n"
+ + " from\n"
+ + " lineitem,\n"
+ + " part\n"
+ + " where\n"
+ + " p_partkey = l_partkey\n"
+ + " and p_brand = 'Brand#23'\n"
+ + " and p_container = 'MED BOX'\n"
+ + " (p1) (p2)";
+ private static final String SUB_QUERY_TEMPLATE = "(\n"
+ + " select\n"
+ + " %s\n"
+ + " from\n"
+ + " lineitem\n"
+ + " where\n"
+ + " l_partkey = p_partkey\n"
+ + " )";
+ private static final String AVG = "0.2 * avg(l_quantity)";
+ private static final String MAX = "max(l_quantity) / 2";
+ private static final String MIN = "min(l_extendedprice) * 5";
+
+ private static final String[] queries = {
+ buildSubQuery(AVG),
+ buildSubQuery(MAX),
+ buildSubQuery(MIN)
+ };
+
+ private static String buildFromTemplate(String[] predicate, String[]
query) {
+ String sql = SQL_TEMPLATE;
+ for (int i = 0; i < predicate.length; ++i) {
+ for (int j = 0; j < query.length; ++j) {
+ predicate[i] = predicate[i].replace(String.format("(q%d)", j +
1), query[j]);
+ }
+ sql = sql.replace(String.format("(p%d)", i + 1), predicate[i]);
+ }
+ return sql;
+ }
+
+ private static String buildSubQuery(String res) {
+ return String.format(SUB_QUERY_TEMPLATE, res);
+ }
+
+ @Test
+ public void testRuleOnTPCHTest() {
+ check(TPCHUtils.Q2);
+ check(TPCHUtils.Q17);
+ }
+
+ @Disabled
+ @Test
+ public void testComplexPredicates() {
+ // we ensure there's one sub-query in a predicate and in-predicates do
not contain sub-query,
+ // so we test compound predicates and sub-query in more than one
predicates
+ // now we disabled them temporarily, and enable when the rule support
the cases.
+ String[] testCases = {
+ "and l_quantity > 10 or l_quantity < (q1)",
+ "or l_quantity < (q3)",
+ "and l_extendedprice > (q2)",
+ };
+
+ check(buildFromTemplate(new String[] {testCases[0], testCases[1]},
queries));
+ check(buildFromTemplate(new String[] {testCases[0], testCases[2]},
queries));
+ check(buildFromTemplate(new String[] {testCases[1], testCases[2]},
queries));
+ }
+
+ @Test
+ public void testNotMatchTheRule() {
+ String[] testCases = {
+ "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+ + " from lineitem, part\n"
+ + " where p_partkey = l_partkey\n"
+ + " and p_brand = 'Brand#23'\n"
+ + " and p_container = 'MED BOX'\n"
+ + " and l_quantity < (\n"
+ + " select 0.2 * avg(l_quantity)\n"
+ + " from lineitem);",
+ "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+ + " from lineitem, part\n"
+ + " where p_partkey = l_partkey\n"
+ + " and p_brand = 'Brand#23'\n"
+ + " and p_container = 'MED BOX'\n"
+ + " and l_quantity < (\n"
+ + " select 0.2 * avg(l_quantity)\n"
+ + " from lineitem, part\n"
+ + " where l_partkey = p_partkey);",
+ "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+ + " from lineitem, part\n"
+ + " where p_partkey = l_partkey\n"
+ + " and p_brand = 'Brand#23'\n"
+ + " and p_container = 'MED BOX'\n"
+ + " and l_quantity < (\n"
+ + " select 0.2 * avg(l_quantity)\n"
+ + " from lineitem, partsupp\n"
+ + " where l_partkey = p_partkey);",
+ "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+ + " from lineitem, part\n"
+ + " where\n"
+ + " p_partkey = l_partkey\n"
+ + " and p_brand = 'Brand#23'\n"
+ + " and p_container = 'MED BOX'\n"
+ + " and l_quantity < (\n"
+ + " select 0.2 * avg(l_quantity)\n"
+ + " from lineitem\n"
+ + " where l_partkey = p_partkey\n"
+ + " and p_brand = 'Brand#24');",
+ "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+ + " from lineitem, part\n"
+ + " where\n"
+ + " p_partkey = l_partkey\n"
+ + " and p_brand = 'Brand#23'\n"
+ + " and p_container = 'MED BOX'\n"
+ + " and l_quantity < (\n"
+ + " select 0.2 * avg(l_quantity)\n"
+ + " from lineitem\n"
+ + " where l_partkey = p_partkey\n"
+ + " and l_partkey = 10);"
+ };
+ // notice: case 4 and 5 can apply the rule, but we support it later.
+ for (String s : testCases) {
+ checkNot(s);
+ }
+ }
+
+ private void check(String sql) {
+ System.out.printf("Test:\n%s\n\n", sql);
+ Plan plan = PlanChecker.from(createCascadesContext(sql))
+ .analyze(sql)
+ .applyTopDown(new AggScalarSubQueryToWindowFunction())
+ .rewrite()
+ .getPlan();
+ System.out.println(plan.treeString());
+ Assertions.assertTrue(plan.anyMatch(LogicalWindow.class::isInstance));
+ }
+
+ private void checkNot(String sql) {
+ System.out.printf("Test:\n%s\n\n", sql);
+ Plan plan = PlanChecker.from(createCascadesContext(sql))
+ .analyze(sql)
+ .applyTopDown(new AggScalarSubQueryToWindowFunction())
+ .rewrite()
+ .getPlan();
+ System.out.println(plan.treeString());
+ Assertions.assertFalse(plan.anyMatch(LogicalWindow.class::isInstance));
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java
index 7a4c518d28..fe7eb62d73 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java
@@ -23,7 +23,6 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.LimitPhase;
-import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
@@ -265,8 +264,8 @@ class PushdownLimitTest extends TestWithFeService
implements MemoPatternMatchSup
LogicalJoin<? extends Plan, ? extends Plan> join = new LogicalJoin<>(
joinType,
joinConditions,
- new LogicalOlapScan(new ObjectId(0), PlanConstructor.score),
- new LogicalOlapScan(new ObjectId(1), PlanConstructor.student)
+ new LogicalOlapScan(((LogicalOlapScan) scanScore).getId(),
PlanConstructor.score),
+ new LogicalOlapScan(((LogicalOlapScan) scanStudent).getId(),
PlanConstructor.student)
);
if (hasProject) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]