This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 84c6f47e4f [Feature](Nereids) add WinMagic rule to rewrite scalar 
sub-query to window function (#17968)
84c6f47e4f is described below

commit 84c6f47e4f922766cd7dfa8f8b5c78665f41d896
Author: mch_ucchi <[email protected]>
AuthorDate: Mon Mar 27 23:58:41 2023 +0800

    [Feature](Nereids) add WinMagic rule to rewrite scalar sub-query to window 
function (#17968)
    
    refer paper: WinMagic - Subquery Elimination Using Window Aggregation
    
    SQL like TPC-H Q2 and Q17, which contains a correlated sub-query with only 
one aggregation function output, we can eliminate the sub-query and transform 
it to window function. For example, TPC-H Q17 is
    
    ```sql
    select
            sum(l_extendedprice) / 7.0 as avg_yearly
        from
            lineitem,
            part
        where
            p_partkey = l_partkey
            and p_brand = 'Brand#23'
            and p_container = 'MED BOX'
            and l_quantity < (
                select
                    0.2 * avg(l_quantity)
                from
                    lineitem
                where
                    l_partkey = p_partkey
            );
    ```
    
    we rewrite it to
    
    ```sql
    select
            sum(l_extendedprice) / 7.0 as avg_yearly
        from (
        select
            l_extendedprice, l_quantity, avg(l_quantity) over(partition by 
l_partkey) avg_l_quantity
        from
            lineitem,
            part
        where
            p_partkey = l_partkey
            and p_brand = 'Brand#23'
            and p_container = 'MED BOX' )
        where l_quantity < 0.2 * avg_l_quantity
    ```
    
    now the rule can only handle: where conjuncts in outer scope contain one 
sub-query and the conjunct contain sub-query is a comparison-predicate, we will 
support compound-predicate and more than one conjuncts containing sub-query 
later.
---
 .../doris/nereids/jobs/batch/NereidsRewriter.java  |   3 +
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../logical/AggScalarSubQueryToWindowFunction.java | 385 +++++++++++++++++++++
 .../rules/rewrite/logical/ExistsApplyToJoin.java   |  16 +-
 .../rules/rewrite/logical/InApplyToJoin.java       |   4 +-
 ...CorrelatedFilterUnderApplyAggregateProject.java |   2 +-
 .../rewrite/logical/PullUpProjectUnderApply.java   |   2 +-
 .../rules/rewrite/logical/ScalarApplyToJoin.java   |   8 +-
 .../logical/UnCorrelatedApplyAggregateFilter.java  |   2 +-
 .../rewrite/logical/UnCorrelatedApplyFilter.java   |   2 +-
 .../logical/UnCorrelatedApplyProjectFilter.java    |   2 +-
 .../org/apache/doris/nereids/trees/TreeNode.java   |  13 +
 .../nereids/trees/plans/logical/LogicalApply.java  |  26 +-
 .../org/apache/doris/nereids/util/PlanUtils.java   |  16 +
 .../AggScalarSubQueryToWindowFunctionTest.java     | 177 ++++++++++
 .../rules/rewrite/logical/PushdownLimitTest.java   |   5 +-
 16 files changed, 629 insertions(+), 35 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
index ef830e4e9c..d00790dffc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
@@ -32,6 +32,7 @@ import 
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewrite;
 import org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithAggregate;
 import 
org.apache.doris.nereids.rules.mv.SelectMaterializedIndexWithoutAggregate;
 import org.apache.doris.nereids.rules.rewrite.logical.AdjustNullable;
+import 
org.apache.doris.nereids.rules.rewrite.logical.AggScalarSubQueryToWindowFunction;
 import org.apache.doris.nereids.rules.rewrite.logical.BuildAggForUnion;
 import 
org.apache.doris.nereids.rules.rewrite.logical.CheckAndStandardizeWindowFunctionAndFrame;
 import org.apache.doris.nereids.rules.rewrite.logical.ColumnPruning;
@@ -104,6 +105,8 @@ public class NereidsRewriter extends BatchRewriteJob {
             ),
 
             topic("Subquery unnesting",
+                custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, 
AggScalarSubQueryToWindowFunction::new),
+
                 bottomUp(
                     new EliminateUselessPlanUnderApply(),
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index c85fdfeef8..cd7d7a7c90 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -108,6 +108,7 @@ public enum RuleType {
     ELIMINATE_SORT_UNDER_APPLY(RuleTypeClass.REWRITE),
     ELIMINATE_SORT_UNDER_APPLY_PROJECT(RuleTypeClass.REWRITE),
     PULL_UP_PROJECT_UNDER_APPLY(RuleTypeClass.REWRITE),
+    AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION(RuleTypeClass.REWRITE),
     UN_CORRELATED_APPLY_FILTER(RuleTypeClass.REWRITE),
     UN_CORRELATED_APPLY_PROJECT_FILTER(RuleTypeClass.REWRITE),
     UN_CORRELATED_APPLY_AGGREGATE_FILTER(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunction.java
new file mode 100644
index 0000000000..2603ca611c
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunction.java
@@ -0,0 +1,385 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanUtils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * change the plan:
+ * logicalFilter(logicalApply(any(), logicalAggregate()))
+ * to
+ * logicalProject((logicalFilter(logicalWindow(logicalFilter(any())))))
+ * refer paper: WinMagic - Subquery Elimination Using Window Aggregation
+ */
+
+public class AggScalarSubQueryToWindowFunction extends 
DefaultPlanRewriter<JobContext> implements CustomRewriter {
+    private static final Set<Class<? extends AggregateFunction>> 
SUPPORTED_FUNCTION = ImmutableSet.of(
+            Min.class, Max.class, Count.class, Sum.class, Avg.class
+    );
+    private static final Set<Class<? extends LogicalPlan>> LEFT_SUPPORTED_PLAN 
= ImmutableSet.of(
+            LogicalRelation.class, LogicalJoin.class, LogicalProject.class, 
LogicalFilter.class, LogicalLimit.class
+    );
+    private static final Set<Class<? extends LogicalPlan>> 
RIGHT_SUPPORTED_PLAN = ImmutableSet.of(
+            LogicalRelation.class, LogicalJoin.class, LogicalProject.class, 
LogicalFilter.class, LogicalAggregate.class
+    );
+    private List<LogicalPlan> outerPlans = null;
+    private List<LogicalPlan> innerPlans = null;
+    private LogicalAggregate aggOp = null;
+    private List<AggregateFunction> functions = null;
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext context) {
+        return plan.accept(this, context);
+    }
+
+    @Override
+    public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, 
JobContext context) {
+        LogicalApply<Plan, LogicalAggregate<Plan>> apply = 
checkPattern(filter);
+        if (apply == null) {
+            return filter;
+        }
+        if (!check(filter, apply)) {
+            return filter;
+        }
+        return trans(filter, apply);
+    }
+
+    private LogicalApply<Plan, LogicalAggregate<Plan>> 
checkPattern(LogicalFilter<? extends Plan> filter) {
+        LogicalPlan plan = ((LogicalPlan) filter.child());
+        if (plan instanceof LogicalProject) {
+            plan = ((LogicalPlan) ((LogicalProject) plan).child());
+        }
+        if (!(plan instanceof LogicalApply)) {
+            return null;
+        }
+        LogicalApply apply = (LogicalApply) plan;
+        if (!checkApplyNode(apply)) {
+            return null;
+        }
+        return apply.right() instanceof LogicalAggregate ? apply : null;
+    }
+
+    private boolean check(LogicalFilter<? extends Plan> filter, 
LogicalApply<Plan, LogicalAggregate<Plan>> apply) {
+        LogicalPlan outer = ((LogicalPlan) apply.child(0));
+        LogicalPlan inner = ((LogicalPlan) apply.child(1));
+        outerPlans = PlanCollector.INSTANCE.collect(outer);
+        innerPlans = PlanCollector.INSTANCE.collect(inner);
+        Optional<LogicalFilter> innerFilter = innerPlans.stream()
+                .filter(LogicalFilter.class::isInstance)
+                .map(LogicalFilter.class::cast).findFirst();
+        return innerFilter.isPresent()
+                && checkPlanType() && checkAggType()
+                && checkRelation(apply.getCorrelationSlot())
+                && checkPredicate(Sets.newHashSet(filter.getConjuncts()),
+                Sets.newHashSet(innerFilter.get().getConjuncts()));
+    }
+
+    // check children's nodes because query process will be changed
+    private boolean checkPlanType() {
+        return outerPlans.stream().allMatch(p -> 
LEFT_SUPPORTED_PLAN.stream().anyMatch(c -> c.isInstance(p)))
+                && innerPlans.stream().allMatch(p -> 
RIGHT_SUPPORTED_PLAN.stream().anyMatch(c -> c.isInstance(p)));
+    }
+
+    private boolean checkApplyNode(LogicalApply apply) {
+        return apply.isScalar() && apply.isCorrelated() && 
apply.getSubCorrespondingConjunct().isPresent()
+                && apply.getSubCorrespondingConjunct().get() instanceof 
ComparisonPredicate;
+    }
+
+    // check aggregation of inner scope
+    private boolean checkAggType() {
+        List<LogicalAggregate> aggSet = 
innerPlans.stream().filter(LogicalAggregate.class::isInstance)
+                .map(LogicalAggregate.class::cast)
+                .collect(Collectors.toList());
+        if (aggSet.size() > 1) {
+            // window functions don't support nesting.
+            return false;
+        }
+        aggOp = aggSet.get(0);
+        functions = ((List<AggregateFunction>) 
ExpressionUtils.<AggregateFunction>collectAll(
+                aggOp.getOutputExpressions(), 
AggregateFunction.class::isInstance));
+        Preconditions.checkArgument(functions.size() == 1);
+        return functions.stream().allMatch(f -> 
SUPPORTED_FUNCTION.contains(f.getClass()) && !f.isDistinct());
+    }
+
+    // check if the relations of the outer's includes the inner's
+    private boolean checkRelation(List<Expression> correlatedSlots) {
+        List<LogicalRelation> outerTables = 
outerPlans.stream().filter(LogicalRelation.class::isInstance)
+                .map(LogicalRelation.class::cast)
+                .collect(Collectors.toList());
+        List<LogicalRelation> innerTables = 
innerPlans.stream().filter(LogicalRelation.class::isInstance)
+                .map(LogicalRelation.class::cast)
+                .collect(Collectors.toList());
+
+        Set<Long> outerIds = outerTables.stream().map(node -> 
node.getTable().getId()).collect(Collectors.toSet());
+        Set<Long> innerIds = innerTables.stream().map(node -> 
node.getTable().getId()).collect(Collectors.toSet());
+
+        Set<Long> outerCopy = Sets.newHashSet(outerIds);
+        outerIds.removeAll(innerIds);
+        innerIds.removeAll(outerCopy);
+        if (outerIds.isEmpty() || !innerIds.isEmpty()) {
+            return false;
+        }
+
+        Set<ExprId> correlatedRelationOutput = outerTables.stream()
+                .filter(node -> outerIds.contains(node.getTable().getId()))
+                
.map(LogicalRelation::getOutputExprIdSet).flatMap(Collection::stream).collect(Collectors.toSet());
+        return ExpressionUtils.collect(correlatedSlots, 
NamedExpression.class::isInstance).stream()
+                .map(NamedExpression.class::cast)
+                .allMatch(e -> 
correlatedRelationOutput.contains(e.getExprId()));
+    }
+
+    private boolean checkPredicate(Set<Expression> outerConjuncts, 
Set<Expression> innerConjuncts) {
+        Iterator<Expression> innerIter = innerConjuncts.iterator();
+        // inner predicate should be the sub-set of outer predicate.
+        while (innerIter.hasNext()) {
+            Expression innerExpr = innerIter.next();
+            Iterator<Expression> outerIter = outerConjuncts.iterator();
+            while (outerIter.hasNext()) {
+                Expression outerExpr = outerIter.next();
+                if (ExpressionIdenticalChecker.INSTANCE.check(innerExpr, 
outerExpr)) {
+                    innerIter.remove();
+                    outerIter.remove();
+                }
+            }
+        }
+        // now the expressions are all like 'expr op literal' or flipped, and 
whose expr is not correlated.
+        return innerConjuncts.size() == 0;
+    }
+
+    private Plan trans(LogicalFilter<? extends Plan> filter, 
LogicalApply<Plan, LogicalAggregate<Plan>> apply) {
+        LogicalAggregate<Plan> agg = apply.right();
+
+        // transform algorithm
+        // first: find the slot in outer scope corresponding to the slot in 
aggregate function in inner scope.
+        // second: find the aggregation function in inner scope, and replace 
it to window function, and the aggregate
+        // slot is the slot in outer scope in the first step.
+        // third: the expression containing aggregation function in inner 
scope will be the child of an alias,
+        // so in the predicate between outer and inner, we change the alias to 
expression which is the alias's child,
+        // and change the aggregation function to the alias of window function.
+
+        // for example, in tpc-h Q17
+        // window filter conjuncts is
+        // cast(l_quantity#id1 as decimal(27, 9)) < `0.2 * avg(l_quantity)`#id2
+        // and
+        // 0.2 * avg(l_quantity#id3) as `0.2 * l_quantity`#id2
+        // is agg's output expression
+        // we change it to
+        // cast(l_quantity#id1 as decimal(27, 9)) < 0.2 * `avg(l_quantity#id1) 
over(window)`#id4
+        // and
+        // avg(l_quantity#id1) over(window) as `avg(l_quantity#id1) 
over(window)`#id4
+
+        // it's a simple case, but we may meet some complex cases in ut.
+        // TODO: support compound predicate and multi apply node.
+
+        Expression windowFilterConjunct = 
apply.getSubCorrespondingConjunct().get();
+        windowFilterConjunct = PlanUtils.maybeCommuteComparisonPredicate(
+                (ComparisonPredicate) windowFilterConjunct, apply.left());
+
+        // build window function, replace the slot
+        List<Expression> windowAggSlots = 
windowFilterConjunct.child(0).collectToList(Slot.class::isInstance);
+
+        AggregateFunction function = functions.get(0);
+        if (function instanceof NullableAggregateFunction) {
+            // adjust agg function's nullable.
+            function = ((NullableAggregateFunction) 
function).withAlwaysNullable(false);
+        }
+
+        WindowExpression windowFunction = 
createWindowFunction(apply.getCorrelationSlot(),
+                function.withChildren(windowAggSlots));
+        NamedExpression windowFunctionAlias = new Alias(windowFunction, 
windowFunction.toSql());
+
+        // build filter conjunct, get the alias of the agg output and extract 
its child.
+        // then replace the agg to window function, then build conjunct
+        // we ensure aggOut is Alias.
+        NamedExpression aggOut = agg.getOutputExpressions().get(0);
+        Expression aggOutExpr = aggOut.child(0);
+        // change the agg function to window function alias.
+        aggOutExpr = MapReplacer.INSTANCE.replace(aggOutExpr, ImmutableMap
+                .of(AggregateFunction.class, e -> 
windowFunctionAlias.toSlot()));
+
+        // we change the child contains the original agg output to agg output 
expr.
+        // for comparison predicate, it is always the child(1), since we 
ensure the window agg slot is in child(0)
+        // for in predicate, we should extract the options and find the 
corresponding child.
+        windowFilterConjunct = windowFilterConjunct
+                .withChildren(windowFilterConjunct.child(0), aggOutExpr);
+
+        LogicalFilter newFilter = ((LogicalFilter) 
filter.withChildren(apply.left()));
+        LogicalWindow newWindow = new 
LogicalWindow<>(ImmutableList.of(windowFunctionAlias), newFilter);
+        LogicalFilter windowFilter = new 
LogicalFilter<>(ImmutableSet.of(windowFilterConjunct), newWindow);
+        return windowFilter;
+    }
+
+    private WindowExpression createWindowFunction(List<Expression> 
correlatedSlots, AggregateFunction function) {
+        // partition by clause is set by all the correlated slots.
+        
Preconditions.checkArgument(correlatedSlots.stream().allMatch(Slot.class::isInstance));
+        return new WindowExpression(function, correlatedSlots, 
Collections.emptyList());
+    }
+
+    private static class PlanCollector extends DefaultPlanVisitor<Void, 
List<LogicalPlan>> {
+        public static final PlanCollector INSTANCE = new PlanCollector();
+
+        public List<LogicalPlan> collect(LogicalPlan plan) {
+            List<LogicalPlan> buffer = Lists.newArrayList();
+            plan.accept(this, buffer);
+            return buffer;
+        }
+
+        @Override
+        public Void visit(Plan plan, List<LogicalPlan> buffer) {
+            Preconditions.checkArgument(plan instanceof LogicalPlan);
+            buffer.add(((LogicalPlan) plan));
+            plan.children().forEach(child -> child.accept(this, buffer));
+            return null;
+        }
+    }
+
+    private static class ExpressionIdenticalChecker extends 
DefaultExpressionVisitor<Boolean, Expression> {
+        public static final ExpressionIdenticalChecker INSTANCE = new 
ExpressionIdenticalChecker();
+
+        public boolean check(Expression expression, Expression expression1) {
+            return expression.accept(this, expression1);
+        }
+
+        private boolean isClassMatch(Object o1, Object o2) {
+            return o1.getClass().equals(o2.getClass());
+        }
+
+        private boolean isSameChild(Expression expression, Expression 
expression1) {
+            if (expression.children().size() != expression1.children().size()) 
{
+                return false;
+            }
+            for (int i = 0; i < expression.children().size(); ++i) {
+                if (!expression.children().get(i).accept(this, 
expression1.children().get(i))) {
+                    return false;
+                }
+            }
+            return true;
+        }
+
+        private boolean isSameObjects(Object... o) {
+            Preconditions.checkArgument(o.length % 2 == 0);
+            for (int i = 0; i < o.length; i += 2) {
+                if (!Objects.equals(o[i], o[i + 1])) {
+                    return false;
+                }
+            }
+            return true;
+        }
+
+        private boolean isSameOperator(Expression expression, Expression 
expression1, Object... o) {
+            return isSameObjects(o) && isSameChild(expression, expression1);
+        }
+
+        @Override
+        public Boolean visit(Expression expression, Expression expression1) {
+            return isClassMatch(expression, expression1) && 
isSameChild(expression, expression1);
+        }
+
+        @Override
+        public Boolean visitNamedExpression(NamedExpression namedExpression, 
Expression expr) {
+            return isClassMatch(namedExpression, expr)
+                    && isSameOperator(namedExpression, expr, 
namedExpression.getName(),
+                    ((NamedExpression) expr).getName());
+        }
+
+        @Override
+        public Boolean visitLiteral(Literal literal, Expression expr) {
+            return isClassMatch(literal, expr)
+                    && isSameOperator(literal, expr, literal.getValue(), 
((Literal) expr).getValue());
+        }
+
+        @Override
+        public Boolean visitEqualTo(EqualTo equalTo, Expression expr) {
+            return isSameChild(equalTo, expr) || 
isSameChild(equalTo.commute(), expr);
+        }
+    }
+
+    private static class MapReplacer extends 
DefaultExpressionRewriter<Map<Class<? extends Expression>,
+                Function<Expression, Expression>>> {
+        public static final MapReplacer INSTANCE = new MapReplacer();
+
+        public Expression replace(Expression e, Map<Class<? extends 
Expression>,
+                Function<Expression, Expression>> context) {
+            return e.accept(this, context);
+        }
+
+        @Override
+        public Expression visit(Expression e, Map<Class<? extends Expression>,
+                Function<Expression, Expression>> context) {
+            Expression replaced = e;
+            for (Class c : context.keySet()) {
+                if (c.isInstance(e)) {
+                    replaced = context.get(c).apply(e);
+                    break;
+                }
+            }
+            return super.visit(replaced, context);
+        }
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java
index 26a3f7de80..5e6a4ac265 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java
@@ -92,11 +92,11 @@ public class ExistsApplyToJoin extends 
OneRewriteRuleFactory {
     private Plan correlatedToJoin(LogicalApply apply) {
         Optional<Expression> correlationFilter = apply.getCorrelationFilter();
         Expression predicate = null;
-        if (correlationFilter.isPresent() && 
apply.getSubCorrespondingConject().isPresent()) {
+        if (correlationFilter.isPresent() && 
apply.getSubCorrespondingConjunct().isPresent()) {
             predicate = ExpressionUtils.and(correlationFilter.get(),
-                (Expression) apply.getSubCorrespondingConject().get());
-        } else if (apply.getSubCorrespondingConject().isPresent()) {
-            predicate = (Expression) apply.getSubCorrespondingConject().get();
+                (Expression) apply.getSubCorrespondingConjunct().get());
+        } else if (apply.getSubCorrespondingConjunct().isPresent()) {
+            predicate = (Expression) apply.getSubCorrespondingConjunct().get();
         } else if (correlationFilter.isPresent()) {
             predicate = correlationFilter.get();
         }
@@ -134,8 +134,8 @@ public class ExistsApplyToJoin extends 
OneRewriteRuleFactory {
         LogicalAggregate newAgg = new LogicalAggregate<>(new ArrayList<>(),
                 ImmutableList.of(alias), newLimit);
         LogicalJoin newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, 
ExpressionUtils.EMPTY_CONDITION,
-                unapply.getSubCorrespondingConject().isPresent()
-                    ? ExpressionUtils.extractConjunction((Expression) 
unapply.getSubCorrespondingConject().get())
+                unapply.getSubCorrespondingConjunct().isPresent()
+                    ? ExpressionUtils.extractConjunction((Expression) 
unapply.getSubCorrespondingConjunct().get())
                     : ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, 
unapply.getMarkJoinSlotReference(),
                 (LogicalPlan) unapply.left(), newAgg);
         return new LogicalFilter<>(ImmutableSet.of(new 
EqualTo(newAgg.getOutput().get(0),
@@ -145,8 +145,8 @@ public class ExistsApplyToJoin extends 
OneRewriteRuleFactory {
     private Plan unCorrelatedExist(LogicalApply unapply) {
         LogicalLimit newLimit = new LogicalLimit<>(1, 0, LimitPhase.ORIGIN, 
(LogicalPlan) unapply.right());
         return new LogicalJoin<>(JoinType.CROSS_JOIN, 
ExpressionUtils.EMPTY_CONDITION,
-            unapply.getSubCorrespondingConject().isPresent()
-                ? ExpressionUtils.extractConjunction((Expression) 
unapply.getSubCorrespondingConject().get())
+            unapply.getSubCorrespondingConjunct().isPresent()
+                ? ExpressionUtils.extractConjunction((Expression) 
unapply.getSubCorrespondingConjunct().get())
                 : ExpressionUtils.EMPTY_CONDITION,
             JoinHint.NONE, unapply.getMarkJoinSlotReference(), (LogicalPlan) 
unapply.left(), newLimit);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
index 8325ad18fb..a0d9cfd6f5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
@@ -102,8 +102,8 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
                 predicate = new EqualTo(left, right);
             }
 
-            if (apply.getSubCorrespondingConject().isPresent()) {
-                predicate = ExpressionUtils.and(predicate, 
apply.getSubCorrespondingConject().get());
+            if (apply.getSubCorrespondingConjunct().isPresent()) {
+                predicate = ExpressionUtils.and(predicate, 
apply.getSubCorrespondingConjunct().get());
             }
             List<Expression> conjuncts = 
ExpressionUtils.extractConjunction(predicate);
             if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpCorrelatedFilterUnderApplyAggregateProject.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpCorrelatedFilterUnderApplyAggregateProject.java
index 9c2b22aa3e..99b6b123fa 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpCorrelatedFilterUnderApplyAggregateProject.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpCorrelatedFilterUnderApplyAggregateProject.java
@@ -80,7 +80,7 @@ public class PullUpCorrelatedFilterUnderApplyAggregateProject 
extends OneRewrite
                     LogicalAggregate newAgg = 
agg.withChildren(ImmutableList.of(newFilter));
                     return new LogicalApply<>(apply.getCorrelationSlot(), 
apply.getSubqueryExpr(),
                             apply.getCorrelationFilter(), 
apply.getMarkJoinSlotReference(),
-                            apply.getSubCorrespondingConject(), apply.left(), 
newAgg);
+                            apply.getSubCorrespondingConjunct(), apply.left(), 
newAgg);
                 
}).toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT);
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpProjectUnderApply.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpProjectUnderApply.java
index bca11e2351..a8105775ba 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpProjectUnderApply.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PullUpProjectUnderApply.java
@@ -58,7 +58,7 @@ public class PullUpProjectUnderApply extends 
OneRewriteRuleFactory {
                     LogicalProject<Plan> project = apply.right();
                     LogicalApply newCorrelate = new 
LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),
                                 apply.getCorrelationFilter(), 
apply.getMarkJoinSlotReference(),
-                                apply.getSubCorrespondingConject(), 
apply.left(), project.child());
+                                apply.getSubCorrespondingConjunct(), 
apply.left(), project.child());
                     List<NamedExpression> newProjects = new ArrayList<>();
                     newProjects.addAll(apply.left().getOutput());
                     if (apply.getSubqueryExpr() instanceof ScalarSubquery) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java
index 97a1315825..999f0b3a83 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java
@@ -61,8 +61,8 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
                 (LogicalPlan) apply.right());
         return new LogicalJoin<>(JoinType.CROSS_JOIN,
                 ExpressionUtils.EMPTY_CONDITION,
-                apply.getSubCorrespondingConject().isPresent()
-                    ? ExpressionUtils.extractConjunction((Expression) 
apply.getSubCorrespondingConject().get())
+                apply.getSubCorrespondingConjunct().isPresent()
+                    ? ExpressionUtils.extractConjunction((Expression) 
apply.getSubCorrespondingConjunct().get())
                     : ExpressionUtils.EMPTY_CONDITION,
                 JoinHint.NONE,
                 apply.getMarkJoinSlotReference(),
@@ -86,9 +86,9 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory {
         return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN,
                 ExpressionUtils.EMPTY_CONDITION,
                 ExpressionUtils.extractConjunction(
-                    apply.getSubCorrespondingConject().isPresent()
+                    apply.getSubCorrespondingConjunct().isPresent()
                         ? ExpressionUtils.and(
-                            (Expression) 
apply.getSubCorrespondingConject().get(),
+                            (Expression) 
apply.getSubCorrespondingConjunct().get(),
                             correlationFilter.get())
                         : correlationFilter.get()),
                 JoinHint.NONE,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyAggregateFilter.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyAggregateFilter.java
index ab05f23921..cc6590f7c1 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyAggregateFilter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyAggregateFilter.java
@@ -89,7 +89,7 @@ public class UnCorrelatedApplyAggregateFilter extends 
OneRewriteRuleFactory {
                     apply.getSubqueryExpr(),
                     ExpressionUtils.optionalAnd(correlatedPredicate),
                     apply.getMarkJoinSlotReference(),
-                    apply.getSubCorrespondingConject(),
+                    apply.getSubCorrespondingConjunct(),
                     apply.left(), newAgg);
         }).toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyFilter.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyFilter.java
index 95959a6205..d33897f260 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyFilter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyFilter.java
@@ -69,7 +69,7 @@ public class UnCorrelatedApplyFilter extends 
OneRewriteRuleFactory {
             Plan child = 
PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), 
filter.child());
             return new LogicalApply<>(apply.getCorrelationSlot(), 
apply.getSubqueryExpr(),
                     ExpressionUtils.optionalAnd(correlatedPredicate), 
apply.getMarkJoinSlotReference(),
-                    apply.getSubCorrespondingConject(),
+                    apply.getSubCorrespondingConjunct(),
                     apply.left(), child);
         }).toRule(RuleType.UN_CORRELATED_APPLY_FILTER);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyProjectFilter.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyProjectFilter.java
index c4cff23293..1907a26ced 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyProjectFilter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/UnCorrelatedApplyProjectFilter.java
@@ -90,7 +90,7 @@ public class UnCorrelatedApplyProjectFilter extends 
OneRewriteRuleFactory {
                     LogicalProject newProject = 
project.withProjectsAndChild(projects, child);
                     return new LogicalApply<>(apply.getCorrelationSlot(), 
apply.getSubqueryExpr(),
                             ExpressionUtils.optionalAnd(correlatedPredicate), 
apply.getMarkJoinSlotReference(),
-                            apply.getSubCorrespondingConject(),
+                            apply.getSubCorrespondingConjunct(),
                             apply.left(), newProject);
                 }).toRule(RuleType.UN_CORRELATED_APPLY_PROJECT_FILTER);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
index 92b99ec68e..0394ebea87 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
@@ -222,6 +222,19 @@ public interface TreeNode<NODE_TYPE extends 
TreeNode<NODE_TYPE>> {
         return (T) result.build();
     }
 
+    /**
+     * Collect the nodes that satisfied the predicate to list.
+     */
+    default <T> List<T> collectToList(Predicate<TreeNode<NODE_TYPE>> 
predicate) {
+        ImmutableList.Builder<TreeNode<NODE_TYPE>> result = 
ImmutableList.builder();
+        foreach(node -> {
+            if (predicate.test(node)) {
+                result.add(node);
+            }
+        });
+        return (List<T>) result.build();
+    }
+
     /**
      * iterate top down and test predicate if contains any instance of the 
classes
      * @param types classes array
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
index 4c3a092244..df28b35059 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
@@ -55,7 +55,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
     // The slot replaced by the subquery in MarkJoin
     private final Optional<MarkJoinSlotReference> markJoinSlotReference;
 
-    private final Optional<Expression> subCorrespondingConject;
+    private final Optional<Expression> subCorrespondingConjunct;
 
     /**
      * Constructor.
@@ -65,22 +65,22 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
             List<Expression> correlationSlot,
             SubqueryExpr subqueryExpr, Optional<Expression> correlationFilter,
             Optional<MarkJoinSlotReference> markJoinSlotReference,
-            Optional<Expression> subCorrespondingConject,
+            Optional<Expression> subCorrespondingConjunct,
             LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
         super(PlanType.LOGICAL_APPLY, groupExpression, logicalProperties, 
leftChild, rightChild);
         this.correlationSlot = correlationSlot == null ? ImmutableList.of() : 
ImmutableList.copyOf(correlationSlot);
         this.subqueryExpr = Objects.requireNonNull(subqueryExpr, "subquery can 
not be null");
         this.correlationFilter = correlationFilter;
         this.markJoinSlotReference = markJoinSlotReference;
-        this.subCorrespondingConject = subCorrespondingConject;
+        this.subCorrespondingConjunct = subCorrespondingConjunct;
     }
 
     public LogicalApply(List<Expression> correlationSlot, SubqueryExpr 
subqueryExpr,
             Optional<Expression> correlationFilter, 
Optional<MarkJoinSlotReference> markJoinSlotReference,
-            Optional<Expression> subCorrespondingConject,
+            Optional<Expression> subCorrespondingConjunct,
             LEFT_CHILD_TYPE input, RIGHT_CHILD_TYPE subquery) {
         this(Optional.empty(), Optional.empty(), correlationSlot, subqueryExpr,
-                correlationFilter, markJoinSlotReference, 
subCorrespondingConject, input, subquery);
+                correlationFilter, markJoinSlotReference, 
subCorrespondingConjunct, input, subquery);
     }
 
     public List<Expression> getCorrelationSlot() {
@@ -123,8 +123,8 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
         return markJoinSlotReference;
     }
 
-    public Optional<Expression> getSubCorrespondingConject() {
-        return subCorrespondingConject;
+    public Optional<Expression> getSubCorrespondingConjunct() {
+        return subCorrespondingConjunct;
     }
 
     @Override
@@ -142,7 +142,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
                 "isMarkJoin", markJoinSlotReference.isPresent(),
                 "MarkJoinSlotReference", markJoinSlotReference.isPresent() ? 
markJoinSlotReference.get() : "empty",
                 "scalarSubCorrespondingSlot",
-                subCorrespondingConject.isPresent() ? 
subCorrespondingConject.get() : "empty");
+                subCorrespondingConjunct.isPresent() ? 
subCorrespondingConjunct.get() : "empty");
     }
 
     @Override
@@ -158,13 +158,13 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
                 && Objects.equals(subqueryExpr, that.getSubqueryExpr())
                 && Objects.equals(correlationFilter, 
that.getCorrelationFilter())
                 && Objects.equals(markJoinSlotReference, 
that.getMarkJoinSlotReference())
-                && Objects.equals(subCorrespondingConject, 
that.getSubCorrespondingConject());
+                && Objects.equals(subCorrespondingConjunct, 
that.getSubCorrespondingConjunct());
     }
 
     @Override
     public int hashCode() {
         return Objects.hash(
-                correlationSlot, subqueryExpr, correlationFilter, 
markJoinSlotReference, subCorrespondingConject);
+                correlationSlot, subqueryExpr, correlationFilter, 
markJoinSlotReference, subCorrespondingConjunct);
     }
 
     @Override
@@ -189,7 +189,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
     public LogicalBinary<Plan, Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 2);
         return new LogicalApply<>(correlationSlot, subqueryExpr, 
correlationFilter,
-                markJoinSlotReference, subCorrespondingConject,
+                markJoinSlotReference, subCorrespondingConjunct,
                 children.get(0), children.get(1));
     }
 
@@ -197,13 +197,13 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
     public Plan withGroupExpression(Optional<GroupExpression> groupExpression) 
{
         return new LogicalApply<>(groupExpression, 
Optional.of(getLogicalProperties()),
                 correlationSlot, subqueryExpr, correlationFilter,
-                markJoinSlotReference, subCorrespondingConject, left(), 
right());
+                markJoinSlotReference, subCorrespondingConjunct, left(), 
right());
     }
 
     @Override
     public Plan withLogicalProperties(Optional<LogicalProperties> 
logicalProperties) {
         return new LogicalApply<>(Optional.empty(), logicalProperties,
                 correlationSlot, subqueryExpr, correlationFilter,
-                markJoinSlotReference, subCorrespondingConject, left(), 
right());
+                markJoinSlotReference, subCorrespondingConjunct, left(), 
right());
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
index 1a08fed6dd..f68664bc32 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
@@ -17,10 +17,14 @@
 
 package org.apache.doris.nereids.util;
 
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 
+import com.google.common.collect.Sets;
+
 import java.util.Optional;
 import java.util.Set;
 
@@ -38,4 +42,16 @@ public class PlanUtils {
     public static Plan filterOrSelf(Set<Expression> predicates, Plan plan) {
         return filter(predicates, plan).map(Plan.class::cast).orElse(plan);
     }
+
+    /**
+     * normalize comparison predicate on a binary plan to its two sides are 
corresponding to the child's output.
+     */
+    public static ComparisonPredicate 
maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
+        Set<Slot> slots = expression.left().collect(Slot.class::isInstance);
+        Set<Slot> leftSlots = left.getOutputSet();
+        Set<Slot> buffer = Sets.newHashSet(slots);
+        buffer.removeAll(leftSlots);
+        return buffer.isEmpty() ? expression : expression.commute();
+    }
+
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunctionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunctionTest.java
new file mode 100644
index 0000000000..d63385f656
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggScalarSubQueryToWindowFunctionTest.java
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.datasets.tpch.TPCHTestBase;
+import org.apache.doris.nereids.datasets.tpch.TPCHUtils;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+public class AggScalarSubQueryToWindowFunctionTest extends TPCHTestBase 
implements MemoPatternMatchSupported {
+    private static final String SQL_TEMPLATE = "    select\n"
+            + "        sum(l_extendedprice) / 7.0 as avg_yearly\n"
+            + "    from\n"
+            + "        lineitem,\n"
+            + "        part\n"
+            + "    where\n"
+            + "        p_partkey = l_partkey\n"
+            + "        and p_brand = 'Brand#23'\n"
+            + "        and p_container = 'MED BOX'\n"
+            + "        (p1) (p2)";
+    private static final String SUB_QUERY_TEMPLATE = "(\n"
+            + "            select\n"
+            + "                %s\n"
+            + "            from\n"
+            + "                lineitem\n"
+            + "            where\n"
+            + "                l_partkey = p_partkey\n"
+            + "        )";
+    private static final String AVG = "0.2 * avg(l_quantity)";
+    private static final String MAX = "max(l_quantity) / 2";
+    private static final String MIN = "min(l_extendedprice) * 5";
+
+    private static final String[] queries = {
+            buildSubQuery(AVG),
+            buildSubQuery(MAX),
+            buildSubQuery(MIN)
+    };
+
+    private static String buildFromTemplate(String[] predicate, String[] 
query) {
+        String sql = SQL_TEMPLATE;
+        for (int i = 0; i < predicate.length; ++i) {
+            for (int j = 0; j < query.length; ++j) {
+                predicate[i] = predicate[i].replace(String.format("(q%d)", j + 
1), query[j]);
+            }
+            sql = sql.replace(String.format("(p%d)", i + 1), predicate[i]);
+        }
+        return sql;
+    }
+
+    private static String buildSubQuery(String res) {
+        return String.format(SUB_QUERY_TEMPLATE, res);
+    }
+
+    @Test
+    public void testRuleOnTPCHTest() {
+        check(TPCHUtils.Q2);
+        check(TPCHUtils.Q17);
+    }
+
+    @Disabled
+    @Test
+    public void testComplexPredicates() {
+        // we ensure there's one sub-query in a predicate and in-predicates do 
not contain sub-query,
+        // so we test compound predicates and sub-query in more than one 
predicates
+        // now we disabled them temporarily, and enable when the rule support 
the cases.
+        String[] testCases = {
+                "and l_quantity > 10 or l_quantity < (q1)",
+                "or l_quantity < (q3)",
+                "and l_extendedprice > (q2)",
+        };
+
+        check(buildFromTemplate(new String[] {testCases[0], testCases[1]}, 
queries));
+        check(buildFromTemplate(new String[] {testCases[0], testCases[2]}, 
queries));
+        check(buildFromTemplate(new String[] {testCases[1], testCases[2]}, 
queries));
+    }
+
+    @Test
+    public void testNotMatchTheRule() {
+        String[] testCases = {
+                "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+                        + "    from lineitem, part\n"
+                        + "    where p_partkey = l_partkey\n"
+                        + "        and p_brand = 'Brand#23'\n"
+                        + "        and p_container = 'MED BOX'\n"
+                        + "        and l_quantity < (\n"
+                        + "            select 0.2 * avg(l_quantity)\n"
+                        + "            from lineitem);",
+                "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+                        + "    from lineitem, part\n"
+                        + "    where p_partkey = l_partkey\n"
+                        + "        and p_brand = 'Brand#23'\n"
+                        + "        and p_container = 'MED BOX'\n"
+                        + "        and l_quantity < (\n"
+                        + "            select 0.2 * avg(l_quantity)\n"
+                        + "            from lineitem, part\n"
+                        + "            where l_partkey = p_partkey);",
+                "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+                        + "    from lineitem, part\n"
+                        + "    where p_partkey = l_partkey\n"
+                        + "        and p_brand = 'Brand#23'\n"
+                        + "        and p_container = 'MED BOX'\n"
+                        + "        and l_quantity < (\n"
+                        + "            select 0.2 * avg(l_quantity)\n"
+                        + "            from lineitem, partsupp\n"
+                        + "            where l_partkey = p_partkey);",
+                "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+                        + "    from lineitem, part\n"
+                        + "    where\n"
+                        + "        p_partkey = l_partkey\n"
+                        + "        and p_brand = 'Brand#23'\n"
+                        + "        and p_container = 'MED BOX'\n"
+                        + "        and l_quantity < (\n"
+                        + "            select 0.2 * avg(l_quantity)\n"
+                        + "            from lineitem\n"
+                        + "            where l_partkey = p_partkey\n"
+                        + "            and p_brand = 'Brand#24');",
+                "select sum(l_extendedprice) / 7.0 as avg_yearly\n"
+                        + "    from lineitem, part\n"
+                        + "    where\n"
+                        + "        p_partkey = l_partkey\n"
+                        + "        and p_brand = 'Brand#23'\n"
+                        + "        and p_container = 'MED BOX'\n"
+                        + "        and l_quantity < (\n"
+                        + "            select 0.2 * avg(l_quantity)\n"
+                        + "            from lineitem\n"
+                        + "            where l_partkey = p_partkey\n"
+                        + "            and l_partkey = 10);"
+        };
+        // notice: case 4 and 5 can apply the rule, but we support it later.
+        for (String s : testCases) {
+            checkNot(s);
+        }
+    }
+
+    private void check(String sql) {
+        System.out.printf("Test:\n%s\n\n", sql);
+        Plan plan = PlanChecker.from(createCascadesContext(sql))
+                .analyze(sql)
+                .applyTopDown(new AggScalarSubQueryToWindowFunction())
+                .rewrite()
+                .getPlan();
+        System.out.println(plan.treeString());
+        Assertions.assertTrue(plan.anyMatch(LogicalWindow.class::isInstance));
+    }
+
+    private void checkNot(String sql) {
+        System.out.printf("Test:\n%s\n\n", sql);
+        Plan plan = PlanChecker.from(createCascadesContext(sql))
+                .analyze(sql)
+                .applyTopDown(new AggScalarSubQueryToWindowFunction())
+                .rewrite()
+                .getPlan();
+        System.out.println(plan.treeString());
+        Assertions.assertFalse(plan.anyMatch(LogicalWindow.class::isInstance));
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java
index 7a4c518d28..fe7eb62d73 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownLimitTest.java
@@ -23,7 +23,6 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.LimitPhase;
-import org.apache.doris.nereids.trees.plans.ObjectId;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
@@ -265,8 +264,8 @@ class PushdownLimitTest extends TestWithFeService 
implements MemoPatternMatchSup
         LogicalJoin<? extends Plan, ? extends Plan> join = new LogicalJoin<>(
                 joinType,
                 joinConditions,
-                new LogicalOlapScan(new ObjectId(0), PlanConstructor.score),
-                new LogicalOlapScan(new ObjectId(1), PlanConstructor.student)
+                new LogicalOlapScan(((LogicalOlapScan) scanScore).getId(), 
PlanConstructor.score),
+                new LogicalOlapScan(((LogicalOlapScan) scanStudent).getId(), 
PlanConstructor.student)
         );
 
         if (hasProject) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to