This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new c30c0ff7891 [opt](nereids) optimize push down project (#58370)
c30c0ff7891 is described below

commit c30c0ff78912a05c2451f6d52c149842b19163ed
Author: 924060929 <[email protected]>
AuthorDate: Wed Nov 26 11:39:03 2025 +0800

    [opt](nereids) optimize push down project (#58370)
    
    optimize push down project, this can reduce the scan bytes and shuffle
    bytes by prune nested column. #57204 related
    
    the sql:
    ```sql
    select coalecse(struct_element(t1.s, 'city'), 'beijing')
    from t1 join t2
    on t1.id = t2.id
    ```
    
    original plan:
    ```
    Project(coalecse(struct_element(t1.s, 'city'), 'beijing'))
                                 |
                        Join(t1.id=t2.id)
                        /               \
                Project(t1.id, t1.s)    Project(t2.id)
                     |                    |
                Scan(t1)                Scan(t2)
    ```
    
    optimize plan:
    ```
    
                           Project(coalecse(slot#3, 'beijing'))
                                          |
                                   Join(t1.id=t2.id)
                        /                                       \
    Project(t1.id, struct_element(t1.s, 'city')#3)              Project(t2.id)
                  |                                                |
                Scan(t1)                                       Scan(t2)
    ```
---
 .../rules/rewrite/AccessPathPlanCollector.java     | 25 ++++++++++++
 .../nereids/rules/rewrite/PushDownProject.java     | 47 ++++++++++++----------
 .../rules/rewrite/PruneNestedColumnTest.java       | 21 +++++++---
 3 files changed, 65 insertions(+), 28 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
index 514f7bb1e8c..ed253167b38 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
@@ -19,7 +19,9 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.StatementContext;
 import 
org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectAccessPathResult;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
@@ -28,6 +30,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
 import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
@@ -53,6 +56,28 @@ public class AccessPathPlanCollector extends 
DefaultPlanVisitor<Void, StatementC
         return scanSlotToAccessPaths;
     }
 
+    @Override
+    public Void visitLogicalProject(LogicalProject<? extends Plan> project, 
StatementContext context) {
+        AccessPathExpressionCollector exprCollector
+                = new AccessPathExpressionCollector(context, 
allSlotToAccessPaths, false);
+        for (NamedExpression output : project.getProjects()) {
+            // e.g. select struct_element(s, 'city') from (select s from tbl)a;
+            // we will not treat the inner `s` access all path
+            if (output instanceof Slot && 
allSlotToAccessPaths.containsKey(output.getExprId().asInt())) {
+                continue;
+            } else if (output instanceof Alias && output.child(0) instanceof 
Slot
+                    && 
allSlotToAccessPaths.containsKey(output.getExprId().asInt())) {
+                Slot innerSlot = (Slot) output.child(0);
+                Collection<CollectAccessPathResult> outerSlotAccessPaths = 
allSlotToAccessPaths.get(
+                        output.getExprId().asInt());
+                allSlotToAccessPaths.putAll(innerSlot.getExprId().asInt(), 
outerSlotAccessPaths);
+            } else {
+                exprCollector.collect(output);
+            }
+        }
+        return project.child().accept(this, context);
+    }
+
     @Override
     public Void visitLogicalFilter(LogicalFilter<? extends Plan> filter, 
StatementContext context) {
         boolean bottomFilter = filter.child().arity() == 0;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java
index 8b1fbaac8ae..9fbc9413b29 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java
@@ -48,7 +48,6 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
-import java.util.Set;
 import java.util.function.Function;
 
 /** push down project if the expression instance of PreferPushDownProject */
@@ -320,13 +319,13 @@ public class PushDownProject implements 
RewriteRuleFactory, NormalizeToSlot {
     private static class PushdownProjectHelper {
         private final Plan plan;
         private final StatementContext statementContext;
-        private final Map<Expression, Pair<Slot, Plan>> exprToChildAndSlot;
+        private final Map<Expression, Expression> oldExprToNewExpr;
         private final Multimap<Plan, NamedExpression> childToPushDownProjects;
 
         public PushdownProjectHelper(StatementContext statementContext, Plan 
plan) {
             this.statementContext = statementContext;
             this.plan = plan;
-            this.exprToChildAndSlot = new LinkedHashMap<>();
+            this.oldExprToNewExpr = new LinkedHashMap<>();
             this.childToPushDownProjects = ArrayListMultimap.create();
         }
 
@@ -357,32 +356,36 @@ public class PushDownProject implements 
RewriteRuleFactory, NormalizeToSlot {
         }
 
         public <E extends Expression> Optional<E> pushDownExpression(E 
expression) {
-            if (!(expression instanceof PreferPushDownProject
-                    || (expression instanceof Alias && expression.child(0) 
instanceof PreferPushDownProject))) {
+            if (!expression.containsType(PreferPushDownProject.class)) {
                 return Optional.empty();
             }
-            Pair<Slot, Plan> existPushdown = 
exprToChildAndSlot.get(expression);
+            Expression existPushdown = oldExprToNewExpr.get(expression);
             if (existPushdown != null) {
-                return Optional.of((E) existPushdown.first);
+                return Optional.of((E) existPushdown);
             }
 
-            Alias pushDownAlias = null;
-            if (expression instanceof Alias) {
-                pushDownAlias = (Alias) expression;
-            } else {
-                pushDownAlias = new Alias(statementContext.getNextExprId(), 
expression);
-            }
-
-            Set<Slot> inputSlots = expression.getInputSlots();
-            for (Plan child : plan.children()) {
-                if (child.getOutputSet().containsAll(inputSlots)) {
-                    Slot remaimSlot = pushDownAlias.toSlot();
-                    exprToChildAndSlot.put(expression, Pair.of(remaimSlot, 
child));
-                    childToPushDownProjects.put(child, pushDownAlias);
-                    return Optional.of((E) remaimSlot);
+            Expression newExpression = expression.rewriteDownShortCircuit(e -> 
{
+                if (e instanceof PreferPushDownProject) {
+                    List<Plan> children = plan.children();
+                    for (int i = 0; i < children.size(); i++) {
+                        Plan child = children.get(i);
+                        if 
(child.getOutputSet().containsAll(e.getInputSlots())) {
+                            Alias alias = new 
Alias(statementContext.getNextExprId(), e);
+                            Slot slot = alias.toSlot();
+                            childToPushDownProjects.put(child, alias);
+                            return slot;
+                        }
+                    }
                 }
+                return e;
+            });
+
+            if (newExpression != expression) {
+                oldExprToNewExpr.put(expression, newExpression);
+                return Optional.of((E) newExpression);
+            } else {
+                return Optional.empty();
             }
-            return Optional.empty();
         }
 
         public List<Plan> buildNewChildren() {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
index 22e9c5e949f..3d22e1e6c28 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
@@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Coalesce;
 import 
org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
 import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
 import org.apache.doris.nereids.trees.plans.Plan;
@@ -72,6 +73,7 @@ public class PruneNestedColumnTest extends TestWithFeService 
implements MemoPatt
 
         createTable("create table tbl(\n"
                 + "  id int,\n"
+                + "  value int,\n"
                 + "  s struct<\n"
                 + "    city: string,\n"
                 + "    data: array<map<\n"
@@ -83,6 +85,7 @@ public class PruneNestedColumnTest extends TestWithFeService 
implements MemoPatt
 
         createTable("create table tbl2(\n"
                 + "  id2 int,\n"
+                + "  value int,\n"
                 + "  s2 struct<\n"
                 + "    city2: string,\n"
                 + "    data2: array<map<\n"
@@ -376,13 +379,13 @@ public class PruneNestedColumnTest extends 
TestWithFeService implements MemoPatt
 
     @Test
     public void testUnion() throws Throwable {
-        assertColumn("select struct_element(s, 'city') from (select s from tbl 
union all select null)a",
+        assertColumn("select coalesce(struct_element(s, 'city'), 'abc') from 
(select s from tbl union all select null)a",
                 "struct<city:text>",
                 ImmutableList.of(path("s", "city")),
                 ImmutableList.of()
         );
 
-        assertColumn("select * from (select struct_element(s, 'city') from tbl 
union all select null)a",
+        assertColumn("select * from (select coalesce(struct_element(s, 
'city'), 'abc') from tbl union all select null)a",
                 "struct<city:text>",
                 ImmutableList.of(path("s", "city")),
                 ImmutableList.of()
@@ -407,7 +410,7 @@ public class PruneNestedColumnTest extends 
TestWithFeService implements MemoPatt
     @Test
     public void testPushDownThroughJoin() {
         PlanChecker.from(connectContext)
-                .analyze("select struct_element(s, 'city') from (select * from 
tbl)a join (select 100 id, 'f1' name)b on a.id=b.id")
+                .analyze("select coalesce(struct_element(s, 'city'), 'abc') 
from (select * from tbl)a join (select 100 id, 'f1' name)b on a.id=b.id")
                 .rewrite()
                 .matches(
                     logicalResultSink(
@@ -426,7 +429,9 @@ public class PruneNestedColumnTest extends 
TestWithFeService implements MemoPatt
                                 logicalOneRowRelation()
                             )
                         ).when(p -> {
-                            Assertions.assertTrue(p.getProjects().size() == 1 
&& p.getProjects().get(0) instanceof SlotReference);
+                            Assertions.assertTrue(p.getProjects().size() == 1 
&& p.getProjects().get(0) instanceof Alias
+                                    && p.getProjects().get(0).child(0) 
instanceof Coalesce
+                                    && 
p.getProjects().get(0).child(0).child(0) instanceof Slot);
                             return true;
                         })
                     )
@@ -479,7 +484,9 @@ public class PruneNestedColumnTest extends 
TestWithFeService implements MemoPatt
                                 })
                             )
                         ).when(p -> {
-                            Assertions.assertTrue(p.getProjects().size() == 2 
&& p.getProjects().get(0) instanceof SlotReference);
+                            Assertions.assertTrue(p.getProjects().size() == 2
+                                    && (p.getProjects().get(0) instanceof 
SlotReference
+                                        || (p.getProjects().get(0) instanceof 
Alias && p.getProjects().get(0).child(0) instanceof SlotReference)));
                             return true;
                         })
                     )
@@ -509,7 +516,9 @@ public class PruneNestedColumnTest extends 
TestWithFeService implements MemoPatt
                                         )
                                     )
                                 ).when(p -> {
-                                    
Assertions.assertTrue(p.getProjects().size() == 2 && p.getProjects().get(0) 
instanceof SlotReference);
+                                    
Assertions.assertTrue(p.getProjects().size() == 2
+                                            && (p.getProjects().get(0) 
instanceof SlotReference
+                                                || p.getProjects().get(0) 
instanceof Alias && p.getProjects().get(0).child(0) instanceof SlotReference));
                                     return true;
                                 })
                             )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to