morrySnow commented on code in PR #63318:
URL: https://github.com/apache/doris/pull/63318#discussion_r3273207874


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java:
##########
@@ -47,19 +49,10 @@ public class InferAggNotNull extends OneRewriteRuleFactory {
     public Rule build() {
         return logicalAggregate()
                 .when(agg -> agg.getGroupByExpressions().size() == 0)
-                .when(agg -> agg.getAggregateFunctions().size() == 1)
-                .when(agg -> {
-                    Set<AggregateFunction> funcs = agg.getAggregateFunctions();
-                    return funcs.stream().allMatch(f -> f instanceof Count)
-                            || funcs.stream().allMatch(f -> f instanceof Avg)
-                            || funcs.stream().allMatch(f -> f instanceof Sum)
-                            || funcs.stream().allMatch(f -> f instanceof Max)
-                            || funcs.stream().allMatch(f -> f instanceof Min);
-                }).thenApply(ctx -> {
+                .thenApply(ctx -> {
                     LogicalAggregate<Plan> agg = ctx.root;
-                    Set<Expression> exprs = 
agg.getAggregateFunctions().stream().flatMap(f -> f.children().stream())

Review Comment:
   the better way is optimize `getAggregateFunctions` directly 



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java:
##########
@@ -47,19 +49,10 @@ public class InferAggNotNull extends OneRewriteRuleFactory {
     public Rule build() {
         return logicalAggregate()
                 .when(agg -> agg.getGroupByExpressions().size() == 0)
-                .when(agg -> agg.getAggregateFunctions().size() == 1)
-                .when(agg -> {
-                    Set<AggregateFunction> funcs = agg.getAggregateFunctions();
-                    return funcs.stream().allMatch(f -> f instanceof Count)
-                            || funcs.stream().allMatch(f -> f instanceof Avg)
-                            || funcs.stream().allMatch(f -> f instanceof Sum)
-                            || funcs.stream().allMatch(f -> f instanceof Max)
-                            || funcs.stream().allMatch(f -> f instanceof Min);
-                }).thenApply(ctx -> {
+                .thenApply(ctx -> {
                     LogicalAggregate<Plan> agg = ctx.root;
-                    Set<Expression> exprs = 
agg.getAggregateFunctions().stream().flatMap(f -> f.children().stream())
-                            .collect(Collectors.toSet());
-                    Set<Expression> isNotNulls = 
ExpressionUtils.inferNotNull(exprs, ctx.cascadesContext);
+                    Set<AggregateFunction> aggregateFunctions = 
collectAggregateFunctions(agg);
+                    Set<Expression> isNotNulls = 
inferCommonNotNulls(aggregateFunctions, ctx.cascadesContext);

Review Comment:
   Since ExpressionUtils.inferNotNull has been updated, is it still necessary 
to use inferCommonNotNulls?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to