morrySnow commented on code in PR #63690: URL: https://github.com/apache/doris/pull/63690#discussion_r3386213052
########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggHints.java: ########## @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.eageraggregation; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; + +import com.google.common.collect.ImmutableMap; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Branch-scoped eager aggregation push-down hints parsed from the session variable + * {@code force_eager_agg_hint}. + * + * <p>Format: semicolon-separated list of {@code <key>=<action>} entries, where + * {@code key = <funcName>:<argSig>} and {@code action ∈ {push, nopush}}. + * + * <p>The key is matched per aggregate-function occurrence, but the effect is applied at the + * current candidate push-down branch/subtree instead of independently per function: + * if any matched aggregate in the branch is marked {@code nopush}, push-down is disabled for + * that branch; otherwise, if any matched aggregate in the branch is marked {@code push}, + * push-down may be forced for that branch. Other aggregates in the same branch follow that + * branch-level decision. + * + * <p>{@code argSig} rules: + * <ul> + * <li>{@code count(*)} → {@code "*"}</li> + * <li>single-arg agg over a SlotReference → {@code "<last-qualifier-segment>.<column>"} + * or {@code "<column>"} when the slot has no qualifier</li> + * <li>otherwise → {@code Expression#toSql()} lower-cased</li> + * </ul> + * + * <p>Examples: + * <pre> + * set force_eager_agg_hint = 'sum:t1.a=push; sum:t2.a=nopush; count:*=push'; + * </pre> + * + * <p>This feature is intended for tests/debugging of the eager-aggregation rewrite only; + * when unset, all decisions fall back to {@code eager_aggregation_mode} + statistics. + */ +public final class EagerAggHints { + + /** Matched hint action for a specific aggregate-function occurrence. */ + public enum Action { + PUSH, + NOPUSH + } + + private EagerAggHints() { + } + + /** + * Returns the matched hint action for the given aggregate function based on the current + * session's {@code force_eager_agg_hint}, or empty if no matching entry is configured. + */ + public static Optional<Action> decide(AggregateFunction aggFunction) { + Map<String, Action> hints = currentHints(); + if (hints.isEmpty()) { + return Optional.empty(); + } + Action action = hints.get(keyOf(aggFunction)); + return Optional.ofNullable(action); + } + + /** Builds the canonical hint key for the given aggregate function. */ + public static String keyOf(AggregateFunction aggFunction) { + String fn = aggFunction.getName().toLowerCase(); + if (aggFunction instanceof Count && ((Count) aggFunction).isStar()) { + return fn + ":*"; + } + if (aggFunction.arity() == 1) { + Expression arg = aggFunction.child(0); + if (arg instanceof SlotReference) { + SlotReference slot = (SlotReference) arg; + List<String> qualifier = slot.getQualifier(); + String prefix = qualifier.isEmpty() + ? "" + : qualifier.get(qualifier.size() - 1).toLowerCase() + "."; + return fn + ":" + prefix + slot.getName().toLowerCase(); + } + } + return fn + ":" + aggFunction.child(0).toSql().toLowerCase(); + } + + private static Map<String, Action> currentHints() { + ConnectContext ctx = ConnectContext.get(); + if (ctx == null) { + return ImmutableMap.of(); + } + SessionVariable sv = ctx.getSessionVariable(); + String raw = sv.forceEagerAggHint; + if (raw == null || raw.isEmpty()) { + return ImmutableMap.of(); + } + return parse(raw); + } + + /** Parse a raw hint string into a map; malformed entries are silently ignored. */ + public static Map<String, Action> parse(String raw) { Review Comment: could we parse and check it when set session variable? so we could get a Map from `sv.forceEagerAggHint` directly ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java: ########## @@ -249,6 +250,11 @@ public static Expression rebuildSignature(Expression expr) { BoundFunction rebuilt = (BoundFunction) fn.withChildren(newChildren); rebuilt = (BoundFunction) TypeCoercionUtils.processBoundFunction(rebuilt); return rebuilt; + } else if (expr instanceof BinaryArithmetic) { + BinaryArithmetic binaryArithmetic = (BinaryArithmetic) expr; + BinaryArithmetic rebuilt = (BinaryArithmetic) binaryArithmetic.withChildren(newChildren); + rebuilt = (BinaryArithmetic) TypeCoercionUtils.processBinaryArithmetic(rebuilt); + return rebuilt; Review Comment: this is a bug? if it is, please commit a seperate PR to fix it ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggContext.java: ########## @@ -87,20 +96,22 @@ public PushDownAggContext(List<AggregateFunction> aggFunctions, this.passThroughBigJoin = passThroughBigJoin; this.hasDecomposedAggIf = hasDecomposedAggIf; this.hasCaseWhen = hasCaseWhen; + this.bilateralState = Objects.requireNonNull(bilateralState); Review Comment: add error message ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggContext.java: ########## @@ -117,7 +128,22 @@ public List<SlotReference> getGroupKeys() { public PushDownAggContext withGroupKeys(List<SlotReference> groupKeys) { return new PushDownAggContext(aggFunctions, groupKeys, aliasMap, - cascadesContext, passThroughBigJoin, hasDecomposedAggIf, hasCaseWhen); + cascadesContext, passThroughBigJoin, hasDecomposedAggIf, hasCaseWhen, + bilateralState); + } + + /** + * Derive a child context for one branch of a join during bilateral push-down. + */ + public PushDownAggContext forBilateralBranch(List<AggregateFunction> branchAggFunctions, Review Comment: how to understand the prefix `branch` in parameter `branchAggFunctions`? the context is used for Bilateral push down, so which side does the `branch` represent? If the context is used for only one branch. Maybe rename to `forOneBranch` is better? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/PushDownAggregation.java: ########## @@ -255,18 +256,18 @@ public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, JobConte Map<Expression, Expression> replaceMap = new HashMap<>(); List<AggregateFunction> relatedAggFunc = aggFunctionsForOutputExpressions.get(ne); for (AggregateFunction func : relatedAggFunc) { - Slot pushedDownSlot = pushDownContext.getAliasMap().get(func).toSlot(); + Alias pushedAlias = pushDownContext.getAliasMap().get(func); + ExprId pushId = pushedAlias.getExprId(); + if (!state.hasAggFuncOutput(pushId)) { + continue; + } + Expression value = state.getPushedAggFuncSlot(pushId); if (func instanceof Count) { - // For count(A), after pushdown we have count(A) as x, - // and the top agg should use sum(x) instead of count(x). - // Wrap with ifnull(..., 0) because COUNT never returns NULL, - // but after pushdown across an outer join, the intermediate count - // slot can be NULL (null-extended), making sum(NULL) = NULL. - Function rollUpFunc = ((RollUpTrait) func).constructRollUp(pushedDownSlot); - replaceMap.put(func, new Nvl(rollUpFunc, new BigIntLiteral(0))); + replaceMap.put(func, new Sum0(value)); + } else if (func instanceof Max || func instanceof Min) { + replaceMap.put(func.child(0), value); } else if (func.arity() > 0) { - // For sum/max/min, replace the child expression with the pushed down slot - replaceMap.put(func.child(0), pushedDownSlot); + replaceMap.put(func.child(0), value); Review Comment: 两个分支一模一样 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
