This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 8bc3c3641d9 [fix](Nereids) topn runtime filter only support simplest 
case (#29312) (#29455)
8bc3c3641d9 is described below

commit 8bc3c3641d90b3b5cfbc4267e47b951dfff07148
Author: morrySnow <[email protected]>
AuthorDate: Wed Jan 3 23:19:43 2024 +0800

    [fix](Nereids) topn runtime filter only support simplest case (#29312) 
(#29455)
---
 .../doris/nereids/processor/post/TopNScanOpt.java  | 58 +++++++++++++++++-----
 .../nereids/postprocess/TopNRuntimeFilterTest.java | 22 +++++++-
 2 files changed, 65 insertions(+), 15 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
index 1a0dddfe5c2..31a2a7aff56 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/TopNScanOpt.java
@@ -24,27 +24,58 @@ import org.apache.doris.nereids.trees.plans.SortPhase;
 import org.apache.doris.nereids.trees.plans.algebra.Filter;
 import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
 import org.apache.doris.nereids.trees.plans.algebra.Project;
+import org.apache.doris.nereids.trees.plans.algebra.TopN;
+import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
 import 
org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalSink;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
 import org.apache.doris.qe.ConnectContext;
 
-import com.google.common.collect.ImmutableList;
-
 /**
  * topN opt
  * refer to:
  * <a href="https://github.com/apache/doris/pull/15558";>...</a>
  * <a href="https://github.com/apache/doris/pull/15663";>...</a>
+ *
+ * // only support simple case: select ... from tbl [where ...] order by ... 
limit ...
  */
 
 public class TopNScanOpt extends PlanPostProcessor {
 
+    @Override
+    public Plan visit(Plan plan, CascadesContext context) {
+        return plan;
+    }
+
+    @Override
+    public Plan visitPhysicalSink(PhysicalSink<? extends Plan> physicalSink, 
CascadesContext context) {
+        if (physicalSink.child() instanceof TopN) {
+            return super.visit(physicalSink, context);
+        }
+        return physicalSink;
+    }
+
+    @Override
+    public Plan visitPhysicalDistribute(PhysicalDistribute<? extends Plan> 
distribute, CascadesContext context) {
+        if (distribute.child() instanceof TopN && distribute.child() 
instanceof AbstractPhysicalSort
+                && ((AbstractPhysicalSort<?>) 
distribute.child()).getSortPhase() == SortPhase.LOCAL_SORT) {
+            return super.visit(distribute, context);
+        }
+        return distribute;
+    }
+
     @Override
     public PhysicalTopN<? extends Plan> visitPhysicalTopN(PhysicalTopN<? 
extends Plan> topN, CascadesContext ctx) {
-        Plan child = topN.child().accept(this, ctx);
-        topN = rewriteTopN(topN);
-        if (child != topN.child()) {
-            topN = ((PhysicalTopN) 
topN.withChildren(child)).copyStatsAndGroupIdFrom(topN);
+        if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
+            Plan child = topN.child();
+            topN = rewriteTopN(topN);
+            if (child != topN.child()) {
+                topN = ((PhysicalTopN<? extends Plan>) 
topN.withChildren(child)).copyStatsAndGroupIdFrom(topN);
+            }
+            return topN;
+        } else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
+            return (PhysicalTopN<? extends Plan>) super.visit(topN, ctx);
         }
         return topN;
     }
@@ -52,13 +83,14 @@ public class TopNScanOpt extends PlanPostProcessor {
     @Override
     public Plan 
visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? extends Plan> 
topN,
             CascadesContext context) {
-        Plan child = topN.child().accept(this, context);
-        if (child != topN.child()) {
-            topN = 
topN.withChildren(ImmutableList.of(child)).copyStatsAndGroupIdFrom(topN);
-        }
-        PhysicalTopN<? extends Plan> rewrittenTopN = 
rewriteTopN(topN.getPhysicalTopN());
-        if (topN.getPhysicalTopN() != rewrittenTopN) {
-            topN = 
topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN);
+        if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
+            PhysicalTopN<? extends Plan> rewrittenTopN = 
rewriteTopN(topN.getPhysicalTopN());
+            if (topN.getPhysicalTopN() != rewrittenTopN) {
+                topN = 
topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN);
+            }
+            return topN;
+        } else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
+            return super.visit(topN, context);
         }
         return topN;
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
index f4fdf6f44f0..0ac233898ad 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/TopNRuntimeFilterTest.java
@@ -20,8 +20,10 @@ package org.apache.doris.nereids.postprocess;
 import org.apache.doris.nereids.datasets.ssb.SSBTestBase;
 import org.apache.doris.nereids.processor.post.PlanPostProcessors;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.SortPhase;
 import 
org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
 import org.apache.doris.nereids.util.PlanChecker;
 
 import org.junit.jupiter.api.Assertions;
@@ -41,7 +43,7 @@ public class TopNRuntimeFilterTest extends SSBTestBase {
                 .implement();
         PhysicalPlan plan = checker.getPhysicalPlan();
         plan = new 
PlanPostProcessors(checker.getCascadesContext()).process(plan);
-        Assertions.assertTrue(plan.children().get(0).child(0) instanceof 
PhysicalDeferMaterializeTopN);
+        Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, 
plan.children().get(0).child(0));
         PhysicalDeferMaterializeTopN<? extends Plan> localTopN
                 = (PhysicalDeferMaterializeTopN<? extends Plan>) 
plan.child(0).child(0);
         
Assertions.assertTrue(localTopN.getPhysicalTopN().isEnableRuntimeFilter());
@@ -56,9 +58,25 @@ public class TopNRuntimeFilterTest extends SSBTestBase {
                 .implement();
         PhysicalPlan plan = checker.getPhysicalPlan();
         plan = new 
PlanPostProcessors(checker.getCascadesContext()).process(plan);
-        Assertions.assertTrue(plan.children().get(0).child(0) instanceof 
PhysicalDeferMaterializeTopN);
+        Assertions.assertInstanceOf(PhysicalDeferMaterializeTopN.class, 
plan.children().get(0).child(0));
         PhysicalDeferMaterializeTopN<? extends Plan> localTopN
                 = (PhysicalDeferMaterializeTopN<? extends Plan>) 
plan.child(0).child(0);
         
Assertions.assertFalse(localTopN.getPhysicalTopN().isEnableRuntimeFilter());
     }
+
+    @Test
+    public void testNotUseTopNRfForComplexCase() {
+        String sql = "select * from (select 1) tl join (select * from customer 
order by c_custkey limit 5) tb";
+        PlanChecker checker = PlanChecker.from(connectContext).analyze(sql)
+                .rewrite()
+                .implement();
+        PhysicalPlan plan = checker.getPhysicalPlan();
+        plan = new 
PlanPostProcessors(checker.getCascadesContext()).process(plan);
+        Assertions.assertInstanceOf(PhysicalTopN.class, 
plan.child(0).child(0).child(1).child(0));
+        Assertions.assertEquals(SortPhase.LOCAL_SORT, ((PhysicalTopN<? extends 
Plan>) plan
+                .child(0).child(0).child(1).child(0)).getSortPhase());
+        PhysicalTopN<? extends Plan> localTopN = (PhysicalTopN<? extends 
Plan>) plan
+                .child(0).child(0).child(1).child(0);
+        Assertions.assertFalse(localTopN.isEnableRuntimeFilter());
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to