This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new db27399d28a [opt](fe) Bound not-null inference cost (#63318)
db27399d28a is described below

commit db27399d28a590c5604df44ebba0ed159368604a
Author: minghong <[email protected]>
AuthorDate: Tue May 26 18:42:07 2026 +0800

    [opt](fe) Bound not-null inference cost (#63318)
    
    ### What problem does this PR solve?
    
    Issue Number: close #xxx
    
    Related PR: #xxx
    
    Problem Summary:
    Not-null inference replaces candidate slots with NULL and folds each
    predicate, so wide, deep, or high-slot expressions can make rewrite
    rules spend excessive time in repeated replace and fold work. Aggregate
    not-null inference also needed to handle multiple aggregate outputs
    conservatively instead of inferring from all aggregate arguments as one
    set.
    
    This change adds a shared bounded guard for not-null inference, reuses
    it from aggregate inference, and lets general callers skip only
    expensive predicates while preserving the original query predicates. It
    also reworks join inference to compute nullable-rejecting
    slots once and reuse them for both sides, and makes aggregate inference
    require a common inferred not-null predicate across supported
     aggregate functions.
    
    optimized rule:
    InferAggNotNull、InferFilterNotNull、InferJoinNotNull、EliminateNotNull
---
 .../nereids/rules/rewrite/InferAggNotNull.java     | 66 ++++++++++++++----
 .../nereids/rules/rewrite/InferJoinNotNull.java    | 29 +++++---
 .../nereids/trees/plans/algebra/Aggregate.java     | 18 ++++-
 .../apache/doris/nereids/util/ExpressionUtils.java | 55 ++++++++++++++-
 .../rules/rewrite/EliminateNotNullTest.java        | 77 ++++++++++++++++++++
 .../nereids/rules/rewrite/InferAggNotNullTest.java | 81 ++++++++++++++++++++++
 .../rules/rewrite/InferFilterNotNullTest.java      | 28 ++++++++
 .../rules/rewrite/InferJoinNotNullTest.java        | 33 +++++++++
 8 files changed, 363 insertions(+), 24 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java
index e30190592a6..4daf320811a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -36,8 +37,8 @@ import org.apache.doris.nereids.util.PlanUtils;
 import com.google.common.collect.ImmutableSet;
 
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.Set;
-import java.util.stream.Collectors;
 
 /**
  * InferNotNull from Agg count(distinct);
@@ -47,19 +48,10 @@ public class InferAggNotNull extends OneRewriteRuleFactory {
     public Rule build() {
         return logicalAggregate()
                 .when(agg -> agg.getGroupByExpressions().size() == 0)
-                .when(agg -> agg.getAggregateFunctions().size() == 1)
-                .when(agg -> {
-                    Set<AggregateFunction> funcs = agg.getAggregateFunctions();
-                    return funcs.stream().allMatch(f -> f instanceof Count)
-                            || funcs.stream().allMatch(f -> f instanceof Avg)
-                            || funcs.stream().allMatch(f -> f instanceof Sum)
-                            || funcs.stream().allMatch(f -> f instanceof Max)
-                            || funcs.stream().allMatch(f -> f instanceof Min);
-                }).thenApply(ctx -> {
+                .thenApply(ctx -> {
                     LogicalAggregate<Plan> agg = ctx.root;
-                    Set<Expression> exprs = 
agg.getAggregateFunctions().stream().flatMap(f -> f.children().stream())
-                            .collect(Collectors.toSet());
-                    Set<Expression> isNotNulls = 
ExpressionUtils.inferNotNull(exprs, ctx.cascadesContext);
+                    Set<AggregateFunction> aggregateFunctions = 
agg.getAggregateFunctions();
+                    Set<Expression> isNotNulls = 
inferCommonNotNulls(aggregateFunctions, ctx.cascadesContext);
                     Set<Expression> predicates = Collections.emptySet();
                     if ((agg.child() instanceof Filter)) {
                         predicates = ((Filter) agg.child()).getConjuncts();
@@ -80,4 +72,52 @@ public class InferAggNotNull extends OneRewriteRuleFactory {
                     return 
agg.withChildren(PlanUtils.filter(needGenerateNotNulls, agg.child()).get());
                 }).toRule(RuleType.INFER_AGG_NOT_NULL);
     }
+
+    private Set<Expression> inferCommonNotNulls(
+            Set<AggregateFunction> aggregateFunctions, CascadesContext 
cascadesContext) {
+        if (aggregateFunctions.isEmpty()) {
+            return Collections.emptySet();
+        }
+        for (AggregateFunction aggregateFunction : aggregateFunctions) {
+            if (!canInferFunctionNotNull(aggregateFunction)) {
+                return Collections.emptySet();
+            }
+        }
+        Set<Expression> commonNotNulls = null;
+        for (AggregateFunction aggregateFunction : aggregateFunctions) {
+            Set<Expression> functionNotNulls = 
inferFunctionNotNulls(aggregateFunction, cascadesContext);
+            if (functionNotNulls.isEmpty()) {
+                return Collections.emptySet();
+            }
+            if (commonNotNulls == null) {
+                commonNotNulls = new HashSet<>(functionNotNulls);
+            } else {
+                commonNotNulls.retainAll(functionNotNulls);
+                if (commonNotNulls.isEmpty()) {
+                    return Collections.emptySet();
+                }
+            }
+        }
+        return commonNotNulls == null ? Collections.emptySet() : 
commonNotNulls;
+    }
+
+    private Set<Expression> inferFunctionNotNulls(
+            AggregateFunction aggregateFunction, CascadesContext 
cascadesContext) {
+        return 
ExpressionUtils.inferNotNull(ImmutableSet.copyOf(aggregateFunction.children()), 
cascadesContext);
+    }
+
+    private boolean canInferFunctionNotNull(AggregateFunction 
aggregateFunction) {
+        return isSupportedAggregateFunction(aggregateFunction)
+                && !aggregateFunction.children().isEmpty()
+                && 
ExpressionUtils.isCheapEnoughToInferNotNull(aggregateFunction.children());
+    }
+
+    private boolean isSupportedAggregateFunction(AggregateFunction 
aggregateFunction) {
+        return aggregateFunction instanceof Count
+                || aggregateFunction instanceof Avg
+                || aggregateFunction instanceof Sum
+                || aggregateFunction instanceof Max
+                || aggregateFunction instanceof Min;
+    }
+
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java
index 5f87a3ec940..c3fb0e068cb 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNull.java
@@ -20,12 +20,17 @@ package org.apache.doris.nereids.rules.rewrite;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.PlanUtils;
 
+import com.google.common.collect.ImmutableSet;
+
 import java.util.LinkedHashSet;
 import java.util.Set;
 
@@ -50,23 +55,21 @@ public class InferJoinNotNull extends OneRewriteRuleFactory 
{
                 Set<Expression> conjuncts = new LinkedHashSet<>();
                 conjuncts.addAll(join.getHashJoinConjuncts());
                 conjuncts.addAll(join.getOtherJoinConjuncts());
+                Set<Slot> notNullSlots = ExpressionUtils.inferNotNullSlots(
+                        conjuncts, ctx.cascadesContext);
 
                 Plan left = join.left();
                 Plan right = join.right();
                 if (join.getJoinType().isInnerJoin() || 
join.getJoinType().isAsofInnerJoin()) {
-                    Set<Expression> leftNotNull = ExpressionUtils.inferNotNull(
-                            conjuncts, join.left().getOutputSet(), 
ctx.cascadesContext);
-                    Set<Expression> rightNotNull = 
ExpressionUtils.inferNotNull(
-                            conjuncts, join.right().getOutputSet(), 
ctx.cascadesContext);
+                    Set<Expression> leftNotNull = inferNotNull(notNullSlots, 
join.left().getOutputSet());
+                    Set<Expression> rightNotNull = inferNotNull(notNullSlots, 
join.right().getOutputSet());
                     left = PlanUtils.filterOrSelf(leftNotNull, join.left());
                     right = PlanUtils.filterOrSelf(rightNotNull, join.right());
                 } else if (join.getJoinType() == JoinType.LEFT_SEMI_JOIN) {
-                    Set<Expression> leftNotNull = ExpressionUtils.inferNotNull(
-                            conjuncts, join.left().getOutputSet(), 
ctx.cascadesContext);
+                    Set<Expression> leftNotNull = inferNotNull(notNullSlots, 
join.left().getOutputSet());
                     left = PlanUtils.filterOrSelf(leftNotNull, join.left());
                 } else {
-                    Set<Expression> rightNotNull = 
ExpressionUtils.inferNotNull(
-                            conjuncts, join.right().getOutputSet(), 
ctx.cascadesContext);
+                    Set<Expression> rightNotNull = inferNotNull(notNullSlots, 
join.right().getOutputSet());
                     right = PlanUtils.filterOrSelf(rightNotNull, join.right());
                 }
 
@@ -76,4 +79,14 @@ public class InferJoinNotNull extends OneRewriteRuleFactory {
                 return join.withChildren(left, right);
             }).toRule(RuleType.INFER_JOIN_NOT_NULL);
     }
+
+    private Set<Expression> inferNotNull(Set<Slot> notNullSlots, Set<Slot> 
outputSlots) {
+        ImmutableSet.Builder<Expression> predicates = 
ImmutableSet.builderWithExpectedSize(notNullSlots.size());
+        for (Slot slot : notNullSlots) {
+            if (outputSlots.contains(slot)) {
+                predicates.add(new Not(new IsNull(slot), true));
+            }
+        }
+        return predicates.build();
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
index 12fc9608fd3..f67703c23c9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java
@@ -28,7 +28,6 @@ import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisit
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.UnaryPlan;
 import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
-import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableSet;
@@ -60,8 +59,23 @@ public interface Aggregate<CHILD_TYPE extends Plan> extends 
UnaryPlan<CHILD_TYPE
         return withAggOutput(prunedOutputs);
     }
 
+    /**
+     * get aggregate functions
+     * aggregate functions cannot be nested, so we stop recursion when we find 
an aggregate function,
+     * and do not need to traverse its children.
+     */
     default Set<AggregateFunction> getAggregateFunctions() {
-        return ExpressionUtils.collect(getOutputExpressions(), 
AggregateFunction.class::isInstance);
+        ImmutableSet.Builder<AggregateFunction> aggregateFunctions = 
ImmutableSet.builder();
+        for (Expression outputExpression : getOutputExpressions()) {
+            outputExpression.foreach(expression -> {
+                if (expression instanceof AggregateFunction) {
+                    aggregateFunctions.add((AggregateFunction) expression);
+                    return true;
+                }
+                return false;
+            });
+        }
+        return aggregateFunctions.build();
     }
 
     /**getAggregateFunctionWithGuardExpr*/
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index e594556eda3..5dcd65d69c6 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -122,6 +122,9 @@ import java.util.stream.Collectors;
 public class ExpressionUtils {
 
     public static final List<Expression> EMPTY_CONDITION = ImmutableList.of();
+    private static final int MAX_INFER_NOT_NULL_EXPR_WIDTH = 256;
+    private static final int MAX_INFER_NOT_NULL_EXPR_DEPTH = 64;
+    private static final int MAX_INFER_NOT_NULL_INPUT_SLOTS = 32;
 
     public static List<Expression> extractConjunction(Expression expr) {
         return extract(And.class, expr);
@@ -784,7 +787,7 @@ public class ExpressionUtils {
      */
     public static Set<Slot> inferNotNullSlots(Set<Expression> predicates, 
CascadesContext cascadesContext) {
         ImmutableSet.Builder<Slot> notNullSlots = 
ImmutableSet.builderWithExpectedSize(predicates.size());
-        for (Expression predicate : predicates) {
+        for (Expression predicate : 
filterCheapPredicatesForNotNull(predicates)) {
             for (Slot slot : predicate.getInputSlots()) {
                 Map<Expression, Expression> replaceMap = new HashMap<>();
                 Literal nullLiteral = new NullLiteral(slot.getDataType());
@@ -800,6 +803,56 @@ public class ExpressionUtils {
         return notNullSlots.build();
     }
 
+    /**
+     * Return whether all predicates are cheap enough for not-null inference.
+     */
+    public static boolean isCheapEnoughToInferNotNull(Collection<? extends 
Expression> predicates) {
+        Set<Slot> inputSlots = new HashSet<>();
+        for (Expression predicate : predicates) {
+            Optional<Set<Slot>> mergedInputSlots = 
mergeInputSlotsIfCheap(predicate, inputSlots);
+            if (!mergedInputSlots.isPresent()) {
+                return false;
+            }
+            inputSlots = mergedInputSlots.get();
+        }
+        return true;
+    }
+
+    /**
+     * Filter predicates that are cheap enough for not-null inference.
+     */
+    public static Set<Expression> filterCheapPredicatesForNotNull(
+            Collection<? extends Expression> predicates) {
+        Set<Slot> inputSlots = new HashSet<>();
+        Set<Expression> cheapPredicates = Sets.newLinkedHashSet();
+        for (Expression predicate : predicates) {
+            Optional<Set<Slot>> mergedInputSlots = 
mergeInputSlotsIfCheap(predicate, inputSlots);
+            if (!mergedInputSlots.isPresent()) {
+                continue;
+            }
+            inputSlots = mergedInputSlots.get();
+            cheapPredicates.add(predicate);
+        }
+        return cheapPredicates;
+    }
+
+    private static Optional<Set<Slot>> mergeInputSlotsIfCheap(Expression 
predicate, Set<Slot> inputSlots) {
+        if (predicate.getWidth() > MAX_INFER_NOT_NULL_EXPR_WIDTH
+                || predicate.getDepth() > MAX_INFER_NOT_NULL_EXPR_DEPTH) {
+            return Optional.empty();
+        }
+        Set<Slot> predicateInputSlots = predicate.getInputSlots();
+        if (predicateInputSlots.size() > MAX_INFER_NOT_NULL_INPUT_SLOTS) {
+            return Optional.empty();
+        }
+        Set<Slot> mergedInputSlots = new HashSet<>(inputSlots);
+        mergedInputSlots.addAll(predicateInputSlots);
+        if (mergedInputSlots.size() > MAX_INFER_NOT_NULL_INPUT_SLOTS) {
+            return Optional.empty();
+        }
+        return Optional.of(mergedInputSlots);
+    }
+
     /**
      * infer notNulls slot from predicate
      */
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNullTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNullTest.java
new file mode 100644
index 00000000000..5486aae91f7
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNullTest.java
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.junit.jupiter.api.Test;
+
+class EliminateNotNullTest implements MemoPatternMatchSupported {
+    private final SlotReference slot = new SlotReference("nullable_col", 
IntegerType.INSTANCE, true);
+    private final LogicalOneRowRelation relation = new 
LogicalOneRowRelation(new RelationId(1), ImmutableList.of(slot));
+
+    @Test
+    void testEliminateNotNullForSimplePredicate() {
+        Expression simplePredicate = new EqualTo(slot, Literal.of(1));
+        Expression explicitNotNull = new Not(new IsNull(slot));
+        LogicalPlan plan = new LogicalPlanBuilder(relation)
+                .filter(ImmutableSet.of(simplePredicate, explicitNotNull))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new EliminateNotNull())
+                .matches(logicalFilter().when(filter -> 
filter.getConjuncts().size() == 1));
+    }
+
+    @Test
+    void testKeepNotNullWhenOnlyWidePredicateCanProveIt() {
+        Expression widePredicate = new EqualTo(repeatAdd(slot, 257), 
Literal.of(1));
+        Expression explicitNotNull = new Not(new IsNull(slot));
+        LogicalPlan plan = new LogicalPlanBuilder(relation)
+                .filter(ImmutableSet.of(widePredicate, explicitNotNull))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new EliminateNotNull())
+                .matches(logicalFilter().when(filter -> 
filter.getConjuncts().size() == 2));
+    }
+
+    private Expression repeatAdd(Expression expression, int width) {
+        if (width == 1) {
+            return expression;
+        }
+        int leftWidth = width / 2;
+        return new Add(repeatAdd(expression, leftWidth), repeatAdd(expression, 
width - leftWidth));
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java
index 7d20c2f22a6..23b108c8347 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java
@@ -19,7 +19,11 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Not;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.util.LogicalPlanBuilder;
@@ -29,8 +33,11 @@ import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
 import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.util.Set;
+
 class InferAggNotNullTest implements MemoPatternMatchSupported {
     private final LogicalOlapScan scan1 = 
PlanConstructor.newLogicalOlapScan(0, "t1", 0);
 
@@ -51,6 +58,62 @@ class InferAggNotNullTest implements 
MemoPatternMatchSupported {
                 );
     }
 
+    @Test
+    void testInferMultipleAggregateSameInput() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .aggGroupUsingIndex(ImmutableList.of(),
+                        ImmutableList.of(
+                                new Alias(new Avg(scan1.getOutput().get(1)), 
"avg_k"),
+                                new Alias(new Sum(scan1.getOutput().get(1)), 
"sum_k")))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new InferAggNotNull())
+                .matches(
+                        logicalAggregate(
+                                logicalFilter().when(filter -> 
filter.getConjuncts().size() == 1
+                                        && filter.getConjuncts().stream()
+                                        .allMatch(e -> ((Not) 
e).isGeneratedIsNotNull()))
+                        )
+                );
+    }
+
+    @Test
+    void testNotInferMultipleAggregateDifferentInputs() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .aggGroupUsingIndex(ImmutableList.of(),
+                        ImmutableList.of(
+                                new Alias(new Avg(scan1.getOutput().get(1)), 
"avg_k1"),
+                                new Alias(new Sum(scan1.getOutput().get(0)), 
"sum_k2")))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new InferAggNotNull())
+                .matches(
+                        logicalAggregate(
+                                logicalOlapScan()
+                        )
+                );
+    }
+
+    @Test
+    void testNotInferMultipleAggregateWithCountStar() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .aggGroupUsingIndex(ImmutableList.of(),
+                        ImmutableList.of(
+                                new Alias(new Avg(scan1.getOutput().get(1)), 
"avg_k"),
+                                new Alias(new Count(), "count_star")))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new InferAggNotNull())
+                .matches(
+                        logicalAggregate(
+                                logicalOlapScan()
+                        )
+                );
+    }
+
     @Test
     void testCountStar() {
         LogicalPlan plan = new LogicalPlanBuilder(scan1)
@@ -66,4 +129,22 @@ class InferAggNotNullTest implements 
MemoPatternMatchSupported {
                         )
                 );
     }
+
+    @Test
+    void testGetAggregateFunctionsStopsAtAggregateFunction() {
+        // Use different agg function types for inner (Avg) and outer (Count),
+        // so we can verify by instanceof regardless of how the plan builder
+        // clones/transforms expressions internally.
+        Avg inner = new Avg(scan1.getOutput().get(1));
+        Count outer = new Count(false, inner);
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .aggGroupUsingIndex(ImmutableList.of(), ImmutableList.of(new 
Alias(outer, "cnt")))
+                .build();
+
+        Set<AggregateFunction> aggregateFunctions = ((LogicalAggregate<?>) 
plan).getAggregateFunctions();
+        System.out.println("aggregateFunctions: " + aggregateFunctions);
+        Assertions.assertEquals(1, aggregateFunctions.size());
+        Assertions.assertTrue(aggregateFunctions.stream().allMatch(f -> f 
instanceof Count),
+                "should collect only the outer Count, got: " + 
aggregateFunctions);
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNullTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNullTest.java
index 9e4335db3ed..bf9d1d31f71 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNullTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNullTest.java
@@ -19,7 +19,9 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@@ -30,6 +32,7 @@ import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
+import com.google.common.collect.ImmutableSet;
 import org.junit.jupiter.api.Test;
 
 class InferFilterNotNullTest implements MemoPatternMatchSupported {
@@ -77,4 +80,29 @@ class InferFilterNotNullTest implements 
MemoPatternMatchSupported {
                         logicalFilter().when(filter -> 
filter.getConjuncts().size() == 1)
                 );
     }
+
+    @Test
+    void testSkipWidePredicateButKeepSimplePredicate() {
+        Expression widePredicate = new 
EqualTo(repeatAdd(scan1.getOutput().get(0), 257), Literal.of(1));
+        Expression simplePredicate = new EqualTo(scan1.getOutput().get(1), 
Literal.of(1));
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .filter(ImmutableSet.of(widePredicate, simplePredicate))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new InferFilterNotNull())
+                .matches(
+                        logicalFilter().when(filter -> 
filter.getConjuncts().stream()
+                                .filter(e -> e instanceof Not && ((Not) 
e).isGeneratedIsNotNull())
+                                .count() == 1)
+                );
+    }
+
+    private Expression repeatAdd(Expression expression, int width) {
+        if (width == 1) {
+            return expression;
+        }
+        int leftWidth = width / 2;
+        return new Add(repeatAdd(expression, leftWidth), repeatAdd(expression, 
width - leftWidth));
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java
index d963363a379..4ea2466edb2 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferJoinNotNullTest.java
@@ -18,6 +18,10 @@
 package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@@ -27,6 +31,7 @@ import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
+import com.google.common.collect.ImmutableList;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 
@@ -92,6 +97,27 @@ class InferJoinNotNullTest implements 
MemoPatternMatchSupported {
                 );
     }
 
+    @Test
+    void testSkipWideOtherConjunctButKeepHashConjunct() {
+        Expression widePredicate = new 
EqualTo(repeatAdd(scan1.getOutput().get(1), 257), Literal.of(1));
+        LogicalPlan innerJoin = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN,
+                        ImmutableList.of(new EqualTo(scan1.getOutput().get(0), 
scan2.getOutput().get(0))),
+                        ImmutableList.of(widePredicate))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), innerJoin)
+                .applyTopDown(new InferJoinNotNull())
+                .matches(
+                        innerLogicalJoin(
+                                logicalFilter().when(f -> 
f.getPredicate().toString()
+                                        .equals("( not id#10000 IS NULL)")),
+                                logicalFilter().when(f -> 
f.getPredicate().toString()
+                                        .equals("( not id#10002 IS NULL)"))
+                        )
+                );
+    }
+
     @Test
     void testInferAndEliminate() {
         LogicalPlan plan = new LogicalPlanBuilder(scan1)
@@ -109,4 +135,11 @@ class InferJoinNotNullTest implements 
MemoPatternMatchSupported {
                 );
     }
 
+    private Expression repeatAdd(Expression expression, int width) {
+        if (width == 1) {
+            return expression;
+        }
+        int leftWidth = width / 2;
+        return new Add(repeatAdd(expression, leftWidth), repeatAdd(expression, 
width - leftWidth));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to