This is an automated email from the ASF dual-hosted git repository.

panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 703badd9367 [Fix](agg) fix push agg op in nullable column before 
projection (#58234)
703badd9367 is described below

commit 703badd9367c5d9f52819622094efb113d1529dc
Author: HappenLee <[email protected]>
AuthorDate: Sat Nov 22 13:08:50 2025 +0800

    [Fix](agg) fix push agg op in nullable column before projection (#58234)
    
    [Fix](chrome-extension://dbjibobgilijgolhjdcbdebjhejelffo/agg) Fix push
    aggregate operation in nullable column before projection
    # Changes
    Core Logic Adjustment (AggregateStrategies.java):
    Removed redundant Count function-specific nullable slot checking logic,
    unifying aggregate function argument validation into a single
    stream-based check
    Enhanced argument type verification: For SlotReference and numeric-type
    Cast (with SlotReference as child) arguments, uniformly collect slots to
    checkNullSlots for nullability validation
    Fixed inconsistent aggregate pushdown judgment for nested projection
    scenarios, ensuring consistent validation logic for aggregate function
    arguments across different code branches
---
 .../rules/implementation/AggregateStrategies.java  | 55 ++++++++++++++--------
 .../explain/test_pushdown_explain.groovy           | 34 +++++++++++++
 2 files changed, 69 insertions(+), 20 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index 1634b9208d7..8b79d7e7fb7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -61,6 +61,7 @@ import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableList;
 
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -565,6 +566,7 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
 
         boolean containsCount = false;
         Set<SlotReference> checkNullSlots = new HashSet<>();
+        Set<Expression> expressionAfterProject = new HashSet<>();
 
         // Single loop through aggregateFunctions to handle multiple logic
         for (AggregateFunction function : aggregateFunctions) {
@@ -582,10 +584,12 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                     Expression arg0 = function.getArguments().get(0);
                     if (arg0 instanceof SlotReference) {
                         checkNullSlots.add((SlotReference) arg0);
+                        expressionAfterProject.add(arg0);
                     } else if (arg0 instanceof Cast) {
                         Expression child0 = arg0.child(0);
                         if (child0 instanceof SlotReference) {
                             checkNullSlots.add((SlotReference) child0);
+                            expressionAfterProject.add(arg0);
                         }
                     }
                 }
@@ -633,28 +637,39 @@ public class AggregateStrategies implements 
ImplementationRuleFactory {
                 .collect(ImmutableList.toImmutableList());
 
         if (project != null) {
-            argumentsOfAggregateFunction = Project.findProject(
-                        argumentsOfAggregateFunction, project.getProjects())
-                    .stream()
-                    .map(p -> p instanceof Alias ? p.child(0) : p)
-                    .collect(ImmutableList.toImmutableList());
-        }
-
-        onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
-                .stream()
-                .allMatch(argument -> {
-                    if (argument instanceof SlotReference) {
-                        return true;
+            List<Expression> processedExpressions = new ArrayList<>();
+            List<? extends Expression> projections = 
Project.findProject(argumentsOfAggregateFunction,
+                    project.getProjects());
+
+            for (int i = 0, size = projections.size(); i < size; i++) {
+                // Process the expression (replace Alias with its child)
+                boolean needCheckSlotNull = 
expressionAfterProject.contains(argumentsOfAggregateFunction.get(i));
+                Expression p = projections.get(i);
+                Expression argument = p instanceof Alias ? p.child(0) : p;
+                processedExpressions.add(argument);
+
+                // Check if the argument matches the required pattern
+                if (argument instanceof SlotReference) {
+                    // Argument is valid, continue
+                    if (needCheckSlotNull) {
+                        checkNullSlots.add((SlotReference) argument);
                     }
-                    if (argument instanceof Cast) {
-                        return argument.child(0) instanceof SlotReference
-                                && argument.getDataType().isNumericType()
-                                && 
argument.child(0).getDataType().isNumericType();
+                } else if (argument instanceof Cast) {
+                    boolean castMatch = argument.child(0) instanceof 
SlotReference
+                            && argument.getDataType().isNumericType()
+                            && argument.child(0).getDataType().isNumericType();
+                    if (!castMatch) {
+                        return canNotPush;
+                    } else {
+                        if (needCheckSlotNull) {
+                            checkNullSlots.add((SlotReference) 
argument.child(0));
+                        }
                     }
-                    return false;
-                });
-        if (!onlyContainsSlotOrNumericCastSlot) {
-            return canNotPush;
+                } else {
+                    return canNotPush;
+                }
+            }
+            argumentsOfAggregateFunction = processedExpressions;
         }
 
         Set<PushDownAggOp> pushDownAggOps = functionClasses.stream()
diff --git 
a/regression-test/suites/nereids_p0/explain/test_pushdown_explain.groovy 
b/regression-test/suites/nereids_p0/explain/test_pushdown_explain.groovy
index 25053f24023..220ab4038c5 100644
--- a/regression-test/suites/nereids_p0/explain/test_pushdown_explain.groovy
+++ b/regression-test/suites/nereids_p0/explain/test_pushdown_explain.groovy
@@ -107,6 +107,40 @@ suite("test_pushdown_explain") {
         contains "pushAggOp=COUNT"
     }
 
+    // test projection in agg pushdown rule
+    explain {
+        sql("select count(a) from (select non_nullable_col as a from 
test_null_columns) t1;")
+        contains "pushAggOp=COUNT"
+    }
+    explain {
+        sql("select count(a) from (select nullable_col as a from 
test_null_columns) t1;")
+        contains "pushAggOp=NONE"
+    }
+    explain {
+        sql("select count(a), min(a) from (select non_nullable_col as a from 
test_null_columns) t1;")
+        contains "pushAggOp=MIX"
+    }
+    explain {
+        sql("select count(a), min(a) from (select nullable_col as a from 
test_null_columns) t1;")
+        contains "pushAggOp=NONE"
+    }
+    explain {
+        sql("select count(a), min(b) from (select nullable_col as a, 
non_nullable_col as b from test_null_columns) t1;")
+        contains "pushAggOp=NONE"
+    }
+    explain {
+        sql("select count(b), min(a) from (select nullable_col as a, 
non_nullable_col as b from test_null_columns) t1;")
+        contains "pushAggOp=MIX"
+    }
+    explain {
+        sql("select count(non_nullable_col), max(nullable_col) from 
test_null_columns;")
+        contains "pushAggOp=MIX"
+    }
+    explain {
+        sql("select count(nullable_col), max(non_nullable_col) from 
test_null_columns;")
+        contains "pushAggOp=NONE"
+    }
+
     explain {
         sql("select count(non_nullable_col), min(non_nullable_col), 
max(non_nullable_col) from test_null_columns;")
         contains "pushAggOp=MIX"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to