This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 5e5615f27a8 branch-4.0: [Fix](agg) fix push agg op in nullable column 
before projection #58234 (#58281)
5e5615f27a8 is described below

commit 5e5615f27a87167aebe78c3e68d8eb3af3ffd6cc
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Nov 25 10:04:23 2025 +0800

    branch-4.0: [Fix](agg) fix push agg op in nullable column before projection 
#58234 (#58281)
    
    Cherry-picked from #58234
    
    Co-authored-by: HappenLee <[email protected]>
---
 .../rules/implementation/AggregateStrategies.java  | 55 ++++++++++++++--------
 .../explain/test_pushdown_explain.groovy           | 34 +++++++++++++
 2 files changed, 69 insertions(+), 20 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index a5304c6e11b..ca2f115f096 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -61,6 +61,7 @@ import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableList;
 
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -565,6 +566,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
         boolean containsCount = false;
         Set<SlotReference> checkNullSlots = new HashSet<>();
+        Set<Expression> expressionAfterProject = new HashSet<>();
 
         // Single loop through aggregateFunctions to handle multiple logic
         for (AggregateFunction function : aggregateFunctions) {
@@ -582,10 +584,12 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     Expression arg0 = function.getArguments().get(0);
                     if (arg0 instanceof SlotReference) {
                         checkNullSlots.add((SlotReference) arg0);
+                        expressionAfterProject.add(arg0);
                     } else if (arg0 instanceof Cast) {
                         Expression child0 = arg0.child(0);
                         if (child0 instanceof SlotReference) {
                             checkNullSlots.add((SlotReference) child0);
+                            expressionAfterProject.add(arg0);
                         }
                     }
                 }
@@ -633,28 +637,39 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 .collect(ImmutableList.toImmutableList());
 
         if (project != null) {
-            argumentsOfAggregateFunction = Project.findProject(
-                        argumentsOfAggregateFunction, project.getProjects())
-                    .stream()
-                    .map(p -> p instanceof Alias ? p.child(0) : p)
-                    .collect(ImmutableList.toImmutableList());
-        }
-
-        onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
-                .stream()
-                .allMatch(argument -> {
-                    if (argument instanceof SlotReference) {
-                        return true;
+            List<Expression> processedExpressions = new ArrayList<>();
+            List<? extends Expression> projections = 
Project.findProject(argumentsOfAggregateFunction,
+                    project.getProjects());
+
+            for (int i = 0, size = projections.size(); i < size; i++) {
+                // Process the expression (replace Alias with its child)
+                boolean needCheckSlotNull = 
expressionAfterProject.contains(argumentsOfAggregateFunction.get(i));
+                Expression p = projections.get(i);
+                Expression argument = p instanceof Alias ? p.child(0) : p;
+                processedExpressions.add(argument);
+
+                // Check if the argument matches the required pattern
+                if (argument instanceof SlotReference) {
+                    // Argument is valid, continue
+                    if (needCheckSlotNull) {
+                        checkNullSlots.add((SlotReference) argument);
                     }
-                    if (argument instanceof Cast) {
-                        return argument.child(0) instanceof SlotReference
-                                && argument.getDataType().isNumericType()
-                                && 
argument.child(0).getDataType().isNumericType();
+                } else if (argument instanceof Cast) {
+                    boolean castMatch = argument.child(0) instanceof 
SlotReference
+                            && argument.getDataType().isNumericType()
+                            && argument.child(0).getDataType().isNumericType();
+                    if (!castMatch) {
+                        return canNotPush;
+                    } else {
+                        if (needCheckSlotNull) {
+                            checkNullSlots.add((SlotReference) 
argument.child(0));
+                        }
                     }
-                    return false;
-                });
-        if (!onlyContainsSlotOrNumericCastSlot) {
-            return canNotPush;
+                } else {
+                    return canNotPush;
+                }
+            }
+            argumentsOfAggregateFunction = processedExpressions;
         }
 
         Set<PushDownAggOp> pushDownAggOps = functionClasses.stream()
diff --git 
a/regression-test/suites/nereids_p0/explain/test_pushdown_explain.groovy 
b/regression-test/suites/nereids_p0/explain/test_pushdown_explain.groovy
index 25053f24023..220ab4038c5 100644
--- a/regression-test/suites/nereids_p0/explain/test_pushdown_explain.groovy
+++ b/regression-test/suites/nereids_p0/explain/test_pushdown_explain.groovy
@@ -107,6 +107,40 @@ suite("test_pushdown_explain") {
         contains "pushAggOp=COUNT"
     }
 
+    // test projection in agg pushdown rule
+    explain {
+        sql("select count(a) from (select non_nullable_col as a from 
test_null_columns) t1;")
+        contains "pushAggOp=COUNT"
+    }
+    explain {
+        sql("select count(a) from (select nullable_col as a from 
test_null_columns) t1;")
+        contains "pushAggOp=NONE"
+    }
+    explain {
+        sql("select count(a), min(a) from (select non_nullable_col as a from 
test_null_columns) t1;")
+        contains "pushAggOp=MIX"
+    }
+    explain {
+        sql("select count(a), min(a) from (select nullable_col as a from 
test_null_columns) t1;")
+        contains "pushAggOp=NONE"
+    }
+    explain {
+        sql("select count(a), min(b) from (select nullable_col as a, 
non_nullable_col as b from test_null_columns) t1;")
+        contains "pushAggOp=NONE"
+    }
+    explain {
+        sql("select count(b), min(a) from (select nullable_col as a, 
non_nullable_col as b from test_null_columns) t1;")
+        contains "pushAggOp=MIX"
+    }
+    explain {
+        sql("select count(non_nullable_col), max(nullable_col) from 
test_null_columns;")
+        contains "pushAggOp=MIX"
+    }
+    explain {
+        sql("select count(nullable_col), max(non_nullable_col) from 
test_null_columns;")
+        contains "pushAggOp=NONE"
+    }
+
     explain {
         sql("select count(non_nullable_col), min(non_nullable_col), 
max(non_nullable_col) from test_null_columns;")
         contains "pushAggOp=MIX"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to