This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new 8f580b523f3 [opt](nereids) support partitionTopn for multi window 
exprs (#39687)
8f580b523f3 is described below

commit 8f580b523f3dc8a1c9b3a8b680c3a60b1bdb950c
Author: xzj7019 <[email protected]>
AuthorDate: Thu Aug 22 10:34:36 2024 +0800

    [opt](nereids) support partitionTopn for multi window exprs (#39687)
    
    ## Proposed changes
    
    pick from https://github.com/apache/doris/pull/38393
    
    Co-authored-by: xiongzhongjian <[email protected]>
---
 .../rewrite/CreatePartitionTopNFromWindow.java     | 101 ++---------
 .../doris/nereids/rules/rewrite/PushDownLimit.java |  27 ++-
 .../rules/rewrite/PushDownTopNThroughWindow.java   |  20 ++-
 .../nereids/trees/plans/logical/LogicalWindow.java | 196 +++++++++++++++++----
 .../push_down_multi_filter_through_window.groovy   | 160 +++++++++++++++++
 5 files changed, 361 insertions(+), 143 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CreatePartitionTopNFromWindow.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CreatePartitionTopNFromWindow.java
index 8a4d7a42d3d..834367cdec6 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CreatePartitionTopNFromWindow.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CreatePartitionTopNFromWindow.java
@@ -17,32 +17,16 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.BinaryOperator;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.ExprId;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.LessThan;
-import org.apache.doris.nereids.trees.expressions.LessThanEqual;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.WindowExpression;
-import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
 import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
 
-import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableSet;
-
-import java.util.List;
-import java.util.Optional;
-import java.util.Set;
-import java.util.function.Predicate;
-
 /**
  * Push down the 'partitionTopN' into the 'window'.
  * It will convert the filter condition to the 'limit value' and push down 
below the 'window'.
@@ -89,82 +73,17 @@ public class CreatePartitionTopNFromWindow extends 
OneRewriteRuleFactory {
                 return filter;
             }
 
-            List<NamedExpression> windowExprs = window.getWindowExpressions();
-            if (windowExprs.size() != 1) {
-                return filter;
-            }
-            NamedExpression windowExpr = windowExprs.get(0);
-            if (windowExpr.children().size() != 1 || !(windowExpr.child(0) 
instanceof WindowExpression)) {
-                return filter;
-            }
-
-            // Check the filter conditions. Now, we currently only support 
simple conditions of the form
-            // 'column </ <=/ = constant'. We will extract some related 
conjuncts and do some check.
-            Set<Expression> conjuncts = filter.getConjuncts();
-            Set<Expression> relatedConjuncts = 
extractRelatedConjuncts(conjuncts, windowExpr.getExprId());
-
-            boolean hasPartitionLimit = false;
-            long partitionLimit = Long.MAX_VALUE;
-
-            for (Expression conjunct : relatedConjuncts) {
-                Preconditions.checkArgument(conjunct instanceof 
BinaryOperator);
-                BinaryOperator op = (BinaryOperator) conjunct;
-                Expression leftChild = op.children().get(0);
-                Expression rightChild = op.children().get(1);
-
-                Preconditions.checkArgument(leftChild instanceof SlotReference
-                        && rightChild instanceof IntegerLikeLiteral);
-
-                long limitVal = ((IntegerLikeLiteral) 
rightChild).getLongValue();
-                // Adjust the value for 'limitVal' based on the comparison 
operators.
-                if (conjunct instanceof LessThan) {
-                    limitVal--;
-                }
-                if (limitVal < 0) {
-                    return new 
LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), 
filter.getOutput());
-                }
-                if (hasPartitionLimit) {
-                    partitionLimit = Math.min(partitionLimit, limitVal);
-                } else {
-                    partitionLimit = limitVal;
-                    hasPartitionLimit = true;
-                }
-            }
-
-            if (!hasPartitionLimit) {
-                return filter;
-            }
-
-            Optional<Plan> newWindow = 
window.pushPartitionLimitThroughWindow(partitionLimit, false);
-            if (!newWindow.isPresent()) {
+            Pair<WindowExpression, Long> windowFuncPair = 
window.getPushDownWindowFuncAndLimit(filter, Long.MAX_VALUE);
+            if (windowFuncPair == null) {
                 return filter;
+            } else if (windowFuncPair.second == -1) {
+                // limit -1 indicating a empty relation case
+                return new 
LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), 
filter.getOutput());
+            } else {
+                Plan newWindow = 
window.pushPartitionLimitThroughWindow(windowFuncPair.first,
+                        windowFuncPair.second, false);
+                return filter.withChildren(newWindow);
             }
-            return filter.withChildren(newWindow.get());
         }).toRule(RuleType.CREATE_PARTITION_TOPN_FOR_WINDOW);
     }
-
-    private Set<Expression> extractRelatedConjuncts(Set<Expression> conjuncts, 
ExprId slotRefID) {
-        Predicate<Expression> condition = conjunct -> {
-            if (!(conjunct instanceof BinaryOperator)) {
-                return false;
-            }
-            BinaryOperator op = (BinaryOperator) conjunct;
-            Expression leftChild = op.children().get(0);
-            Expression rightChild = op.children().get(1);
-
-            if (!(conjunct instanceof LessThan || conjunct instanceof 
LessThanEqual || conjunct instanceof EqualTo)) {
-                return false;
-            }
-
-            // TODO: Now, we only support the column on the left side.
-            if (!(leftChild instanceof SlotReference) || !(rightChild 
instanceof IntegerLikeLiteral)) {
-                return false;
-            }
-            return ((SlotReference) leftChild).getExprId() == slotRefID;
-        };
-
-        return conjuncts.stream()
-                .filter(condition)
-                .collect(ImmutableSet.toImmutableSet());
-    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownLimit.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownLimit.java
index 69a7faa6aa7..68cf03d7f5b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownLimit.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownLimit.java
@@ -17,8 +17,10 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.algebra.Limit;
 import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
@@ -31,7 +33,6 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
 import com.google.common.collect.ImmutableList;
 
 import java.util.List;
-import java.util.Optional;
 
 /**
  * Rules to push {@link 
org.apache.doris.nereids.trees.plans.logical.LogicalLimit} down.
@@ -72,11 +73,17 @@ public class PushDownLimit implements RewriteRuleFactory {
                         .then(limit -> {
                             LogicalWindow<Plan> window = limit.child();
                             long partitionLimit = limit.getLimit() + 
limit.getOffset();
-                            Optional<Plan> newWindow = 
window.pushPartitionLimitThroughWindow(partitionLimit, true);
-                            if (!newWindow.isPresent()) {
+                            if (partitionLimit <= 0) {
                                 return limit;
                             }
-                            return limit.withChildren(newWindow.get());
+                            Pair<WindowExpression, Long> windowFuncLongPair = 
window
+                                    .getPushDownWindowFuncAndLimit(null, 
partitionLimit);
+                            if (windowFuncLongPair == null) {
+                                return limit;
+                            }
+                            Plan newWindow = 
window.pushPartitionLimitThroughWindow(windowFuncLongPair.first,
+                                    windowFuncLongPair.second, true);
+                            return limit.withChildren(newWindow);
                         }).toRule(RuleType.PUSH_LIMIT_THROUGH_WINDOW),
 
                 // limit -> project -> window
@@ -85,11 +92,17 @@ public class PushDownLimit implements RewriteRuleFactory {
                             LogicalProject<LogicalWindow<Plan>> project = 
limit.child();
                             LogicalWindow<Plan> window = project.child();
                             long partitionLimit = limit.getLimit() + 
limit.getOffset();
-                            Optional<Plan> newWindow = 
window.pushPartitionLimitThroughWindow(partitionLimit, true);
-                            if (!newWindow.isPresent()) {
+                            if (partitionLimit <= 0) {
+                                return limit;
+                            }
+                            Pair<WindowExpression, Long> windowFuncLongPair = 
window
+                                    .getPushDownWindowFuncAndLimit(null, 
partitionLimit);
+                            if (windowFuncLongPair == null) {
                                 return limit;
                             }
-                            return 
limit.withChildren(project.withChildren(newWindow.get()));
+                            Plan newWindow = 
window.pushPartitionLimitThroughWindow(windowFuncLongPair.first,
+                                    windowFuncLongPair.second, true);
+                            return 
limit.withChildren(project.withChildren(newWindow));
                         }).toRule(RuleType.PUSH_LIMIT_THROUGH_PROJECT_WINDOW),
 
                 // limit -> union
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughWindow.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughWindow.java
index 7a0eb068873..8dc4cd6f73a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughWindow.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNThroughWindow.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.properties.OrderKey;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
@@ -33,7 +34,6 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
 import com.google.common.collect.ImmutableList;
 
 import java.util.List;
-import java.util.Optional;
 
 /**
  * PushdownTopNThroughWindow push down the TopN through the Window and 
generate the PartitionTopN.
@@ -54,11 +54,14 @@ public class PushDownTopNThroughWindow implements 
RewriteRuleFactory {
                     return topn;
                 }
                 long partitionLimit = topn.getLimit() + topn.getOffset();
-                Optional<Plan> newWindow = 
window.pushPartitionLimitThroughWindow(partitionLimit, true);
-                if (!newWindow.isPresent()) {
+                Pair<WindowExpression, Long> windowFuncLongPair = window
+                        .getPushDownWindowFuncAndLimit(null, partitionLimit);
+                if (windowFuncLongPair == null) {
                     return topn;
                 }
-                return topn.withChildren(newWindow.get());
+                Plan newWindow = 
window.pushPartitionLimitThroughWindow(windowFuncLongPair.first,
+                        windowFuncLongPair.second, true);
+                return topn.withChildren(newWindow);
             }).toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_WINDOW),
 
             // topn -> projection -> window
@@ -74,11 +77,14 @@ public class PushDownTopNThroughWindow implements 
RewriteRuleFactory {
                     return topn;
                 }
                 long partitionLimit = topn.getLimit() + topn.getOffset();
-                Optional<Plan> newWindow = 
window.pushPartitionLimitThroughWindow(partitionLimit, true);
-                if (!newWindow.isPresent()) {
+                Pair<WindowExpression, Long> windowFuncLongPair = window
+                        .getPushDownWindowFuncAndLimit(null, partitionLimit);
+                if (windowFuncLongPair == null) {
                     return topn;
                 }
-                return 
topn.withChildren(project.withChildren(newWindow.get()));
+                Plan newWindow = 
window.pushPartitionLimitThroughWindow(windowFuncLongPair.first,
+                        windowFuncLongPair.second, true);
+                return topn.withChildren(project.withChildren(newWindow));
             }).toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_PROJECT_WINDOW)
         );
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalWindow.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalWindow.java
index ed99c265161..fceb36e677f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalWindow.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalWindow.java
@@ -17,19 +17,27 @@
 
 package org.apache.doris.nereids.trees.plans.logical;
 
+import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.memo.GroupExpression;
 import org.apache.doris.nereids.properties.FdItem;
 import org.apache.doris.nereids.properties.FunctionalDependencies;
 import org.apache.doris.nereids.properties.FunctionalDependencies.Builder;
 import org.apache.doris.nereids.properties.LogicalProperties;
+import org.apache.doris.nereids.trees.expressions.BinaryOperator;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.LessThan;
+import org.apache.doris.nereids.trees.expressions.LessThanEqual;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.WindowExpression;
 import org.apache.doris.nereids.trees.expressions.WindowFrame;
 import org.apache.doris.nereids.trees.expressions.functions.window.DenseRank;
 import org.apache.doris.nereids.trees.expressions.functions.window.Rank;
 import org.apache.doris.nereids.trees.expressions.functions.window.RowNumber;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.PlanType;
 import org.apache.doris.nereids.trees.plans.algebra.Window;
@@ -44,6 +52,8 @@ import com.google.common.collect.ImmutableSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.Set;
+import java.util.function.Predicate;
 
 /**
  * logical node to deal with window functions;
@@ -171,61 +181,171 @@ public class LogicalWindow<CHILD_TYPE extends Plan> 
extends LogicalUnary<CHILD_T
     }
 
     /**
-     * pushPartitionLimitThroughWindow is used to push the partitionLimit 
through the window
-     * and generate the partitionTopN. If the window can not meet the 
requirement,
-     * it will return null. So when we use this function, we need check the 
null in the outside.
+     * Get push down window function candidate and corresponding partition 
limit.
+     *
+     * @param filter
+     *              For partition topN filter cases, it means the topN filter;
+     *              For partition limit cases, it will be null.
+     * @param partitionLimit
+     *              For partition topN filter cases, it means the filter 
boundary,
+     *                  e.g, 100 for the case rn <= 100;
+     *              For partition limit cases, it means the limit.
+     * @return
+     *              Return null means invalid cases or the opt option is 
disabled,
+     *              else return the chosen window function and the chosen 
partition limit.
+     *              A special limit -1 means the case can be optimized as 
empty relation.
      */
-    public Optional<Plan> pushPartitionLimitThroughWindow(long partitionLimit, 
boolean hasGlobalLimit) {
+    public Pair<WindowExpression, Long> 
getPushDownWindowFuncAndLimit(LogicalFilter<?> filter, long partitionLimit) {
         if 
(!ConnectContext.get().getSessionVariable().isEnablePartitionTopN()) {
-            return Optional.empty();
+            return null;
         }
         // We have already done such optimization rule, so just ignore it.
-        if (child(0) instanceof LogicalPartitionTopN) {
-            return Optional.empty();
+        if (child(0) instanceof LogicalPartitionTopN
+                || (child(0) instanceof LogicalFilter
+                && child(0).child(0) != null
+                && child(0).child(0) instanceof LogicalPartitionTopN)) {
+            return null;
         }
 
         // Check the window function. There are some restrictions for window 
function:
-        // 1. The number of window function should be 1.
-        // 2. The window function should be one of the 'row_number()', 
'rank()', 'dense_rank()'.
-        // 3. The window frame should be 'UNBOUNDED' to 'CURRENT'.
-        // 4. The 'PARTITION' key and 'ORDER' key can not be empty at the same 
time.
-        if (windowExpressions.size() != 1) {
-            return Optional.empty();
-        }
-        NamedExpression windowExpr = windowExpressions.get(0);
-        if (windowExpr.children().size() != 1 || !(windowExpr.child(0) 
instanceof WindowExpression)) {
-            return Optional.empty();
-        }
+        // 1. The window function should be one of the 'row_number()', 
'rank()', 'dense_rank()'.
+        // 2. The window frame should be 'UNBOUNDED' to 'CURRENT'.
+        // 3. The 'PARTITION' key and 'ORDER' key can not be empty at the same 
time.
+        WindowExpression chosenWindowFunc = null;
+        long chosenPartitionLimit = Long.MAX_VALUE;
+        long chosenRowNumberPartitionLimit = Long.MAX_VALUE;
+        boolean hasRowNumber = false;
+        for (NamedExpression windowExpr : windowExpressions) {
+            if (windowExpr == null || windowExpr.children().size() != 1
+                    || !(windowExpr.child(0) instanceof WindowExpression)) {
+                continue;
+            }
+            WindowExpression windowFunc = (WindowExpression) 
windowExpr.child(0);
 
-        WindowExpression windowFunc = (WindowExpression) windowExpr.child(0);
-        // Check the window function name.
-        if (!(windowFunc.getFunction() instanceof RowNumber
-                || windowFunc.getFunction() instanceof Rank
-                || windowFunc.getFunction() instanceof DenseRank)) {
-            return Optional.empty();
-        }
+            // Check the window function name.
+            if (!(windowFunc.getFunction() instanceof RowNumber
+                    || windowFunc.getFunction() instanceof Rank
+                    || windowFunc.getFunction() instanceof DenseRank)) {
+                continue;
+            }
 
-        // Check the partition key and order key.
-        if (windowFunc.getPartitionKeys().isEmpty() && 
windowFunc.getOrderKeys().isEmpty()) {
-            return Optional.empty();
-        }
+            // Check the partition key and order key.
+            if (windowFunc.getPartitionKeys().isEmpty() && 
windowFunc.getOrderKeys().isEmpty()) {
+                continue;
+            }
 
-        // Check the window type and window frame.
-        Optional<WindowFrame> windowFrame = windowFunc.getWindowFrame();
-        if (windowFrame.isPresent()) {
-            WindowFrame frame = windowFrame.get();
-            if (!(frame.getLeftBoundary().getFrameBoundType() == 
WindowFrame.FrameBoundType.UNBOUNDED_PRECEDING
-                    && frame.getRightBoundary().getFrameBoundType() == 
WindowFrame.FrameBoundType.CURRENT_ROW)) {
-                return Optional.empty();
+            // Check the window type and window frame.
+            Optional<WindowFrame> windowFrame = windowFunc.getWindowFrame();
+            if (windowFrame.isPresent()) {
+                WindowFrame frame = windowFrame.get();
+                if (!(frame.getLeftBoundary().getFrameBoundType() == 
WindowFrame.FrameBoundType.UNBOUNDED_PRECEDING
+                        && frame.getRightBoundary().getFrameBoundType() == 
WindowFrame.FrameBoundType.CURRENT_ROW)) {
+                    continue;
+                }
+            } else {
+                continue;
             }
+
+            // Check filter conditions.
+            if (filter != null) {
+                // We currently only support simple conditions of the form 
'column </ <=/ = constant'.
+                // We will extract some related conjuncts and do some check.
+                boolean hasPartitionLimit = false;
+                long curPartitionLimit = Long.MAX_VALUE;
+                Set<Expression> conjuncts = filter.getConjuncts();
+                Set<Expression> relatedConjuncts = 
extractRelatedConjuncts(conjuncts, windowExpr.getExprId());
+                for (Expression conjunct : relatedConjuncts) {
+                    // Pre-checking has been done in former extraction
+                    BinaryOperator op = (BinaryOperator) conjunct;
+                    Expression rightChild = op.children().get(1);
+                    long limitVal = ((IntegerLikeLiteral) 
rightChild).getLongValue();
+                    // Adjust the value for 'limitVal' based on the comparison 
operators.
+                    if (conjunct instanceof LessThan) {
+                        limitVal--;
+                    }
+                    if (limitVal < 0) {
+                        // Set return limit value as -1 for indicating a empty 
relation opt case
+                        chosenPartitionLimit = -1;
+                        chosenRowNumberPartitionLimit = -1;
+                        break;
+                    }
+                    if (hasPartitionLimit) {
+                        curPartitionLimit = Math.min(curPartitionLimit, 
limitVal);
+                    } else {
+                        curPartitionLimit = limitVal;
+                        hasPartitionLimit = true;
+                    }
+                }
+                if (chosenPartitionLimit == -1) {
+                    chosenWindowFunc = windowFunc;
+                    break;
+                } else if (windowFunc.getFunction() instanceof RowNumber) {
+                    // choose row_number first any way
+                    // if multiple exists, choose the one with minimal limit 
value
+                    if (curPartitionLimit < chosenRowNumberPartitionLimit) {
+                        chosenRowNumberPartitionLimit = curPartitionLimit;
+                        chosenWindowFunc = windowFunc;
+                        hasRowNumber = true;
+                    }
+                } else if (!hasRowNumber) {
+                    // if no row_number, choose the one with minimal limit 
value
+                    if (curPartitionLimit < chosenPartitionLimit) {
+                        chosenPartitionLimit = curPartitionLimit;
+                        chosenWindowFunc = windowFunc;
+                    }
+                }
+            } else {
+                // limit
+                chosenWindowFunc = windowFunc;
+                chosenPartitionLimit = partitionLimit;
+                if (windowFunc.getFunction() instanceof RowNumber) {
+                    break;
+                }
+            }
+        }
+        if (chosenWindowFunc == null || (chosenPartitionLimit == Long.MAX_VALUE
+                && chosenRowNumberPartitionLimit == Long.MAX_VALUE)) {
+            return null;
         } else {
-            return Optional.empty();
+            return Pair.of(chosenWindowFunc, hasRowNumber ? 
chosenRowNumberPartitionLimit : chosenPartitionLimit);
         }
+    }
 
+    /**
+     * pushPartitionLimitThroughWindow is used to push the partitionLimit 
through the window
+     * and generate the partitionTopN. If the window can not meet the 
requirement,
+     * it will return null. So when we use this function, we need check the 
null in the outside.
+     */
+    public Plan pushPartitionLimitThroughWindow(WindowExpression windowFunc,
+            long partitionLimit, boolean hasGlobalLimit) {
         LogicalWindow<?> window = (LogicalWindow<?>) withChildren(new 
LogicalPartitionTopN<>(windowFunc, hasGlobalLimit,
                 partitionLimit, child(0)));
+        return window;
+    }
 
-        return Optional.ofNullable(window);
+    private Set<Expression> extractRelatedConjuncts(Set<Expression> conjuncts, 
ExprId slotRefID) {
+        Predicate<Expression> condition = conjunct -> {
+            if (!(conjunct instanceof BinaryOperator)) {
+                return false;
+            }
+            BinaryOperator op = (BinaryOperator) conjunct;
+            Expression leftChild = op.children().get(0);
+            Expression rightChild = op.children().get(1);
+
+            if (!(conjunct instanceof LessThan || conjunct instanceof 
LessThanEqual || conjunct instanceof EqualTo)) {
+                return false;
+            }
+
+            // TODO: Now, we only support the column on the left side.
+            if (!(leftChild instanceof SlotReference) || !(rightChild 
instanceof IntegerLikeLiteral)) {
+                return false;
+            }
+            return ((SlotReference) leftChild).getExprId() == slotRefID;
+        };
+
+        return conjuncts.stream()
+                .filter(condition)
+                .collect(ImmutableSet.toImmutableSet());
     }
 
     private boolean isUnique(NamedExpression namedExpression) {
diff --git 
a/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_multi_filter_through_window.groovy
 
b/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_multi_filter_through_window.groovy
new file mode 100644
index 00000000000..d808d30f8eb
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_multi_filter_through_window.groovy
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("push_down_multi_filter_through_window") {
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+    sql "set ignore_shape_nodes='PhysicalDistribute'"
+    sql "drop table if exists push_down_multi_predicate_through_window_t"
+    multi_sql """
+    CREATE TABLE push_down_multi_predicate_through_window_t (c1 INT, c2 INT, 
c3 VARCHAR(50)) properties("replication_num"="1");
+    INSERT INTO push_down_multi_predicate_through_window_t (c1, c2, c3) 
VALUES(1, 10, 'A'),(2, 20, 'B'),(3, 30, 'C'),(4, 40, 'D');
+    """
+    explain {
+        sql ("select * from (select row_number() over(partition by c1, c2 
order by c3) as rn from push_down_multi_predicate_through_window_t) t where rn 
<= 1;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1, c2 order by 
c3) as rk from push_down_multi_predicate_through_window_t) t where rk <= 1;")
+        contains "VPartitionTopN"
+        contains "functions: rank"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1, c2 order by 
c3) as rk from push_down_multi_predicate_through_window_t) t where rk > 1;")
+        notContains "VPartitionTopN"
+    }
+
+    explain {
+        sql ("select * from (select row_number() over(partition by c1, c2 
order by c3) as rn from push_down_multi_predicate_through_window_t) t where rn 
> 1;")
+        notContains "VPartitionTopN"
+    }
+
+    explain {
+        sql ("select * from (select row_number() over(partition by c1, c2 
order by c3) as rn, rank() over(partition by c1 order by c3) as rk from 
push_down_multi_predicate_through_window_t) t where rn <= 1 and rk <= 1;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1 order by c3) 
as rk, row_number() over(partition by c1, c2 order by c3) as rn from 
push_down_multi_predicate_through_window_t) t where rn <= 10 and rk <= 1;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 10"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1 order by c3) 
as rk, row_number() over(partition by c1, c2 order by c3) as rn from 
push_down_multi_predicate_through_window_t) t where rk <= 1;")
+        contains "VPartitionTopN"
+        contains "functions: rank"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1 order by c3) 
as rk, row_number() over(partition by c1, c2 order by c3) as rn from 
push_down_multi_predicate_through_window_t) t where rn <= 10;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 10"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1 order by c3) 
as rk, rank() over(partition by c1, c2 order by c3) as rn from 
push_down_multi_predicate_through_window_t) t where rn <= 1 and rk <= 10;")
+        contains "VPartitionTopN"
+        contains "functions: rank"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1 order by c3) 
as rk, rank() over(partition by c1, c2 order by c3) as rn from 
push_down_multi_predicate_through_window_t) t where rn <= 10 and rk <= 1;")
+        contains "VPartitionTopN"
+        contains "functions: rank"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1 order by c3) 
as rk, rank() over(partition by c1, c2 order by c3) as rn from 
push_down_multi_predicate_through_window_t) t where rn > 1 and rk <= 1;")
+        contains "VPartitionTopN"
+        contains "functions: rank"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1 order by c3) 
as rk, rank() over(partition by c1, c2 order by c3) as rn from 
push_down_multi_predicate_through_window_t) t where rn <= 1 and rk > 1;")
+        contains "VPartitionTopN"
+        contains "functions: rank"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select row_number() over(partition by c1, c2 
order by c3) as rn, rank() over(partition by c1 order by c3) as rk from 
push_down_multi_predicate_through_window_t) t limit 10;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 10"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1, c2 order by 
c3) as rn, rank() over(partition by c1 order by c3) as rk from 
push_down_multi_predicate_through_window_t) t limit 10;")
+        contains "VPartitionTopN"
+        contains "functions: rank"
+        contains "partition limit: 10"
+    }
+
+    explain {
+        sql ("select * from (select row_number() over(partition by c1, c2 
order by c3) as rn, row_number() over(partition by c1 order by c3) as rk from 
push_down_multi_predicate_through_window_t) t where rn <= 10 and rk <= 1;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select row_number() over(partition by c1, c2 
order by c3) as rn, row_number() over(partition by c1 order by c3) as rk from 
push_down_multi_predicate_through_window_t) t where rn <= 1 and rk <= 10;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select row_number() over(partition by c1, c2 
order by c3) as rn, rank() over(partition by c1 order by c3) as rk1, rank() 
over(partition by c2 order by c3) as rk2 from 
push_down_multi_predicate_through_window_t) t where rn <= 1 and rk1 <= 10 and 
rk2 <= 100;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select row_number() over(partition by c1 order by 
c3) as rn1, row_number() over(partition by c2 order by c3) as rn2, rank() 
over(partition by c1, c2 order by c3) as rk from 
push_down_multi_predicate_through_window_t) t where rn1 <= 10 and rn2 <= 1 and 
rk <= 100;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select rank() over(partition by c1, c2 order by 
c3) as rk, row_number() over(partition by c1 order by c3) as rn1, row_number() 
over(partition by c2 order by c3) as rn2 from 
push_down_multi_predicate_through_window_t) t where rn1 <= 1 and rn2 <= 10 and 
rk <= 100;")
+        contains "VPartitionTopN"
+        contains "functions: row_number"
+        contains "partition limit: 1"
+    }
+
+    explain {
+        sql ("select * from (select row_number() over(partition by c1, c2 
order by c3) as rn, rank() over(partition by c1 order by c3) as rk from 
push_down_multi_predicate_through_window_t) t where rn <= 1 or rk <= 1;")
+        notContains "VPartitionTopN"
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to