morrySnow commented on code in PR #60757: URL: https://github.com/apache/doris/pull/60757#discussion_r3387234408
########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * push down aggregation + */ +public class PushDownAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private static final Logger LOG = LoggerFactory.getLogger(PushDownAggregation.class); + + public final EagerAggRewriter writer = new EagerAggRewriter(); + + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Count.class, + Sum.class, + Max.class, + Min.class); + + private final Set<Class> acceptNodeType = Sets.newHashSet( + LogicalUnion.class, + LogicalProject.class, + LogicalFilter.class, + LogicalRelation.class, + LogicalJoin.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (SessionVariable.isFeDebug()) { + try { + new AdjustNullable(false).rewriteRoot(plan, null); + } catch (Exception e) { + LOG.warn("(PushDownAggregation) input plan has nullable problem", e); + return plan; + } + } + int mode = SessionVariable.getEagerAggregationMode(); + if (mode < 0) { + return plan; + } else { + Plan result = plan.accept(this, jobContext); + if (SessionVariable.isFeDebug()) { + result = new AdjustNullable(true).rewriteRoot(result, null); + } + return result; + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) { + Plan newChild = agg.child().accept(this, context); + if (newChild != agg.child()) { + return agg.withChildren(newChild); + } + + if (agg.getSourceRepeat().isPresent()) { + return agg; + } + + List<SlotReference> groupKeys = new ArrayList<>(); + for (Expression groupKey : agg.getGroupByExpressions()) { + if (groupKey instanceof SlotReference) { + groupKeys.add((SlotReference) groupKey); + } else { + SessionVariable.throwAnalysisExceptionWhenFeDebug( + "PushDownAggregation failed: agg is not normalized\n " + + agg.treeString()); + return agg; + } + } Review Comment: 这个检查可能必要性不大,这里即使绕过去了。后面的 cascades也会报错 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * push down aggregation + */ +public class PushDownAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private static final Logger LOG = LoggerFactory.getLogger(PushDownAggregation.class); + + public final EagerAggRewriter writer = new EagerAggRewriter(); + + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Count.class, + Sum.class, + Max.class, + Min.class); + + private final Set<Class> acceptNodeType = Sets.newHashSet( + LogicalUnion.class, + LogicalProject.class, + LogicalFilter.class, + LogicalRelation.class, + LogicalJoin.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (SessionVariable.isFeDebug()) { + try { + new AdjustNullable(false).rewriteRoot(plan, null); + } catch (Exception e) { + LOG.warn("(PushDownAggregation) input plan has nullable problem", e); + return plan; + } + } + int mode = SessionVariable.getEagerAggregationMode(); + if (mode < 0) { + return plan; + } else { + Plan result = plan.accept(this, jobContext); + if (SessionVariable.isFeDebug()) { Review Comment: 所有的 fedebug,可以换用更精细的 org.apache.doris.common.util.DebugPointUtil ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * push down aggregation + */ +public class PushDownAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private static final Logger LOG = LoggerFactory.getLogger(PushDownAggregation.class); + + public final EagerAggRewriter writer = new EagerAggRewriter(); + + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Count.class, + Sum.class, + Max.class, + Min.class); + + private final Set<Class> acceptNodeType = Sets.newHashSet( + LogicalUnion.class, + LogicalProject.class, + LogicalFilter.class, + LogicalRelation.class, + LogicalJoin.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (SessionVariable.isFeDebug()) { + try { + new AdjustNullable(false).rewriteRoot(plan, null); + } catch (Exception e) { + LOG.warn("(PushDownAggregation) input plan has nullable problem", e); + return plan; + } + } + int mode = SessionVariable.getEagerAggregationMode(); + if (mode < 0) { + return plan; + } else { + Plan result = plan.accept(this, jobContext); + if (SessionVariable.isFeDebug()) { + result = new AdjustNullable(true).rewriteRoot(result, null); + } + return result; + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) { + Plan newChild = agg.child().accept(this, context); + if (newChild != agg.child()) { + return agg.withChildren(newChild); + } + + if (agg.getSourceRepeat().isPresent()) { + return agg; + } + + List<SlotReference> groupKeys = new ArrayList<>(); + for (Expression groupKey : agg.getGroupByExpressions()) { + if (groupKey instanceof SlotReference) { + groupKeys.add((SlotReference) groupKey); + } else { + SessionVariable.throwAnalysisExceptionWhenFeDebug( + "PushDownAggregation failed: agg is not normalized\n " + + agg.treeString()); + return agg; + } + } + + Set<AggregateFunction> aggFunctions = Sets.newHashSet(); + boolean hasDecomposedAggIf = false; + boolean hasCaseWhen = false; + Map<NamedExpression, List<AggregateFunction>> aggFunctionsForOutputExpressions = Maps.newHashMap(); + for (NamedExpression aggOutput : agg.getOutputExpressions()) { + List<AggregateFunction> funcs = Lists.newArrayList(); + aggFunctionsForOutputExpressions.put(aggOutput, funcs); + for (Object obj : aggOutput.collect(AggregateFunction.class::isInstance)) { + AggregateFunction aggFunction = (AggregateFunction) obj; + if (aggFunction.isDistinct()) { + return agg; + } + if (pushDownAggFunctionSet.contains(aggFunction.getClass())) { + // CaseWhen and If (which CASE WHEN is normalized into) must both be checked. + // When an agg function contains an If/CaseWhen whose condition tests IS NULL + // (e.g. count(if(col IS NULL, value, NULL))), pushing it to the nullable side + // of an outer join produces wrong results: null-extended rows make "col IS NULL" + // TRUE at the top level, but the pre-aggregated count slot becomes NULL after + // null-extension, and ifnull(sum(NULL), 0) = 0 instead of the correct 1. + if (!hasCaseWhen && aggFunction.anyMatch(e -> e instanceof CaseWhen || e instanceof If)) { + hasCaseWhen = true; + } Review Comment: 可能遇到 ifnull(x, y) 等函数依然有问题 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java: ########## @@ -92,6 +97,38 @@ public LogicalSetOperation(PlanType planType, Qualifier qualifier, List<NamedExp this.qualifier = qualifier; this.outputs = ImmutableList.copyOf(outputs); this.regularChildrenOutputs = ImmutableList.copyOf(regularChildrenOutputs); + //if (SessionVariable.isFeDebug()) { + // checkOutputs(outputs, regularChildrenOutputs) + // .ifPresent(msg -> SessionVariable.throwAnalysisExceptionWhenFeDebug(msg)); + //} + } + + // check every slot in outputs has its counterpart in regularChildrenOutputs, and they have the same data type. + private Optional<String> checkOutputs(List<NamedExpression> outputs, Review Comment: 看起来这个函数只在注释里面用了 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java: ########## @@ -176,10 +176,13 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, O } //4. union all joins and put producers to context List<List<SlotReference>> childrenOutputs = joins.stream() - .map(j -> j.getOutput().stream() + .map(j -> j.getOutput().stream() //.map(j -> j.getOutput().stream().distinct() Review Comment: 多余的注释 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalUnion.java: ########## @@ -424,4 +424,12 @@ private static List<NamedExpression> castToCommonType(List<NamedExpression> row, } return changed ? castedRow.build() : row; } + + public LogicalSetOperation withChildrenAndOutputs(List<Plan> children, List<NamedExpression> newOuptuts, Review Comment: 为什么不是返回 LogicalUnion? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java: ########## @@ -221,28 +219,21 @@ public Plan visitLogicalUnion(LogicalUnion union, PruneContext context) { } LogicalUnion prunedOutputUnion = pruneUnionOutput(union, context); // start prune children of union - List<Slot> originOutput = union.getOutput(); Review Comment: 这个修改是因为 union 的列裁剪有bug? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java: ########## @@ -0,0 +1,625 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.StatsDerive; +import org.apache.doris.nereids.stats.ExpressionEstimation; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.qe.SessionVariable; +import org.apache.doris.statistics.ColumnStatistic; +import org.apache.doris.statistics.Statistics; + +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * eager aggregation + * agg[sum(t1.A) group by t1.B] + * ->join(t1.C=t2.D) + * ->T1(A, B, C) + * ->T2(D) + * + * => + * agg[sum(x) group by t1.B] + * ->join(t1.C=t2.D) + * ->agg[sum(A) as x, group by B] + * ->T1(A, B, C) + * ->T2(D) + */ +public class EagerAggRewriter extends DefaultPlanRewriter<PushDownAggContext> { + private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000; + private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000; + private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100; + private final StatsDerive derive = new StatsDerive(false); + + @Override + public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, PushDownAggContext context) { + boolean toLeft = false; + boolean toRight = false; + boolean pushHere = false; + if (context.getAggFunctions().isEmpty()) { + // select t1.v from t1 join t2 on t1.id = t2.id group by t1.v, t2.v + // if no agg function, try to push agg to the child which contains all group keys + // TODO: consider t1.rows/(t1.id, t1.v).ndv and t2.rows/(t2.id, t2.v).ndv to determine push target + if (join.left().getOutputSet().containsAll(context.getGroupKeys())) { + toLeft = true; + } else if (join.right().getOutputSet().containsAll(context.getGroupKeys())) { + toRight = true; + } else { + pushHere = true; + } + } else { + for (AggregateFunction aggFunc : context.getAggFunctions()) { + if (join.left().getOutputSet().containsAll(aggFunc.getInputSlots())) { + toLeft = true; + } else if (join.right().getOutputSet().containsAll(aggFunc.getInputSlots())) { + toRight = true; + } else { + pushHere = true; + } + } + } + + if (pushHere || (toLeft && toRight)) { Review Comment: toLeft and toRight 永远是 false? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java: ########## @@ -0,0 +1,625 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.StatsDerive; +import org.apache.doris.nereids.stats.ExpressionEstimation; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.qe.SessionVariable; +import org.apache.doris.statistics.ColumnStatistic; +import org.apache.doris.statistics.Statistics; + +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * eager aggregation + * agg[sum(t1.A) group by t1.B] + * ->join(t1.C=t2.D) + * ->T1(A, B, C) + * ->T2(D) + * + * => + * agg[sum(x) group by t1.B] + * ->join(t1.C=t2.D) + * ->agg[sum(A) as x, group by B] + * ->T1(A, B, C) + * ->T2(D) + */ +public class EagerAggRewriter extends DefaultPlanRewriter<PushDownAggContext> { + private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000; + private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000; + private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100; + private final StatsDerive derive = new StatsDerive(false); + + @Override + public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, PushDownAggContext context) { + boolean toLeft = false; + boolean toRight = false; + boolean pushHere = false; + if (context.getAggFunctions().isEmpty()) { + // select t1.v from t1 join t2 on t1.id = t2.id group by t1.v, t2.v + // if no agg function, try to push agg to the child which contains all group keys + // TODO: consider t1.rows/(t1.id, t1.v).ndv and t2.rows/(t2.id, t2.v).ndv to determine push target + if (join.left().getOutputSet().containsAll(context.getGroupKeys())) { + toLeft = true; + } else if (join.right().getOutputSet().containsAll(context.getGroupKeys())) { + toRight = true; + } else { + pushHere = true; + } + } else { + for (AggregateFunction aggFunc : context.getAggFunctions()) { + if (join.left().getOutputSet().containsAll(aggFunc.getInputSlots())) { + toLeft = true; + } else if (join.right().getOutputSet().containsAll(aggFunc.getInputSlots())) { + toRight = true; + } else { + pushHere = true; + } Review Comment: 这里为什么不检查 group by key 了? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggContext.java: ########## @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * PushDownAggContext + */ +public class PushDownAggContext { + public static final int BIG_JOIN_BUILD_SIZE = 400_000; + // count(if(...)): if(...) push down as a whole + // sum/min/max(if(truePart, elsePart)): if(...) can be split to sum(truePart) and sum(elsePart) + public final boolean hasDecomposedAggIf; + // When aggFunc(if(...)) is present, pushing down the null-supplemented side of the outer join is avoided. + // This is because null values are highly error-prone, + // so the push-down operation is not performed during hashCaseWhen. + public final boolean hasCaseWhen; + private final List<AggregateFunction> aggFunctions; + private final List<SlotReference> groupKeys; + private final HashMap<AggregateFunction, Alias> aliasMap; + private final Set<Slot> aggFunctionsInputSlots; + + // cascadesContext is used for normalizeAgg + private final CascadesContext cascadesContext; + + private final boolean passThroughBigJoin; + + /** + * constructor + */ + public PushDownAggContext(List<AggregateFunction> aggFunctions, + List<SlotReference> groupKeys, Map<AggregateFunction, Alias> aliasMap, CascadesContext cascadesContext, + boolean passThroughBigJoin, boolean hasDecomposedAggIf, boolean hasCaseWhen) { + this.groupKeys = groupKeys.stream().distinct().collect(Collectors.toList()); + this.aggFunctions = ImmutableList.copyOf(aggFunctions); + this.cascadesContext = cascadesContext; + + HashMap<AggregateFunction, Alias> builtAliasMap = new HashMap<>(); + if (aliasMap == null) { + for (AggregateFunction aggFunction : this.aggFunctions) { + builtAliasMap.put(aggFunction, new Alias(aggFunction, aggFunction.getName())); + } + } else { + for (AggregateFunction aggFunction : this.aggFunctions) { + Alias alias = aliasMap.get(aggFunction); + if (alias == null) { + alias = new Alias(aggFunction, aggFunction.getName()); + } + builtAliasMap.put(aggFunction, alias); + } + } + this.aliasMap = builtAliasMap; + + this.aggFunctionsInputSlots = aggFunctions.stream() + .flatMap(aggFunction -> aggFunction.getInputSlots().stream()) + .filter(Slot.class::isInstance) + .collect(ImmutableSet.toImmutableSet()); + this.passThroughBigJoin = passThroughBigJoin; + this.hasDecomposedAggIf = hasDecomposedAggIf; + this.hasCaseWhen = hasCaseWhen; + } + + /** + * check validation + * @return true, if groupKeys is not empty and no group by key is in aggFunctionsInputSlots + */ + public boolean isValid() { + return !groupKeys.isEmpty() + && !groupKeys.stream().anyMatch(s -> aggFunctionsInputSlots.contains(s)); Review Comment: 使用 nonMatch。另外这里是为什么呢?sum(a) group by a 好像也没啥问题? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSetOperation.java: ########## @@ -92,6 +97,38 @@ public LogicalSetOperation(PlanType planType, Qualifier qualifier, List<NamedExp this.qualifier = qualifier; this.outputs = ImmutableList.copyOf(outputs); this.regularChildrenOutputs = ImmutableList.copyOf(regularChildrenOutputs); + //if (SessionVariable.isFeDebug()) { + // checkOutputs(outputs, regularChildrenOutputs) + // .ifPresent(msg -> SessionVariable.throwAnalysisExceptionWhenFeDebug(msg)); + //} Review Comment: 多余的注释需要删掉 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * push down aggregation + */ +public class PushDownAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private static final Logger LOG = LoggerFactory.getLogger(PushDownAggregation.class); + + public final EagerAggRewriter writer = new EagerAggRewriter(); + + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Count.class, + Sum.class, + Max.class, + Min.class); + + private final Set<Class> acceptNodeType = Sets.newHashSet( + LogicalUnion.class, + LogicalProject.class, + LogicalFilter.class, + LogicalRelation.class, + LogicalJoin.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (SessionVariable.isFeDebug()) { + try { + new AdjustNullable(false).rewriteRoot(plan, null); + } catch (Exception e) { + LOG.warn("(PushDownAggregation) input plan has nullable problem", e); + return plan; + } + } + int mode = SessionVariable.getEagerAggregationMode(); + if (mode < 0) { + return plan; + } else { + Plan result = plan.accept(this, jobContext); + if (SessionVariable.isFeDebug()) { + result = new AdjustNullable(true).rewriteRoot(result, null); + } + return result; + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) { + Plan newChild = agg.child().accept(this, context); + if (newChild != agg.child()) { + return agg.withChildren(newChild); + } + + if (agg.getSourceRepeat().isPresent()) { + return agg; + } + + List<SlotReference> groupKeys = new ArrayList<>(); + for (Expression groupKey : agg.getGroupByExpressions()) { + if (groupKey instanceof SlotReference) { + groupKeys.add((SlotReference) groupKey); + } else { + SessionVariable.throwAnalysisExceptionWhenFeDebug( + "PushDownAggregation failed: agg is not normalized\n " + + agg.treeString()); + return agg; + } + } + + Set<AggregateFunction> aggFunctions = Sets.newHashSet(); + boolean hasDecomposedAggIf = false; + boolean hasCaseWhen = false; + Map<NamedExpression, List<AggregateFunction>> aggFunctionsForOutputExpressions = Maps.newHashMap(); + for (NamedExpression aggOutput : agg.getOutputExpressions()) { + List<AggregateFunction> funcs = Lists.newArrayList(); + aggFunctionsForOutputExpressions.put(aggOutput, funcs); + for (Object obj : aggOutput.collect(AggregateFunction.class::isInstance)) { + AggregateFunction aggFunction = (AggregateFunction) obj; + if (aggFunction.isDistinct()) { + return agg; + } + if (pushDownAggFunctionSet.contains(aggFunction.getClass())) { + // CaseWhen and If (which CASE WHEN is normalized into) must both be checked. + // When an agg function contains an If/CaseWhen whose condition tests IS NULL + // (e.g. count(if(col IS NULL, value, NULL))), pushing it to the nullable side + // of an outer join produces wrong results: null-extended rows make "col IS NULL" + // TRUE at the top level, but the pre-aggregated count slot becomes NULL after + // null-extension, and ifnull(sum(NULL), 0) = 0 instead of the correct 1. + if (!hasCaseWhen && aggFunction.anyMatch(e -> e instanceof CaseWhen || e instanceof If)) { + hasCaseWhen = true; + } + if (aggFunction.arity() > 0 && aggFunction.child(0) instanceof If + && !(aggFunction instanceof Count)) { + // Decompose Sum/Max/Min(If(cond, a, b)) into separate agg functions. + // Count(If(...)) is NOT decomposed here because the top-level + // replacement (Count->Sum rollup) cannot match the decomposed + // Count(a)/Count(b) as sub-expressions of the original Count(If(cond,a,b)). + // Count(If(...)) is pushed down as-is and rolled up normally. + If body = (If) (aggFunction).child(0); + Set<Slot> valueSlots = Sets.newHashSet(body.getTrueValue().getInputSlots()); + valueSlots.addAll(body.getFalseValue().getInputSlots()); + if (body.getCondition().getInputSlots().stream().anyMatch(s -> valueSlots.contains(s))) { + // do not push down sum(if a then a else b) + return agg; + } + AggregateFunction aggTrue = (AggregateFunction) aggFunction.withChildren(body.getTrueValue()); + aggFunctions.add(aggTrue); + funcs.add(aggTrue); + if (!(body.getFalseValue() instanceof NullLiteral)) { + AggregateFunction aggFalse = + (AggregateFunction) aggFunction.withChildren(body.getFalseValue()); + aggFunctions.add(aggFalse); + funcs.add(aggFalse); + } + groupKeys.addAll(body.getCondition().getInputSlots() + .stream().map(slot -> (SlotReference) slot).collect(Collectors.toList())); + hasDecomposedAggIf = true; + } else { + aggFunctions.add(aggFunction); + funcs.add(aggFunction); + } + + } else { + return agg; + } + } + } + + groupKeys = groupKeys.stream().distinct().collect(Collectors.toList()); + if (!checkSubTreePattern(agg.child())) { + return agg; + } + + PushDownAggContext pushDownContext = new PushDownAggContext(new ArrayList<>(aggFunctions), + groupKeys, null, context.getCascadesContext(), false, hasDecomposedAggIf, hasCaseWhen); + if (!pushDownContext.isValid()) { + return agg; + } + try { + Plan child = agg.child().accept(writer, pushDownContext); + if (child != agg.child()) { + // agg has been pushed down, rewrite agg output expressions + // before: agg[sum(A), by (B)] + // ->join(C=D) + // ->scan(T1[A...]) + // ->scan(T2) + // after: agg[sum(x), by(B)] + // ->join(C=D) + // ->agg[sum(A) as x, by(B,C)] + // ->scan(T1[A...]) + // ->scan(T2) + List<NamedExpression> newOutputExpressions = new ArrayList<>(); + //Map<AggregateFunction, AggregateFunction> replaceMap = new HashMap<>(); + //for (AggregateFunction aggFunc : pushDownContext.getAliasMap().keySet()) { + // Alias alias = pushDownContext.getAliasMap().get(aggFunc); + // replaceMap.put(aggFunc, (AggregateFunction) aggFunc.withChildren((Expression) alias.toSlot())); + //} + + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof SlotReference) { + newOutputExpressions.add(ne); + } else { + // every expression has its own replaceMap + // aggregation(output=[min(A), sum(A)]) + // --> join + // -> T1 [A ...] + // -> T2 [...] + // => + // aggregation(output=[min(minA), sum(sumA)]) + // --> join + // -> agg(output=[min(A) as minA, sum(A) as sumA]) + // -> T1 [A ...] + // -> T2 [...] + // for min(A), replaceMap: A->minA + // for sum(A), replaceMap: A->sumA + // for count(A), replaceMap: count(A)->sum(countA), because count needs rollup to sum + Map<Expression, Expression> replaceMap = new HashMap<>(); + List<AggregateFunction> relatedAggFunc = aggFunctionsForOutputExpressions.get(ne); + for (AggregateFunction func : relatedAggFunc) { + Slot pushedDownSlot = pushDownContext.getAliasMap().get(func).toSlot(); + if (func instanceof Count) { + // For count(A), after pushdown we have count(A) as x, + // and the top agg should use sum(x) instead of count(x). + // Wrap with ifnull(..., 0) because COUNT never returns NULL, + // but after pushdown across an outer join, the intermediate count + // slot can be NULL (null-extended), making sum(NULL) = NULL. + Function rollUpFunc = ((RollUpTrait) func).constructRollUp(pushedDownSlot); + replaceMap.put(func, new Nvl(rollUpFunc, new BigIntLiteral(0))); + } else if (func.arity() > 0) { + // For sum/max/min, replace the child expression with the pushed down slot + replaceMap.put(func.child(0), pushedDownSlot); + } + } + NamedExpression replaceAliasExpr = (NamedExpression) ExpressionUtils.replace(ne, replaceMap); + replaceAliasExpr = (NamedExpression) ExpressionUtils.rebuildSignature(replaceAliasExpr); + newOutputExpressions.add(replaceAliasExpr); + } + } + LogicalAggregate<Plan> eagerAgg = + agg.withAggOutputChild(newOutputExpressions, child); + NormalizeAggregate normalizeAggregate = new NormalizeAggregate(); + return normalizeAggregate.normalizeAgg(eagerAgg, Optional.empty(), + context.getCascadesContext()); + } + } catch (RuntimeException e) { + String msg = "PushDownAggregation failed: " + e.getMessage() + "\n" + agg.treeString(); + LOG.info(msg, e); + SessionVariable.throwAnalysisExceptionWhenFeDebug(msg); + } + return agg; + } + + private boolean checkSubTreePattern(Plan root) { + return containsPushDownJoin(root) + && checkPlanNodeType(root); + } + + private boolean containsPushDownJoin(Plan root) { + if (root instanceof LogicalJoin && !((LogicalJoin) root).isMarkJoin()) { + return true; + } Review Comment: 这里可以对mark join 和 asof join 等做短路 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * push down aggregation + */ +public class PushDownAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private static final Logger LOG = LoggerFactory.getLogger(PushDownAggregation.class); + + public final EagerAggRewriter writer = new EagerAggRewriter(); + + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Count.class, + Sum.class, + Max.class, + Min.class); + + private final Set<Class> acceptNodeType = Sets.newHashSet( + LogicalUnion.class, + LogicalProject.class, + LogicalFilter.class, + LogicalRelation.class, + LogicalJoin.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (SessionVariable.isFeDebug()) { + try { + new AdjustNullable(false).rewriteRoot(plan, null); + } catch (Exception e) { + LOG.warn("(PushDownAggregation) input plan has nullable problem", e); + return plan; + } + } + int mode = SessionVariable.getEagerAggregationMode(); + if (mode < 0) { + return plan; + } else { + Plan result = plan.accept(this, jobContext); + if (SessionVariable.isFeDebug()) { + result = new AdjustNullable(true).rewriteRoot(result, null); + } + return result; + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) { + Plan newChild = agg.child().accept(this, context); + if (newChild != agg.child()) { + return agg.withChildren(newChild); + } + + if (agg.getSourceRepeat().isPresent()) { + return agg; + } + + List<SlotReference> groupKeys = new ArrayList<>(); + for (Expression groupKey : agg.getGroupByExpressions()) { + if (groupKey instanceof SlotReference) { + groupKeys.add((SlotReference) groupKey); + } else { + SessionVariable.throwAnalysisExceptionWhenFeDebug( + "PushDownAggregation failed: agg is not normalized\n " + + agg.treeString()); + return agg; + } + } + + Set<AggregateFunction> aggFunctions = Sets.newHashSet(); + boolean hasDecomposedAggIf = false; + boolean hasCaseWhen = false; + Map<NamedExpression, List<AggregateFunction>> aggFunctionsForOutputExpressions = Maps.newHashMap(); + for (NamedExpression aggOutput : agg.getOutputExpressions()) { + List<AggregateFunction> funcs = Lists.newArrayList(); + aggFunctionsForOutputExpressions.put(aggOutput, funcs); + for (Object obj : aggOutput.collect(AggregateFunction.class::isInstance)) { + AggregateFunction aggFunction = (AggregateFunction) obj; + if (aggFunction.isDistinct()) { + return agg; + } + if (pushDownAggFunctionSet.contains(aggFunction.getClass())) { + // CaseWhen and If (which CASE WHEN is normalized into) must both be checked. + // When an agg function contains an If/CaseWhen whose condition tests IS NULL + // (e.g. count(if(col IS NULL, value, NULL))), pushing it to the nullable side + // of an outer join produces wrong results: null-extended rows make "col IS NULL" + // TRUE at the top level, but the pre-aggregated count slot becomes NULL after + // null-extension, and ifnull(sum(NULL), 0) = 0 instead of the correct 1. + if (!hasCaseWhen && aggFunction.anyMatch(e -> e instanceof CaseWhen || e instanceof If)) { + hasCaseWhen = true; + } + if (aggFunction.arity() > 0 && aggFunction.child(0) instanceof If + && !(aggFunction instanceof Count)) { + // Decompose Sum/Max/Min(If(cond, a, b)) into separate agg functions. + // Count(If(...)) is NOT decomposed here because the top-level + // replacement (Count->Sum rollup) cannot match the decomposed + // Count(a)/Count(b) as sub-expressions of the original Count(If(cond,a,b)). + // Count(If(...)) is pushed down as-is and rolled up normally. + If body = (If) (aggFunction).child(0); + Set<Slot> valueSlots = Sets.newHashSet(body.getTrueValue().getInputSlots()); + valueSlots.addAll(body.getFalseValue().getInputSlots()); + if (body.getCondition().getInputSlots().stream().anyMatch(s -> valueSlots.contains(s))) { + // do not push down sum(if a then a else b) + return agg; + } Review Comment: 这块整段没看懂,为什么 if(x,y,z) 要拆开,以及为什么 if(x, x, y) 不能work?需要更清晰的注释 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java: ########## @@ -155,4 +155,8 @@ private Supplier<FunctionSignature> buildSignatureCache(Supplier<FunctionSignatu }); } } + + //public void rebuildSignature() { + // this.signatureCache = buildSignatureCache(null); + //} Review Comment: 多余的注释 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * push down aggregation + */ +public class PushDownAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private static final Logger LOG = LoggerFactory.getLogger(PushDownAggregation.class); + + public final EagerAggRewriter writer = new EagerAggRewriter(); + + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Count.class, + Sum.class, + Max.class, + Min.class); + + private final Set<Class> acceptNodeType = Sets.newHashSet( + LogicalUnion.class, + LogicalProject.class, + LogicalFilter.class, + LogicalRelation.class, + LogicalJoin.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (SessionVariable.isFeDebug()) { + try { + new AdjustNullable(false).rewriteRoot(plan, null); + } catch (Exception e) { + LOG.warn("(PushDownAggregation) input plan has nullable problem", e); + return plan; + } + } + int mode = SessionVariable.getEagerAggregationMode(); + if (mode < 0) { + return plan; + } else { + Plan result = plan.accept(this, jobContext); + if (SessionVariable.isFeDebug()) { + result = new AdjustNullable(true).rewriteRoot(result, null); + } + return result; + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) { + Plan newChild = agg.child().accept(this, context); + if (newChild != agg.child()) { + return agg.withChildren(newChild); + } + + if (agg.getSourceRepeat().isPresent()) { + return agg; + } + + List<SlotReference> groupKeys = new ArrayList<>(); + for (Expression groupKey : agg.getGroupByExpressions()) { + if (groupKey instanceof SlotReference) { + groupKeys.add((SlotReference) groupKey); + } else { + SessionVariable.throwAnalysisExceptionWhenFeDebug( + "PushDownAggregation failed: agg is not normalized\n " + + agg.treeString()); + return agg; + } + } + + Set<AggregateFunction> aggFunctions = Sets.newHashSet(); + boolean hasDecomposedAggIf = false; + boolean hasCaseWhen = false; + Map<NamedExpression, List<AggregateFunction>> aggFunctionsForOutputExpressions = Maps.newHashMap(); + for (NamedExpression aggOutput : agg.getOutputExpressions()) { + List<AggregateFunction> funcs = Lists.newArrayList(); + aggFunctionsForOutputExpressions.put(aggOutput, funcs); + for (Object obj : aggOutput.collect(AggregateFunction.class::isInstance)) { + AggregateFunction aggFunction = (AggregateFunction) obj; + if (aggFunction.isDistinct()) { + return agg; + } + if (pushDownAggFunctionSet.contains(aggFunction.getClass())) { + // CaseWhen and If (which CASE WHEN is normalized into) must both be checked. + // When an agg function contains an If/CaseWhen whose condition tests IS NULL + // (e.g. count(if(col IS NULL, value, NULL))), pushing it to the nullable side + // of an outer join produces wrong results: null-extended rows make "col IS NULL" + // TRUE at the top level, but the pre-aggregated count slot becomes NULL after + // null-extension, and ifnull(sum(NULL), 0) = 0 instead of the correct 1. + if (!hasCaseWhen && aggFunction.anyMatch(e -> e instanceof CaseWhen || e instanceof If)) { + hasCaseWhen = true; + } + if (aggFunction.arity() > 0 && aggFunction.child(0) instanceof If + && !(aggFunction instanceof Count)) { + // Decompose Sum/Max/Min(If(cond, a, b)) into separate agg functions. + // Count(If(...)) is NOT decomposed here because the top-level + // replacement (Count->Sum rollup) cannot match the decomposed + // Count(a)/Count(b) as sub-expressions of the original Count(If(cond,a,b)). + // Count(If(...)) is pushed down as-is and rolled up normally. + If body = (If) (aggFunction).child(0); + Set<Slot> valueSlots = Sets.newHashSet(body.getTrueValue().getInputSlots()); + valueSlots.addAll(body.getFalseValue().getInputSlots()); + if (body.getCondition().getInputSlots().stream().anyMatch(s -> valueSlots.contains(s))) { + // do not push down sum(if a then a else b) + return agg; + } + AggregateFunction aggTrue = (AggregateFunction) aggFunction.withChildren(body.getTrueValue()); + aggFunctions.add(aggTrue); + funcs.add(aggTrue); + if (!(body.getFalseValue() instanceof NullLiteral)) { + AggregateFunction aggFalse = + (AggregateFunction) aggFunction.withChildren(body.getFalseValue()); + aggFunctions.add(aggFalse); + funcs.add(aggFalse); + } + groupKeys.addAll(body.getCondition().getInputSlots() + .stream().map(slot -> (SlotReference) slot).collect(Collectors.toList())); + hasDecomposedAggIf = true; + } else { + aggFunctions.add(aggFunction); + funcs.add(aggFunction); + } + + } else { + return agg; + } + } + } + + groupKeys = groupKeys.stream().distinct().collect(Collectors.toList()); Review Comment: 这个应该移动到check 之后执行 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java: ########## @@ -221,28 +219,21 @@ public Plan visitLogicalUnion(LogicalUnion union, PruneContext context) { } LogicalUnion prunedOutputUnion = pruneUnionOutput(union, context); // start prune children of union - List<Slot> originOutput = union.getOutput(); - Set<Slot> prunedOutput = prunedOutputUnion.getOutputSet(); - List<Integer> prunedOutputIndexes = IntStream.range(0, originOutput.size()) - .filter(index -> prunedOutput.contains(originOutput.get(index))) - .boxed() - .collect(ImmutableList.toImmutableList()); - ImmutableList.Builder<Plan> prunedChildren = ImmutableList.builder(); ImmutableList.Builder<List<SlotReference>> prunedChildrenOutputs = ImmutableList.builder(); for (int i = 0; i < prunedOutputUnion.arity(); i++) { List<SlotReference> regularChildOutputs = prunedOutputUnion.getRegularChildOutput(i); RoaringBitmap prunedChildOutputExprIds = new RoaringBitmap(); - Builder<SlotReference> prunedChildOutputBuilder - = ImmutableList.builderWithExpectedSize(regularChildOutputs.size()); - for (Integer index : prunedOutputIndexes) { - SlotReference slot = regularChildOutputs.get(index); - prunedChildOutputBuilder.add(slot); - prunedChildOutputExprIds.add(slot.getExprId().asInt()); - } - - List<SlotReference> prunedChildOutput = prunedChildOutputBuilder.build(); + //Builder<SlotReference> prunedChildOutputBuilder + // = ImmutableList.builderWithExpectedSize(regularChildOutputs.size()); + //for (Integer index : prunedOutputIndexes) { + // SlotReference slot = regularChildOutputs.get(index); + // prunedChildOutputBuilder.add(slot); + // prunedChildOutputExprIds.add(slot.getExprId().asInt()); + //} + regularChildOutputs.forEach(col -> prunedChildOutputExprIds.add(col.getExprId().asInt())); + List<SlotReference> prunedChildOutput = regularChildOutputs; //prunedChildOutputBuilder.build(); Review Comment: 多余的注释 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java: ########## @@ -176,10 +176,13 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, O } //4. union all joins and put producers to context List<List<SlotReference>> childrenOutputs = joins.stream() - .map(j -> j.getOutput().stream() + .map(j -> j.getOutput().stream() //.map(j -> j.getOutput().stream().distinct() .map(SlotReference.class::cast) .collect(ImmutableList.toImmutableList())) .collect(ImmutableList.toImmutableList()); + //LogicalUnion union = new LogicalUnion(Qualifier.ALL, + // new ArrayList<>(join.getOutput().stream().distinct().collect(Collectors.toList())), + // childrenOutputs, ImmutableList.of(), false, joins); Review Comment: 多余的注释 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java: ########## @@ -319,8 +322,26 @@ private List<Plan> expandInnerJoin(CascadesContext ctx, Pair<List<Expression>, LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(), leftProducer.getCteId(), "", leftProducer); + List<NamedExpression> leftOutput = new ArrayList<>(); Review Comment: 这里是因为 or expansion 有bug? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java: ########## @@ -221,28 +219,21 @@ public Plan visitLogicalUnion(LogicalUnion union, PruneContext context) { } LogicalUnion prunedOutputUnion = pruneUnionOutput(union, context); // start prune children of union - List<Slot> originOutput = union.getOutput(); - Set<Slot> prunedOutput = prunedOutputUnion.getOutputSet(); - List<Integer> prunedOutputIndexes = IntStream.range(0, originOutput.size()) - .filter(index -> prunedOutput.contains(originOutput.get(index))) - .boxed() - .collect(ImmutableList.toImmutableList()); - ImmutableList.Builder<Plan> prunedChildren = ImmutableList.builder(); ImmutableList.Builder<List<SlotReference>> prunedChildrenOutputs = ImmutableList.builder(); for (int i = 0; i < prunedOutputUnion.arity(); i++) { List<SlotReference> regularChildOutputs = prunedOutputUnion.getRegularChildOutput(i); RoaringBitmap prunedChildOutputExprIds = new RoaringBitmap(); - Builder<SlotReference> prunedChildOutputBuilder - = ImmutableList.builderWithExpectedSize(regularChildOutputs.size()); - for (Integer index : prunedOutputIndexes) { - SlotReference slot = regularChildOutputs.get(index); - prunedChildOutputBuilder.add(slot); - prunedChildOutputExprIds.add(slot.getExprId().asInt()); - } - - List<SlotReference> prunedChildOutput = prunedChildOutputBuilder.build(); + //Builder<SlotReference> prunedChildOutputBuilder + // = ImmutableList.builderWithExpectedSize(regularChildOutputs.size()); + //for (Integer index : prunedOutputIndexes) { + // SlotReference slot = regularChildOutputs.get(index); + // prunedChildOutputBuilder.add(slot); + // prunedChildOutputExprIds.add(slot.getExprId().asInt()); + //} Review Comment: 多余的注释 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * push down aggregation + */ +public class PushDownAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private static final Logger LOG = LoggerFactory.getLogger(PushDownAggregation.class); + + public final EagerAggRewriter writer = new EagerAggRewriter(); + + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Count.class, + Sum.class, + Max.class, + Min.class); + + private final Set<Class> acceptNodeType = Sets.newHashSet( + LogicalUnion.class, + LogicalProject.class, + LogicalFilter.class, + LogicalRelation.class, + LogicalJoin.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (SessionVariable.isFeDebug()) { + try { + new AdjustNullable(false).rewriteRoot(plan, null); + } catch (Exception e) { + LOG.warn("(PushDownAggregation) input plan has nullable problem", e); + return plan; + } + } + int mode = SessionVariable.getEagerAggregationMode(); + if (mode < 0) { + return plan; + } else { + Plan result = plan.accept(this, jobContext); + if (SessionVariable.isFeDebug()) { + result = new AdjustNullable(true).rewriteRoot(result, null); + } + return result; + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) { + Plan newChild = agg.child().accept(this, context); + if (newChild != agg.child()) { + return agg.withChildren(newChild); + } + + if (agg.getSourceRepeat().isPresent()) { + return agg; + } + + List<SlotReference> groupKeys = new ArrayList<>(); + for (Expression groupKey : agg.getGroupByExpressions()) { + if (groupKey instanceof SlotReference) { + groupKeys.add((SlotReference) groupKey); + } else { + SessionVariable.throwAnalysisExceptionWhenFeDebug( + "PushDownAggregation failed: agg is not normalized\n " + + agg.treeString()); + return agg; + } + } + + Set<AggregateFunction> aggFunctions = Sets.newHashSet(); + boolean hasDecomposedAggIf = false; + boolean hasCaseWhen = false; + Map<NamedExpression, List<AggregateFunction>> aggFunctionsForOutputExpressions = Maps.newHashMap(); + for (NamedExpression aggOutput : agg.getOutputExpressions()) { + List<AggregateFunction> funcs = Lists.newArrayList(); + aggFunctionsForOutputExpressions.put(aggOutput, funcs); + for (Object obj : aggOutput.collect(AggregateFunction.class::isInstance)) { + AggregateFunction aggFunction = (AggregateFunction) obj; + if (aggFunction.isDistinct()) { + return agg; + } + if (pushDownAggFunctionSet.contains(aggFunction.getClass())) { + // CaseWhen and If (which CASE WHEN is normalized into) must both be checked. + // When an agg function contains an If/CaseWhen whose condition tests IS NULL + // (e.g. count(if(col IS NULL, value, NULL))), pushing it to the nullable side + // of an outer join produces wrong results: null-extended rows make "col IS NULL" + // TRUE at the top level, but the pre-aggregated count slot becomes NULL after + // null-extension, and ifnull(sum(NULL), 0) = 0 instead of the correct 1. + if (!hasCaseWhen && aggFunction.anyMatch(e -> e instanceof CaseWhen || e instanceof If)) { + hasCaseWhen = true; + } + if (aggFunction.arity() > 0 && aggFunction.child(0) instanceof If + && !(aggFunction instanceof Count)) { + // Decompose Sum/Max/Min(If(cond, a, b)) into separate agg functions. + // Count(If(...)) is NOT decomposed here because the top-level + // replacement (Count->Sum rollup) cannot match the decomposed + // Count(a)/Count(b) as sub-expressions of the original Count(If(cond,a,b)). + // Count(If(...)) is pushed down as-is and rolled up normally. + If body = (If) (aggFunction).child(0); + Set<Slot> valueSlots = Sets.newHashSet(body.getTrueValue().getInputSlots()); + valueSlots.addAll(body.getFalseValue().getInputSlots()); + if (body.getCondition().getInputSlots().stream().anyMatch(s -> valueSlots.contains(s))) { + // do not push down sum(if a then a else b) + return agg; + } + AggregateFunction aggTrue = (AggregateFunction) aggFunction.withChildren(body.getTrueValue()); + aggFunctions.add(aggTrue); + funcs.add(aggTrue); + if (!(body.getFalseValue() instanceof NullLiteral)) { + AggregateFunction aggFalse = + (AggregateFunction) aggFunction.withChildren(body.getFalseValue()); + aggFunctions.add(aggFalse); + funcs.add(aggFalse); + } + groupKeys.addAll(body.getCondition().getInputSlots() + .stream().map(slot -> (SlotReference) slot).collect(Collectors.toList())); + hasDecomposedAggIf = true; + } else { + aggFunctions.add(aggFunction); + funcs.add(aggFunction); + } + + } else { + return agg; + } + } + } + + groupKeys = groupKeys.stream().distinct().collect(Collectors.toList()); + if (!checkSubTreePattern(agg.child())) { + return agg; + } + + PushDownAggContext pushDownContext = new PushDownAggContext(new ArrayList<>(aggFunctions), + groupKeys, null, context.getCascadesContext(), false, hasDecomposedAggIf, hasCaseWhen); + if (!pushDownContext.isValid()) { + return agg; + } + try { + Plan child = agg.child().accept(writer, pushDownContext); + if (child != agg.child()) { + // agg has been pushed down, rewrite agg output expressions + // before: agg[sum(A), by (B)] + // ->join(C=D) + // ->scan(T1[A...]) + // ->scan(T2) + // after: agg[sum(x), by(B)] + // ->join(C=D) + // ->agg[sum(A) as x, by(B,C)] + // ->scan(T1[A...]) + // ->scan(T2) + List<NamedExpression> newOutputExpressions = new ArrayList<>(); + //Map<AggregateFunction, AggregateFunction> replaceMap = new HashMap<>(); + //for (AggregateFunction aggFunc : pushDownContext.getAliasMap().keySet()) { + // Alias alias = pushDownContext.getAliasMap().get(aggFunc); + // replaceMap.put(aggFunc, (AggregateFunction) aggFunc.withChildren((Expression) alias.toSlot())); + //} Review Comment: 多余的注释 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java: ########## @@ -0,0 +1,625 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.StatsDerive; +import org.apache.doris.nereids.stats.ExpressionEstimation; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.qe.SessionVariable; +import org.apache.doris.statistics.ColumnStatistic; +import org.apache.doris.statistics.Statistics; + +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * eager aggregation + * agg[sum(t1.A) group by t1.B] + * ->join(t1.C=t2.D) + * ->T1(A, B, C) + * ->T2(D) + * + * => + * agg[sum(x) group by t1.B] + * ->join(t1.C=t2.D) + * ->agg[sum(A) as x, group by B] + * ->T1(A, B, C) + * ->T2(D) + */ +public class EagerAggRewriter extends DefaultPlanRewriter<PushDownAggContext> { + private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000; + private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000; + private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100; + private final StatsDerive derive = new StatsDerive(false); + + @Override + public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, PushDownAggContext context) { + boolean toLeft = false; + boolean toRight = false; + boolean pushHere = false; + if (context.getAggFunctions().isEmpty()) { + // select t1.v from t1 join t2 on t1.id = t2.id group by t1.v, t2.v + // if no agg function, try to push agg to the child which contains all group keys + // TODO: consider t1.rows/(t1.id, t1.v).ndv and t2.rows/(t2.id, t2.v).ndv to determine push target + if (join.left().getOutputSet().containsAll(context.getGroupKeys())) { + toLeft = true; + } else if (join.right().getOutputSet().containsAll(context.getGroupKeys())) { + toRight = true; + } else { + pushHere = true; + } + } else { + for (AggregateFunction aggFunc : context.getAggFunctions()) { + if (join.left().getOutputSet().containsAll(aggFunc.getInputSlots())) { + toLeft = true; + } else if (join.right().getOutputSet().containsAll(aggFunc.getInputSlots())) { + toRight = true; + } else { + pushHere = true; + } + } + } + + if (pushHere || (toLeft && toRight)) { + if (SessionVariable.isEagerAggregationOnJoin()) { + return genAggregate(join, context); + } else { + return join; + } + } + // Do not push aggregation to the nullable side of outer joins when agg function contains case-when. + // CaseWhen expressions may produce non-null values from null-padded rows (e.g., WHEN col IS NULL THEN -54), + // so pre-aggregation before the join loses those contributions. + if (context.hasDecomposedAggIf || context.hasCaseWhen) { + JoinType joinType = join.getJoinType(); + if (joinType.isFullOuterJoin()) { + return join; + } + if (joinType.isRightOuterJoin()) { + toLeft = false; + } + if (joinType.isLeftOuterJoin()) { + toRight = false; + } + if (!toLeft && !toRight && !pushHere) { + return join; + } + } + + // Do not push count(*)/count(literal)/count(preserved_side_col) to the nullable side of outer joins. Review Comment: 为什么 count(preserved_side_col) 能推给另一边? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; +import org.apache.doris.nereids.rules.rewrite.AdjustNullable; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * push down aggregation + */ +public class PushDownAggregation extends DefaultPlanRewriter<JobContext> implements CustomRewriter { + private static final Logger LOG = LoggerFactory.getLogger(PushDownAggregation.class); + + public final EagerAggRewriter writer = new EagerAggRewriter(); + + private final Set<Class> pushDownAggFunctionSet = Sets.newHashSet( + Count.class, + Sum.class, + Max.class, + Min.class); + + private final Set<Class> acceptNodeType = Sets.newHashSet( + LogicalUnion.class, + LogicalProject.class, + LogicalFilter.class, + LogicalRelation.class, + LogicalJoin.class); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + if (SessionVariable.isFeDebug()) { + try { + new AdjustNullable(false).rewriteRoot(plan, null); + } catch (Exception e) { + LOG.warn("(PushDownAggregation) input plan has nullable problem", e); + return plan; + } + } + int mode = SessionVariable.getEagerAggregationMode(); + if (mode < 0) { + return plan; + } else { + Plan result = plan.accept(this, jobContext); + if (SessionVariable.isFeDebug()) { + result = new AdjustNullable(true).rewriteRoot(result, null); + } + return result; + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobContext context) { + Plan newChild = agg.child().accept(this, context); + if (newChild != agg.child()) { + return agg.withChildren(newChild); + } + + if (agg.getSourceRepeat().isPresent()) { + return agg; + } + + List<SlotReference> groupKeys = new ArrayList<>(); + for (Expression groupKey : agg.getGroupByExpressions()) { + if (groupKey instanceof SlotReference) { + groupKeys.add((SlotReference) groupKey); + } else { + SessionVariable.throwAnalysisExceptionWhenFeDebug( + "PushDownAggregation failed: agg is not normalized\n " + + agg.treeString()); + return agg; + } + } + + Set<AggregateFunction> aggFunctions = Sets.newHashSet(); + boolean hasDecomposedAggIf = false; + boolean hasCaseWhen = false; + Map<NamedExpression, List<AggregateFunction>> aggFunctionsForOutputExpressions = Maps.newHashMap(); + for (NamedExpression aggOutput : agg.getOutputExpressions()) { + List<AggregateFunction> funcs = Lists.newArrayList(); + aggFunctionsForOutputExpressions.put(aggOutput, funcs); + for (Object obj : aggOutput.collect(AggregateFunction.class::isInstance)) { + AggregateFunction aggFunction = (AggregateFunction) obj; + if (aggFunction.isDistinct()) { + return agg; + } + if (pushDownAggFunctionSet.contains(aggFunction.getClass())) { + // CaseWhen and If (which CASE WHEN is normalized into) must both be checked. + // When an agg function contains an If/CaseWhen whose condition tests IS NULL + // (e.g. count(if(col IS NULL, value, NULL))), pushing it to the nullable side + // of an outer join produces wrong results: null-extended rows make "col IS NULL" + // TRUE at the top level, but the pre-aggregated count slot becomes NULL after + // null-extension, and ifnull(sum(NULL), 0) = 0 instead of the correct 1. + if (!hasCaseWhen && aggFunction.anyMatch(e -> e instanceof CaseWhen || e instanceof If)) { + hasCaseWhen = true; + } + if (aggFunction.arity() > 0 && aggFunction.child(0) instanceof If + && !(aggFunction instanceof Count)) { + // Decompose Sum/Max/Min(If(cond, a, b)) into separate agg functions. + // Count(If(...)) is NOT decomposed here because the top-level + // replacement (Count->Sum rollup) cannot match the decomposed + // Count(a)/Count(b) as sub-expressions of the original Count(If(cond,a,b)). + // Count(If(...)) is pushed down as-is and rolled up normally. + If body = (If) (aggFunction).child(0); + Set<Slot> valueSlots = Sets.newHashSet(body.getTrueValue().getInputSlots()); + valueSlots.addAll(body.getFalseValue().getInputSlots()); + if (body.getCondition().getInputSlots().stream().anyMatch(s -> valueSlots.contains(s))) { + // do not push down sum(if a then a else b) + return agg; + } + AggregateFunction aggTrue = (AggregateFunction) aggFunction.withChildren(body.getTrueValue()); + aggFunctions.add(aggTrue); + funcs.add(aggTrue); + if (!(body.getFalseValue() instanceof NullLiteral)) { + AggregateFunction aggFalse = + (AggregateFunction) aggFunction.withChildren(body.getFalseValue()); + aggFunctions.add(aggFalse); + funcs.add(aggFalse); + } + groupKeys.addAll(body.getCondition().getInputSlots() + .stream().map(slot -> (SlotReference) slot).collect(Collectors.toList())); + hasDecomposedAggIf = true; + } else { + aggFunctions.add(aggFunction); + funcs.add(aggFunction); + } + + } else { + return agg; + } + } + } + + groupKeys = groupKeys.stream().distinct().collect(Collectors.toList()); + if (!checkSubTreePattern(agg.child())) { + return agg; + } + + PushDownAggContext pushDownContext = new PushDownAggContext(new ArrayList<>(aggFunctions), + groupKeys, null, context.getCascadesContext(), false, hasDecomposedAggIf, hasCaseWhen); + if (!pushDownContext.isValid()) { + return agg; + } + try { + Plan child = agg.child().accept(writer, pushDownContext); + if (child != agg.child()) { + // agg has been pushed down, rewrite agg output expressions + // before: agg[sum(A), by (B)] + // ->join(C=D) + // ->scan(T1[A...]) + // ->scan(T2) + // after: agg[sum(x), by(B)] + // ->join(C=D) + // ->agg[sum(A) as x, by(B,C)] + // ->scan(T1[A...]) + // ->scan(T2) + List<NamedExpression> newOutputExpressions = new ArrayList<>(); + //Map<AggregateFunction, AggregateFunction> replaceMap = new HashMap<>(); + //for (AggregateFunction aggFunc : pushDownContext.getAliasMap().keySet()) { + // Alias alias = pushDownContext.getAliasMap().get(aggFunc); + // replaceMap.put(aggFunc, (AggregateFunction) aggFunc.withChildren((Expression) alias.toSlot())); + //} + + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof SlotReference) { + newOutputExpressions.add(ne); + } else { + // every expression has its own replaceMap + // aggregation(output=[min(A), sum(A)]) + // --> join + // -> T1 [A ...] + // -> T2 [...] + // => + // aggregation(output=[min(minA), sum(sumA)]) + // --> join + // -> agg(output=[min(A) as minA, sum(A) as sumA]) + // -> T1 [A ...] + // -> T2 [...] + // for min(A), replaceMap: A->minA + // for sum(A), replaceMap: A->sumA + // for count(A), replaceMap: count(A)->sum(countA), because count needs rollup to sum + Map<Expression, Expression> replaceMap = new HashMap<>(); + List<AggregateFunction> relatedAggFunc = aggFunctionsForOutputExpressions.get(ne); + for (AggregateFunction func : relatedAggFunc) { + Slot pushedDownSlot = pushDownContext.getAliasMap().get(func).toSlot(); + if (func instanceof Count) { + // For count(A), after pushdown we have count(A) as x, + // and the top agg should use sum(x) instead of count(x). + // Wrap with ifnull(..., 0) because COUNT never returns NULL, + // but after pushdown across an outer join, the intermediate count + // slot can be NULL (null-extended), making sum(NULL) = NULL. + Function rollUpFunc = ((RollUpTrait) func).constructRollUp(pushedDownSlot); + replaceMap.put(func, new Nvl(rollUpFunc, new BigIntLiteral(0))); + } else if (func.arity() > 0) { + // For sum/max/min, replace the child expression with the pushed down slot + replaceMap.put(func.child(0), pushedDownSlot); + } + } + NamedExpression replaceAliasExpr = (NamedExpression) ExpressionUtils.replace(ne, replaceMap); + replaceAliasExpr = (NamedExpression) ExpressionUtils.rebuildSignature(replaceAliasExpr); + newOutputExpressions.add(replaceAliasExpr); + } + } + LogicalAggregate<Plan> eagerAgg = + agg.withAggOutputChild(newOutputExpressions, child); + NormalizeAggregate normalizeAggregate = new NormalizeAggregate(); + return normalizeAggregate.normalizeAgg(eagerAgg, Optional.empty(), + context.getCascadesContext()); Review Comment: 为什么需要重新normalize? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
