This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 0865f74aa14 [nereids](topn-filter) support multi-topn filter (FE part) 
(#31485)
0865f74aa14 is described below

commit 0865f74aa14ea8eba9f64f2b394c72ac2579afb4
Author: minghong <engle...@gmail.com>
AuthorDate: Mon Mar 11 10:43:30 2024 +0800

    [nereids](topn-filter) support multi-topn filter (FE part) (#31485)
    
    support multi-topn-filter
---
 .../org/apache/doris/nereids/CascadesContext.java  |   6 ++
 .../glue/translator/PhysicalPlanTranslator.java    |  39 +++++--
 .../glue/translator/PlanTranslatorContext.java     |   9 +-
 .../doris/nereids/processor/post/TopNScanOpt.java  | 117 ++++++++------------
 .../nereids/processor/post/TopnFilterContext.java  |  93 ++++++++++++++++
 .../nereids/trees/plans/physical/PhysicalTopN.java |  39 +++----
 .../org/apache/doris/planner/OlapScanNode.java     |  25 ++++-
 .../java/org/apache/doris/planner/SortNode.java    |   2 +-
 .../main/java/org/apache/doris/qe/Coordinator.java |  11 +-
 .../nereids/postprocess/TopNRuntimeFilterTest.java |  24 ++++-
 .../data/nereids_tpch_p0/tpch/topn-filter.out      |  29 +++++
 .../suites/nereids_tpch_p0/tpch/topn-filter.groovy | 120 +++++++++++++++++++++
 regression-test/suites/point_query_p0/load.groovy  |   4 +-
 13 files changed, 400 insertions(+), 118 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java
index 6e5b5639966..8e4a47938e4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java
@@ -42,6 +42,7 @@ import 
org.apache.doris.nereids.jobs.scheduler.SimpleJobScheduler;
 import org.apache.doris.nereids.memo.Group;
 import org.apache.doris.nereids.memo.Memo;
 import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
+import org.apache.doris.nereids.processor.post.TopnFilterContext;
 import org.apache.doris.nereids.properties.PhysicalProperties;
 import org.apache.doris.nereids.rules.RuleFactory;
 import org.apache.doris.nereids.rules.RuleSet;
@@ -110,6 +111,7 @@ public class CascadesContext implements ScheduleContext {
     // subqueryExprIsAnalyzed: whether the subquery has been analyzed.
     private final Map<SubqueryExpr, Boolean> subqueryExprIsAnalyzed;
     private final RuntimeFilterContext runtimeFilterContext;
+    private final TopnFilterContext topnFilterContext = new 
TopnFilterContext();
     private Optional<Scope> outerScope = Optional.empty();
     private Map<Long, TableIf> tables = null;
 
@@ -283,6 +285,10 @@ public class CascadesContext implements ScheduleContext {
         return runtimeFilterContext;
     }
 
+    public TopnFilterContext getTopnFilterContext() {
+        return topnFilterContext;
+    }
+
     public void setCurrentJobContext(JobContext currentJobContext) {
         this.currentJobContext = currentJobContext;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index e366521464c..86db08925db 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -209,6 +209,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.Set;
 import java.util.TreeMap;
 import java.util.stream.Collectors;
@@ -741,6 +742,10 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 )
         );
         
olapScanNode.setPushDownAggNoGrouping(context.getRelationPushAggOp(olapScan.getRelationId()));
+        if (context.getTopnFilterContext().isTopnFilterTarget(olapScan)) {
+            olapScanNode.setUseTopnOpt(true);
+            context.getTopnFilterContext().addLegacyTarget(olapScan, 
olapScanNode);
+        }
         // TODO: we need to remove all finalizeForNereids
         olapScanNode.finalizeForNereids();
         // Create PlanFragment
@@ -764,6 +769,10 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
             PhysicalDeferMaterializeOlapScan deferMaterializeOlapScan, 
PlanTranslatorContext context) {
         PlanFragment planFragment = 
visitPhysicalOlapScan(deferMaterializeOlapScan.getPhysicalOlapScan(), context);
         OlapScanNode olapScanNode = (OlapScanNode) planFragment.getPlanRoot();
+        if 
(context.getTopnFilterContext().isTopnFilterTarget(deferMaterializeOlapScan)) {
+            olapScanNode.setUseTopnOpt(true);
+            
context.getTopnFilterContext().addLegacyTarget(deferMaterializeOlapScan, 
olapScanNode);
+        }
         TupleDescriptor tupleDescriptor = 
context.getTupleDesc(olapScanNode.getTupleId());
         for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) {
             if (deferMaterializeOlapScan.getDeferMaterializeSlotIds()
@@ -2026,20 +2035,23 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     public PlanFragment visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, 
PlanTranslatorContext context) {
         PlanFragment inputFragment = topN.child(0).accept(this, context);
         List<List<Expr>> distributeExprLists = 
getDistributeExprs(topN.child(0));
-
         // 2. According to the type of sort, generate physical plan
         if (!topN.getSortPhase().isMerge()) {
             // For localSort or Gather->Sort, we just need to add TopNNode
             SortNode sortNode = translateSortNode(topN, 
inputFragment.getPlanRoot(), context);
             sortNode.setOffset(topN.getOffset());
             sortNode.setLimit(topN.getLimit());
-            if (topN.isEnableRuntimeFilter()) {
+            if (context.getTopnFilterContext().isTopnFilterSource(topN)) {
                 sortNode.setUseTopnOpt(true);
-                PlanNode child = sortNode.getChild(0);
-                Preconditions.checkArgument(child instanceof OlapScanNode,
-                        "topN opt expect OlapScanNode, but we get " + child);
-                OlapScanNode scanNode = ((OlapScanNode) child);
-                scanNode.setUseTopnOpt(true);
+                context.getTopnFilterContext().getTargets(topN).forEach(
+                        olapScan -> {
+                            Optional<OlapScanNode> legacyScan =
+                                    
context.getTopnFilterContext().getLegacyScanNode(olapScan);
+                            Preconditions.checkState(legacyScan.isPresent(),
+                                    "cannot find OlapScanNode for topn 
filter");
+                            legacyScan.get().addTopnFilterSortNode(sortNode);
+                        }
+                );
             }
             // push sort to scan opt
             if (sortNode.getChild(0) instanceof OlapScanNode) {
@@ -2084,12 +2096,23 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
     @Override
     public PlanFragment 
visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? extends Plan> 
topN,
             PlanTranslatorContext context) {
-
         PlanFragment planFragment = visitPhysicalTopN(topN.getPhysicalTopN(), 
context);
         if (planFragment.getPlanRoot() instanceof SortNode) {
             SortNode sortNode = (SortNode) planFragment.getPlanRoot();
             sortNode.setUseTwoPhaseReadOpt(true);
             sortNode.getSortInfo().setUseTwoPhaseRead();
+            if (context.getTopnFilterContext().isTopnFilterSource(topN)) {
+                sortNode.setUseTopnOpt(true);
+                context.getTopnFilterContext().getTargets(topN).forEach(
+                        olapScan -> {
+                            Optional<OlapScanNode> legacyScan =
+                                    
context.getTopnFilterContext().getLegacyScanNode(olapScan);
+                            Preconditions.checkState(legacyScan.isPresent(),
+                                    "cannot find OlapScanNode for topn 
filter");
+                            legacyScan.get().addTopnFilterSortNode(sortNode);
+                        }
+                );
+            }
             TupleDescriptor tupleDescriptor = 
sortNode.getSortInfo().getSortTupleDescriptor();
             for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) {
                 if (topN.getDeferMaterializeSlotIds()
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java
index 8a723b1fd1f..90539332791 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java
@@ -29,6 +29,7 @@ import org.apache.doris.catalog.Column;
 import org.apache.doris.catalog.TableIf;
 import org.apache.doris.common.IdGenerator;
 import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.processor.post.TopnFilterContext;
 import org.apache.doris.nereids.trees.expressions.CTEId;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
@@ -70,7 +71,7 @@ public class PlanTranslatorContext {
     private final DescriptorTable descTable = new DescriptorTable();
 
     private final RuntimeFilterTranslator translator;
-
+    private final TopnFilterContext topnFilterContext;
     /**
      * index from Nereids' slot to legacy slot.
      */
@@ -115,12 +116,14 @@ public class PlanTranslatorContext {
     public PlanTranslatorContext(CascadesContext ctx) {
         this.connectContext = ctx.getConnectContext();
         this.translator = new 
RuntimeFilterTranslator(ctx.getRuntimeFilterContext());
+        this.topnFilterContext = ctx.getTopnFilterContext();
     }
 
     @VisibleForTesting
     public PlanTranslatorContext() {
         this.connectContext = null;
         this.translator = null;
+        this.topnFilterContext = new TopnFilterContext();
     }
 
     /**
@@ -187,6 +190,10 @@ public class PlanTranslatorContext {
         return Optional.ofNullable(translator);
     }
 
+    public TopnFilterContext getTopnFilterContext() {
+        return topnFilterContext;
+    }
+
     public PlanFragmentId nextFragmentId() {
         return fragmentIdGenerator.getNextId();
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
index fab4e34a12f..a9425cb715b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
@@ -19,103 +19,57 @@ package org.apache.doris.nereids.processor.post;
 
 import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.SortPhase;
-import org.apache.doris.nereids.trees.plans.algebra.Filter;
 import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
-import org.apache.doris.nereids.trees.plans.algebra.Project;
-import org.apache.doris.nereids.trees.plans.algebra.TopN;
-import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalCatalogRelation;
 import 
org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalSink;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
 import org.apache.doris.qe.ConnectContext;
 
+import java.util.Optional;
 /**
  * topN opt
  * refer to:
  * <a href="https://github.com/apache/doris/pull/15558";>...</a>
  * <a href="https://github.com/apache/doris/pull/15663";>...</a>
  *
- * // only support simple case: select ... from tbl [where ...] order by ... 
limit ...
+ * // [deprecated] only support simple case: select ... from tbl [where ...] 
order by ... limit ...
  */
 
 public class TopNScanOpt extends PlanPostProcessor {
 
-    @Override
-    public Plan visit(Plan plan, CascadesContext context) {
-        return plan;
-    }
-
-    @Override
-    public Plan visitPhysicalSink(PhysicalSink<? extends Plan> physicalSink, 
CascadesContext context) {
-        if (physicalSink.child() instanceof TopN) {
-            return super.visit(physicalSink, context);
-        } else if (physicalSink.child() instanceof Project && 
physicalSink.child().child(0) instanceof TopN) {
-            PhysicalTopN<?> oldTopN = (PhysicalTopN<?>) 
physicalSink.child().child(0);
-            PhysicalTopN<?> newTopN = (PhysicalTopN<?>) oldTopN.accept(this, 
context);
-            if (newTopN == oldTopN) {
-                return physicalSink;
-            } else {
-                return 
physicalSink.withChildren(physicalSink.child().withChildren(newTopN));
-            }
-        }
-        return physicalSink;
-    }
-
-    @Override
-    public Plan visitPhysicalDistribute(PhysicalDistribute<? extends Plan> 
distribute, CascadesContext context) {
-        if (distribute.child() instanceof TopN && distribute.child() 
instanceof AbstractPhysicalSort
-                && ((AbstractPhysicalSort<?>) 
distribute.child()).getSortPhase() == SortPhase.LOCAL_SORT) {
-            return super.visit(distribute, context);
-        }
-        return distribute;
-    }
-
     @Override
     public PhysicalTopN<? extends Plan> visitPhysicalTopN(PhysicalTopN<? 
extends Plan> topN, CascadesContext ctx) {
-        if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
-            Plan child = topN.child();
-            topN = rewriteTopN(topN);
-            if (child != topN.child()) {
-                topN = ((PhysicalTopN<? extends Plan>) 
topN.withChildren(child)).copyStatsAndGroupIdFrom(topN);
-            }
-            return topN;
-        } else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
-            return (PhysicalTopN<? extends Plan>) super.visit(topN, ctx);
-        }
+        Optional<OlapScan> scanOpt = findScanForTopnFilter(topN);
+        scanOpt.ifPresent(scan -> 
ctx.getTopnFilterContext().addTopnFilter(topN, scan));
+        topN.child().accept(this, ctx);
         return topN;
     }
 
     @Override
     public Plan 
visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? extends Plan> 
topN,
             CascadesContext context) {
-        if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
-            PhysicalTopN<? extends Plan> rewrittenTopN = 
rewriteTopN(topN.getPhysicalTopN());
-            if (topN.getPhysicalTopN() != rewrittenTopN) {
-                topN = 
topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN);
-            }
-            return topN;
-        } else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
-            return super.visit(topN, context);
-        }
+        Optional<OlapScan> scanOpt = 
findScanForTopnFilter(topN.getPhysicalTopN());
+        scanOpt.ifPresent(scan -> 
context.getTopnFilterContext().addTopnFilter(topN, scan));
+        topN.child().accept(this, context);
         return topN;
     }
 
-    private PhysicalTopN<? extends Plan> rewriteTopN(PhysicalTopN<? extends 
Plan> topN) {
-        Plan child = topN.child();
+    private Optional<OlapScan> findScanForTopnFilter(PhysicalTopN<? extends 
Plan> topN) {
         if (topN.getSortPhase() != SortPhase.LOCAL_SORT) {
-            return topN;
+            return Optional.empty();
         }
         if (topN.getOrderKeys().isEmpty()) {
-            return topN;
+            return Optional.empty();
         }
 
         // topn opt
         long topNOptLimitThreshold = getTopNOptLimitThreshold();
         if (topNOptLimitThreshold == -1 || topN.getLimit() > 
topNOptLimitThreshold) {
-            return topN;
+            return Optional.empty();
         }
         // if firstKey's column is not present, it means the firstKey is not 
an original column from scan node
         // for example: "select cast(k1 as INT) as id from tbl1 order by id 
limit 2;" the firstKey "id" is
@@ -125,27 +79,44 @@ public class TopNScanOpt extends PlanPostProcessor {
         // see Alias::toSlot() method to get how column info is passed around 
by alias of slotReference
         Expression firstKey = topN.getOrderKeys().get(0).getExpr();
         if (!firstKey.isColumnFromTable()) {
-            return topN;
+            return Optional.empty();
         }
         if (firstKey.getDataType().isFloatType()
                 || firstKey.getDataType().isDoubleType()) {
-            return topN;
+            return Optional.empty();
         }
 
-        OlapScan olapScan;
-        while (child instanceof Project || child instanceof Filter) {
-            child = child.child(0);
+        if (! (firstKey instanceof SlotReference)) {
+            return Optional.empty();
         }
-        if (!(child instanceof OlapScan)) {
-            return topN;
+        OlapScan olapScan = findScanNodeBySlotReference(topN, (SlotReference) 
firstKey);
+        if (olapScan != null
+                && olapScan.getTable().isDupKeysOrMergeOnWrite()
+                && olapScan instanceof PhysicalCatalogRelation) {
+            return Optional.of(olapScan);
         }
-        olapScan = (OlapScan) child;
 
-        if (olapScan.getTable().isDupKeysOrMergeOnWrite()) {
-            return 
topN.withEnableRuntimeFilter(true).copyStatsAndGroupIdFrom(topN);
-        }
+        return Optional.empty();
+    }
 
-        return topN;
+    private OlapScan findScanNodeBySlotReference(Plan root, SlotReference 
slot) {
+        OlapScan target = null;
+        if (root instanceof OlapScan && root.getOutputSet().contains(slot)) {
+            return (OlapScan) root;
+        } else {
+            if (! root.children().isEmpty()) {
+                // for join and intersect, push topn-filter to their left 
child.
+                // TODO for union, topn-filter can be pushed down to all of 
its children.
+                Plan child = root.child(0);
+                if (!(child instanceof PhysicalWindow) && 
child.getOutputSet().contains(slot)) {
+                    target = findScanNodeBySlotReference(child, slot);
+                    if (target != null) {
+                        return target;
+                    }
+                }
+            }
+        }
+        return target;
     }
 
     private long getTopNOptLimitThreshold() {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java
new file mode 100644
index 00000000000..b5f79defef4
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopnFilterContext.java
@@ -0,0 +1,93 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.processor.post;
+
+import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
+import org.apache.doris.nereids.trees.plans.algebra.TopN;
+import org.apache.doris.planner.OlapScanNode;
+import org.apache.doris.planner.SortNode;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+/**
+ * topN runtime filter context
+ */
+public class TopnFilterContext {
+    private final Map<TopN, List<OlapScan>> filters = Maps.newHashMap();
+    private final Set<TopN> sources = Sets.newHashSet();
+    private final Set<OlapScan> targets = Sets.newHashSet();
+    private final Map<OlapScan, OlapScanNode> legacyTargetsMap = 
Maps.newHashMap();
+    private final Map<TopN, SortNode> legacySourceMap = Maps.newHashMap();
+
+    /**
+     * add topN filter
+     */
+    public void addTopnFilter(TopN topn, OlapScan scan) {
+        targets.add(scan);
+        sources.add(topn);
+
+        List<OlapScan> targets = filters.get(topn);
+        if (targets == null) {
+            filters.put(topn, Lists.newArrayList(scan));
+        } else {
+            targets.add(scan);
+        }
+    }
+
+    /**
+     * find the corresponding sortNode for topn filter
+     */
+    public Optional<OlapScanNode> getLegacyScanNode(OlapScan scan) {
+        return legacyTargetsMap.keySet().contains(scan)
+                ? Optional.of(legacyTargetsMap.get(scan))
+                : Optional.empty();
+    }
+
+    public Optional<SortNode> getLegacySortNode(TopN topn) {
+        return legacyTargetsMap.keySet().contains(topn)
+                ? Optional.of(legacySourceMap.get(topn))
+                : Optional.empty();
+    }
+
+    public boolean isTopnFilterSource(TopN topn) {
+        return sources.contains(topn);
+    }
+
+    public boolean isTopnFilterTarget(OlapScan scan) {
+        return targets.contains(scan);
+    }
+
+    public void addLegacySource(TopN topn, SortNode sort) {
+        legacySourceMap.put(topn, sort);
+    }
+
+    public void addLegacyTarget(OlapScan olapScan, OlapScanNode legacy) {
+        legacyTargetsMap.put(olapScan, legacy);
+    }
+
+    public List<OlapScan> getTargets(TopN topn) {
+        return filters.get(topn);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalTopN.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalTopN.java
index 6989284e0b9..96dc709bbde 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalTopN.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalTopN.java
@@ -43,36 +43,33 @@ public class PhysicalTopN<CHILD_TYPE extends Plan> extends 
AbstractPhysicalSort<
 
     private final long limit;
     private final long offset;
-    private final boolean enableRuntimeFilter;
 
     public PhysicalTopN(List<OrderKey> orderKeys, long limit, long offset,
             SortPhase phase, LogicalProperties logicalProperties, CHILD_TYPE 
child) {
-        this(orderKeys, limit, offset, phase, false, Optional.empty(), 
logicalProperties, child);
+        this(orderKeys, limit, offset, phase, Optional.empty(), 
logicalProperties, child);
     }
 
     /**
      * Constructor of PhysicalHashJoinNode.
      */
     public PhysicalTopN(List<OrderKey> orderKeys, long limit, long offset,
-            SortPhase phase, boolean enableRuntimeFilter,
+            SortPhase phase,
             Optional<GroupExpression> groupExpression, LogicalProperties 
logicalProperties, CHILD_TYPE child) {
-        this(orderKeys, limit, offset, phase, enableRuntimeFilter,
-                groupExpression, logicalProperties, null, null, child);
+        this(orderKeys, limit, offset, phase, groupExpression,
+                logicalProperties, null, null, child);
     }
 
     /**
      * Constructor of PhysicalHashJoinNode.
      */
     public PhysicalTopN(List<OrderKey> orderKeys, long limit, long offset,
-            SortPhase phase, boolean enableRuntimeFilter,
-            Optional<GroupExpression> groupExpression, LogicalProperties 
logicalProperties,
+            SortPhase phase, Optional<GroupExpression> groupExpression, 
LogicalProperties logicalProperties,
             PhysicalProperties physicalProperties, Statistics statistics, 
CHILD_TYPE child) {
         super(PlanType.PHYSICAL_TOP_N, orderKeys, phase, groupExpression, 
logicalProperties, physicalProperties,
                 statistics, child);
         Objects.requireNonNull(orderKeys, "orderKeys should not be null in 
PhysicalTopN.");
         this.limit = limit;
         this.offset = offset;
-        this.enableRuntimeFilter = enableRuntimeFilter;
     }
 
     public long getLimit() {
@@ -83,10 +80,6 @@ public class PhysicalTopN<CHILD_TYPE extends Plan> extends 
AbstractPhysicalSort<
         return offset;
     }
 
-    public boolean isEnableRuntimeFilter() {
-        return enableRuntimeFilter;
-    }
-
     @Override
     public boolean equals(Object o) {
         if (this == o) {
@@ -99,12 +92,12 @@ public class PhysicalTopN<CHILD_TYPE extends Plan> extends 
AbstractPhysicalSort<
             return false;
         }
         PhysicalTopN<?> that = (PhysicalTopN<?>) o;
-        return limit == that.limit && offset == that.offset && 
enableRuntimeFilter == that.enableRuntimeFilter;
+        return limit == that.limit && offset == that.offset;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), limit, offset, 
enableRuntimeFilter);
+        return Objects.hash(super.hashCode(), limit, offset);
     }
 
     @Override
@@ -112,22 +105,17 @@ public class PhysicalTopN<CHILD_TYPE extends Plan> 
extends AbstractPhysicalSort<
         return visitor.visitPhysicalTopN(this, context);
     }
 
-    public PhysicalTopN<Plan> withEnableRuntimeFilter(boolean 
enableRuntimeFilter) {
-        return new PhysicalTopN<>(orderKeys, limit, offset, phase, 
enableRuntimeFilter,
-                groupExpression, getLogicalProperties(), child());
-    }
-
     @Override
     public PhysicalTopN<Plan> withChildren(List<Plan> children) {
         Preconditions.checkArgument(children.size() == 1,
                 "PhysicalTopN's children size must be 1, but real is %s", 
children.size());
-        return new PhysicalTopN<>(orderKeys, limit, offset, phase, 
enableRuntimeFilter, groupExpression,
+        return new PhysicalTopN<>(orderKeys, limit, offset, phase, 
groupExpression,
                 getLogicalProperties(), physicalProperties, statistics, 
children.get(0));
     }
 
     @Override
     public PhysicalTopN<CHILD_TYPE> 
withGroupExpression(Optional<GroupExpression> groupExpression) {
-        return new PhysicalTopN<>(orderKeys, limit, offset, phase, 
enableRuntimeFilter,
+        return new PhysicalTopN<>(orderKeys, limit, offset, phase,
                 groupExpression, getLogicalProperties(), child());
     }
 
@@ -136,14 +124,14 @@ public class PhysicalTopN<CHILD_TYPE extends Plan> 
extends AbstractPhysicalSort<
             Optional<LogicalProperties> logicalProperties, List<Plan> 
children) {
         Preconditions.checkArgument(children.size() == 1,
                 "PhysicalTopN's children size must be 1, but real is %s", 
children.size());
-        return new PhysicalTopN<>(orderKeys, limit, offset, phase, 
enableRuntimeFilter,
+        return new PhysicalTopN<>(orderKeys, limit, offset, phase,
                 groupExpression, logicalProperties.get(), children.get(0));
     }
 
     @Override
     public PhysicalTopN<CHILD_TYPE> 
withPhysicalPropertiesAndStats(PhysicalProperties physicalProperties,
             Statistics statistics) {
-        return new PhysicalTopN<>(orderKeys, limit, offset, phase, 
enableRuntimeFilter,
+        return new PhysicalTopN<>(orderKeys, limit, offset, phase,
                 groupExpression, getLogicalProperties(), physicalProperties, 
statistics, child());
     }
 
@@ -158,8 +146,7 @@ public class PhysicalTopN<CHILD_TYPE extends Plan> extends 
AbstractPhysicalSort<
                 "limit", limit,
                 "offset", offset,
                 "orderKeys", orderKeys,
-                "phase", phase.toString(),
-                "enableRuntimeFilter", enableRuntimeFilter
+                "phase", phase.toString()
         );
     }
 
@@ -170,7 +157,7 @@ public class PhysicalTopN<CHILD_TYPE extends Plan> extends 
AbstractPhysicalSort<
 
     @Override
     public PhysicalTopN<Plan> resetLogicalProperties() {
-        return new PhysicalTopN<>(orderKeys, limit, offset, phase, 
enableRuntimeFilter, groupExpression,
+        return new PhysicalTopN<>(orderKeys, limit, offset, phase, 
groupExpression,
                 null, physicalProperties, statistics, child());
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java
index 11c402709f3..29243a4abeb 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java
@@ -174,8 +174,11 @@ public class OlapScanNode extends ScanNode {
     // It's limit for scanner instead of scanNode so we add a new limit.
     private long sortLimit = -1;
 
+    // useTopnOpt is equivalent to !topnFilterSortNodes.isEmpty().
+    // keep this flag for compatibility.
     private boolean useTopnOpt = false;
-
+    // support multi topn filter
+    private final List<SortNode> topnFilterSortNodes = Lists.newArrayList();
 
     // List of tablets will be scanned by current olap_scan_node
     private ArrayList<Long> scanTabletIds = Lists.newArrayList();
@@ -1339,7 +1342,10 @@ public class OlapScanNode extends ScanNode {
             output.append(prefix).append("SORT LIMIT: 
").append(sortLimit).append("\n");
         }
         if (useTopnOpt) {
-            output.append(prefix).append("TOPN OPT\n");
+            String topnFilterSources = String.join(",",
+                    topnFilterSortNodes.stream()
+                            .map(node -> node.getId().asInt() + 
"").collect(Collectors.toList()));
+            output.append(prefix).append("TOPN 
OPT:").append(topnFilterSources).append("\n");
         }
 
         if (!conjuncts.isEmpty()) {
@@ -1513,6 +1519,13 @@ public class OlapScanNode extends ScanNode {
             msg.olap_scan_node.setSortLimit(sortLimit);
         }
         msg.olap_scan_node.setUseTopnOpt(useTopnOpt);
+        List<Integer> topnFilterSourceNodeIds = getTopnFilterSortNodes()
+                .stream()
+                .map(sortNode -> sortNode.getId().asInt())
+                .collect(Collectors.toList());
+        if (!topnFilterSourceNodeIds.isEmpty()) {
+            
msg.olap_scan_node.setTopnFilterSourceNodeIds(topnFilterSourceNodeIds);
+        }
         msg.olap_scan_node.setKeyType(olapTable.getKeysType().toThrift());
         msg.olap_scan_node.setTableName(olapTable.getName());
         
msg.olap_scan_node.setEnableUniqueKeyMergeOnWrite(olapTable.getEnableUniqueKeyMergeOnWrite());
@@ -1785,4 +1798,12 @@ public class OlapScanNode extends ScanNode {
     public int getScanRangeNum() {
         return getScanTabletIds().size();
     }
+
+    public void addTopnFilterSortNode(SortNode sortNode) {
+        topnFilterSortNodes.add(sortNode);
+    }
+
+    public List<SortNode> getTopnFilterSortNodes() {
+        return topnFilterSortNodes;
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java
index 33a04c5dfa1..24b384d4453 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/SortNode.java
@@ -61,7 +61,7 @@ public class SortNode extends PlanNode {
     List<Expr> resolvedTupleExprs;
     private final SortInfo info;
     private final boolean  useTopN;
-    private boolean useTopnOpt;
+    private boolean useTopnOpt = false;
     private boolean useTwoPhaseReadOpt;
 
     // If mergeByexchange is set to true, the sort information is pushed to the
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
index 42a1455460f..2671228fc3c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java
@@ -3708,7 +3708,10 @@ public class Coordinator implements CoordInterface {
                 int rate = 
Math.min(Config.query_colocate_join_memory_limit_penalty_factor, 
instanceExecParams.size());
                 memLimit = queryOptions.getMemLimit() / rate;
             }
-
+            Set<Integer> topnFilterSources = scanNodes.stream()
+                    .filter(scanNode -> scanNode instanceof OlapScanNode)
+                    .flatMap(scanNode -> ((OlapScanNode) 
scanNode).getTopnFilterSortNodes().stream())
+                    .map(sort -> 
sort.getId().asInt()).collect(Collectors.toSet());
             Map<TNetworkAddress, TPipelineFragmentParams> res = new HashMap();
             Map<TNetworkAddress, Integer> instanceIdx = new HashMap();
             TPlanFragment fragmentThrift = fragment.toThrift();
@@ -3777,6 +3780,12 @@ public class Coordinator implements CoordInterface {
                 localParams.setBackendNum(backendNum++);
                 localParams.setRuntimeFilterParams(new TRuntimeFilterParams());
                 
localParams.runtime_filter_params.setRuntimeFilterMergeAddr(runtimeFilterMergeAddr);
+                if (!topnFilterSources.isEmpty()) {
+                    // topn_filter_source_node_ids is used by nereids not by 
legacy planner.
+                    // if there is no topnFilterSources, do not set it.
+                    // topn_filter_source_node_ids=null means legacy planner
+                    localParams.topn_filter_source_node_ids = 
Lists.newArrayList(topnFilterSources);
+                }
                 if 
(instanceExecParam.instanceId.equals(runtimeFilterMergeInstanceId)) {
                     Set<Integer> broadCastRf = 
assignedRuntimeFilters.stream().filter(RuntimeFilter::isBroadcast)
                             .map(r -> 
r.getFilterId().asInt()).collect(Collectors.toSet());
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
index 95a407d1342..57c6c045141 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
@@ -24,12 +24,13 @@ import org.apache.doris.nereids.trees.plans.SortPhase;
 import 
org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.PlanChecker;
 
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
-public class TopNRuntimeFilterTest extends SSBTestBase {
+public class TopNRuntimeFilterTest extends SSBTestBase implements 
MemoPatternMatchSupported {
     @Override
     public void runBeforeAll() throws Exception {
         super.runBeforeAll();
@@ -46,11 +47,11 @@ public class TopNRuntimeFilterTest extends SSBTestBase {
         Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, 
plan.children().get(0).child(0));
         PhysicalDeferMaterializeTopN<? extends Plan> localTopN
                 = (PhysicalDeferMaterializeTopN<? extends Plan>) 
plan.child(0).child(0);
-        
Assertions.assertTrue(localTopN.getPhysicalTopN().isEnableRuntimeFilter());
+        
Assertions.assertTrue(checker.getCascadesContext().getTopnFilterContext().isTopnFilterSource(localTopN));
     }
 
     @Test
-    public void testNotUseTopNRfForComplexCase() {
+    public void testUseTopNRfForComplexCase() {
         String sql = "select * from (select 1) tl join (select * from customer 
order by c_custkey limit 5) tb";
         PlanChecker checker = PlanChecker.from(connectContext).analyze(sql)
                 .rewrite()
@@ -62,6 +63,21 @@ public class TopNRuntimeFilterTest extends SSBTestBase {
                 .child(0).child(0).child(1).child(0)).getSortPhase());
         PhysicalTopN<? extends Plan> localTopN = (PhysicalTopN<? extends 
Plan>) plan
                 .child(0).child(0).child(1).child(0);
-        Assertions.assertFalse(localTopN.isEnableRuntimeFilter());
+        
Assertions.assertTrue(checker.getCascadesContext().getTopnFilterContext().isTopnFilterSource(localTopN));
+    }
+
+    @Test
+    public void testNotUseTopNRfOnWindow() {
+        String sql = "select rank() over (partition by c_nation order by 
c_custkey) "
+                + "from customer order by c_custkey limit 3";
+        PlanChecker checker = PlanChecker.from(connectContext).analyze(sql)
+                .rewrite().implement();
+        PhysicalPlan plan = checker.getPhysicalPlan();
+        plan = new 
PlanPostProcessors(checker.getCascadesContext()).process(plan);
+        System.out.println(plan.treeString());
+        PhysicalTopN<? extends Plan> localTopN =
+                (PhysicalTopN<? extends Plan>) plan.child(0).child(0).child(0);
+        Assertions.assertTrue(localTopN.getSortPhase().isLocal());
+        
Assertions.assertFalse(checker.getCascadesContext().getTopnFilterContext().isTopnFilterSource(localTopN));
     }
 }
diff --git a/regression-test/data/nereids_tpch_p0/tpch/topn-filter.out 
b/regression-test/data/nereids_tpch_p0/tpch/topn-filter.out
new file mode 100644
index 00000000000..be88d829f25
--- /dev/null
+++ b/regression-test/data/nereids_tpch_p0/tpch/topn-filter.out
@@ -0,0 +1,29 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !simpleTopn --
+1      3691    O       194029.55       1996-01-02      5-LOW   Clerk#000000951 
0       nstructions sleep furiously among 
+2      7801    O       60951.63        1996-12-01      1-URGENT        
Clerk#000000880 0        foxes. pending accounts at the pending, silent asymptot
+3      12332   F       247296.05       1993-10-14      5-LOW   Clerk#000000955 
0       sly final accounts boost. carefully regular ideas cajole carefully. 
depos
+4      13678   O       53829.87        1995-10-11      5-LOW   Clerk#000000124 
0       sits. slyly regular warthogs cajole. regular, regular theodolites acro
+5      4450    F       139660.54       1994-07-30      5-LOW   Clerk#000000925 
0       quickly. bold deposits sleep slyly. packages use slyly
+6      5563    F       65843.52        1992-02-21      4-NOT SPECIFIED 
Clerk#000000058 0       ggle. special, final requests are against the furiously 
specia
+7      3914    O       231037.28       1996-01-10      2-HIGH  Clerk#000000470 
0       ly special requests 
+32     13006   O       166802.63       1995-07-16      2-HIGH  Clerk#000000616 
0       ise blithely bold, regular requests. quickly unusual dep
+33     6697    F       118518.56       1993-10-27      3-MEDIUM        
Clerk#000000409 0       uriously. furiously final request
+34     6101    O       75662.77        1998-07-21      3-MEDIUM        
Clerk#000000223 0       ly final packages. fluffily final deposits wake 
blithely ideas. spe
+
+-- !complexTopn --
+1      3691    10
+2      7801    1
+
+-- !check_result --
+67     5662    0
+102    73      0
+
+-- !check_result2 --
+33     6697    24
+551    8962    24
+
+-- !groupingsets --
+0      50
+1      47
+
diff --git a/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy 
b/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy
new file mode 100644
index 00000000000..23e742fba15
--- /dev/null
+++ b/regression-test/suites/nereids_tpch_p0/tpch/topn-filter.groovy
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("topn-filter") {
+    String db = context.config.getDbNameByFile(new File(context.file.parent))
+    sql "use ${db}"
+    sql 'set enable_nereids_planner=true'
+    sql 'set enable_fallback_to_original_planner=false'
+    sql 'set disable_join_reorder=true;'
+    sql 'set topn_opt_limit_threshold=1024'
+    def String simpleTopn = """
+        select *
+        from orders
+        order by o_orderkey
+        limit 10;"""
+    
+    explain {
+        sql "${simpleTopn}"
+        contains "TOPN OPT:1"
+    }
+
+    qt_simpleTopn "${simpleTopn}"
+    
+    def String complexTopn = """
+        select o_orderkey, c_custkey, n_nationkey
+        from orders 
+        join[broadcast] customer on o_custkey = c_custkey 
+        join[broadcast] nation on c_nationkey=n_nationkey
+        order by o_orderkey limit 2; 
+        """
+    explain{
+        sql "${complexTopn}"
+        contains "TOPN OPT:7"
+    }
+    qt_complexTopn "${complexTopn}"
+
+    def multi_topn_asc = """
+    select o_orderkey, c_custkey, n_nationkey
+    from orders 
+    join[broadcast] customer on o_custkey = c_custkey 
+    join[broadcast] ( select * from nation order by n_nationkey asc limit 1) 
as n on c_nationkey=n_nationkey
+    order by o_orderkey limit 2; 
+    """
+    qt_check_result "${multi_topn_asc}"
+    explain{
+        sql "${multi_topn_asc}"
+        contains "TOPN OPT:9"
+        contains "TOPN OPT:1"
+    }
+
+    def multi_topn_desc = """
+    select o_orderkey, c_custkey, n_nationkey
+    from orders 
+    join[broadcast] customer on o_custkey = c_custkey 
+    join[broadcast] (select * from nation order by n_nationkey desc limit 1) 
as n on c_nationkey=n_nationkey
+    order by o_orderkey limit 2; 
+    """
+    explain {
+        sql "${multi_topn_desc}"
+        contains "TOPN OPT:9"
+        contains "TOPN OPT:1"
+    }
+
+    qt_check_result2 "${multi_topn_desc}"
+
+    // do not use topn-filter
+    explain {
+        sql """
+                select o_orderkey, c_custkey
+                from orders 
+                join[broadcast] customer on o_custkey = c_custkey 
+                order by c_custkey limit 2; 
+            """
+        notContains "TOPN OPT:"
+    }
+
+    // push topn filter down through AGG
+    explain {
+        sql """
+            select s_nationkey, count(1) from supplier group by s_nationkey 
order by s_nationkey limit 1;
+        """
+        contains "TOPN OPT:"
+    }
+
+    // push topn filter down through AGG + Join
+    explain {
+        sql """
+            select * 
+            from 
+             (select s_nationkey, count(1) as total from supplier group by 
s_nationkey having total > 10 ) T
+            join nation on s_nationkey = n_nationkey 
+            order by s_nationkey limit 1;
+        """
+        contains "TOPN OPT:"
+    }
+
+    explain {
+        sql "select n_regionkey, sum(n_nationkey) from nation group by 
grouping sets((n_regionkey)) order by n_regionkey limit 2;"
+        contains "TOPN OPT"
+    }
+
+    qt_groupingsets "select n_regionkey, sum(n_nationkey) from nation group by 
grouping sets((n_regionkey)) order by n_regionkey limit 2;"
+
+}
\ No newline at end of file
diff --git a/regression-test/suites/point_query_p0/load.groovy 
b/regression-test/suites/point_query_p0/load.groovy
index d5cf8074540..772a7130363 100644
--- a/regression-test/suites/point_query_p0/load.groovy
+++ b/regression-test/suites/point_query_p0/load.groovy
@@ -18,8 +18,7 @@
 import org.codehaus.groovy.runtime.IOGroovyMethods
 
 suite("test_point_query_load", "p0") {
-
-    // nereids do not support point query now
+    //test legacy planner
     sql """set enable_nereids_planner=false"""
 
     def dataFile = 
"""${getS3Url()}/regression/datatypes/test_scalar_types_10w.csv"""
@@ -101,6 +100,7 @@ suite("test_point_query_load", "p0") {
      }
     sql "INSERT INTO ${testTable} SELECT * from ${testTable}"
 
+    // test nereids planner
     sql """set enable_nereids_planner=true;"""
     explain {
         sql("""SELECT


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to