feiniaofeiafei commented on code in PR #63690:
URL: https://github.com/apache/doris/pull/63690#discussion_r3449660858
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java:
##########
@@ -70,77 +84,200 @@
* ->T2(D)
*/
public class EagerAggRewriter extends DefaultPlanRewriter<PushDownAggContext> {
+ public static final int BIG_JOIN_BUILD_SIZE = 400_000;
private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000;
private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000;
private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100;
+ private static final String JOIN_CNT = "joinCnt";
private final StatsDerive derive = new StatsDerive(false);
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan>
join, PushDownAggContext context) {
- boolean toLeft = false;
- boolean toRight = false;
- boolean pushHere = false;
- if (join.getJoinType().isAsofJoin()) {
- // do nothing for asof join
- return join;
+ Pair<Boolean, Boolean> pushSide = decideJoinPushSide(join, context);
+ boolean toLeft = pushSide.first;
+ boolean toRight = pushSide.second;
+ if (!toLeft && !toRight) {
+ if (SessionVariable.isEagerAggregationOnJoin()) {
+ return genAggregate(join, context);
+ } else {
+ return join;
+ }
}
- if (context.getAggFunctions().isEmpty()) {
- // select t1.v from t1 join t2 on t1.id = t2.id group by t1.v, t2.v
- // if no agg function, try to push agg to the child which contains
all group keys
- // TODO: consider t1.rows/(t1.id, t1.v).ndv and t2.rows/(t2.id,
t2.v).ndv to determine push target
- if
(join.left().getOutputSet().containsAll(context.getGroupKeys())) {
- toLeft = true;
- } else if
(join.right().getOutputSet().containsAll(context.getGroupKeys())) {
- toRight = true;
+
+ // construct left and right group by keys
+ List<SlotReference> leftChildGroupByKeys = new ArrayList<>();
+ List<SlotReference> rightChildGroupByKeys = new ArrayList<>();
+ if (toLeft) {
+ fillGroupByKeys(join, join.left(), context, leftChildGroupByKeys);
+ }
+ if (toRight) {
+ fillGroupByKeys(join, join.right(), context,
rightChildGroupByKeys);
+ }
+ // construct left and right aggFuncs and aliasMap
+ List<AggregateFunction> leftFuncs = new ArrayList<>();
+ List<AggregateFunction> rightFuncs = new ArrayList<>();
+ Map<AggregateFunction, Alias> leftAliasMap = new IdentityHashMap<>();
+ Map<AggregateFunction, Alias> rightAliasMap = new IdentityHashMap<>();
+ for (AggregateFunction f : context.getAggFunctions()) {
+ Set<Slot> inputs = f.getInputSlots();
+ Alias a = context.getAliasMap().get(f);
+ if (inputs.isEmpty()) {
+ if (join.getJoinType().isRightSemiOrAntiJoin()) {
+ rightFuncs.add(f);
+ rightAliasMap.put(f, a);
+ } else {
+ leftFuncs.add(f);
+ leftAliasMap.put(f, a);
+ }
+ continue;
+ }
+ if (join.left().getOutputSet().containsAll(inputs)) {
+ leftFuncs.add(f);
+ leftAliasMap.put(f, a);
+ } else if (join.right().getOutputSet().containsAll(inputs)) {
+ rightFuncs.add(f);
+ rightAliasMap.put(f, a);
} else {
- pushHere = true;
+ return join;
}
+ }
+
+ boolean passThroughBigJoin = isPassThroughBigJoin(join, context);
+ boolean leftNeedOutputCount = needOutputCountForJoinChild(join,
toLeft, toRight,
+ context.needOutputCount(), rightFuncs);
+ boolean rightNeedOutputCount = needOutputCountForJoinChild(join,
toRight, toLeft,
+ context.needOutputCount(), leftFuncs);
+ Optional<PushDownAggContext> leftChildContext = toLeft ?
Optional.of(context.forOneBranch(leftFuncs,
+ leftAliasMap, leftChildGroupByKeys, passThroughBigJoin,
leftNeedOutputCount)) : Optional.empty();
+ Optional<PushDownAggContext> rightChildContext = toRight ?
Optional.of(context.forOneBranch(rightFuncs,
+ rightAliasMap, rightChildGroupByKeys, passThroughBigJoin,
rightNeedOutputCount)) : Optional.empty();
+
+ Plan newLeft = join.left();
+ Plan newRight = join.right();
+ if (leftChildContext.isPresent() &&
!leftChildContext.get().noGroupKeyAndNoAggFunc()) {
+ newLeft = join.left().accept(this, leftChildContext.get());
+ }
+ if (rightChildContext.isPresent() &&
!rightChildContext.get().noGroupKeyAndNoAggFunc()) {
+ newRight = join.right().accept(this, rightChildContext.get());
+ }
+
+ if (newLeft == join.left() && newRight == join.right()) {
+ context.getBilateralState().registerNoCountSlot(join);
+ return join;
+ }
+ Optional<Slot> leftChildCountSlot =
context.getBilateralState().getCountSlot(newLeft);
+ Optional<Slot> rightChildCountSlot =
context.getBilateralState().getCountSlot(newRight);
+ LogicalJoin<? extends Plan, ? extends Plan> newJoin = (LogicalJoin<?
extends Plan, ? extends Plan>)
+ join.withChildren(newLeft, newRight);
+
+ if (leftChildCountSlot.isPresent() || rightChildCountSlot.isPresent())
{
+ return buildCanonicalJoinProject(newJoin, context,
leftChildContext, rightChildContext,
+ leftChildCountSlot, rightChildCountSlot);
+ }
+ context.getBilateralState().registerNoCountSlot(newJoin);
+ return newJoin;
+ }
+
+ private Pair<Boolean, Boolean> decideJoinPushSide(
+ LogicalJoin<? extends Plan, ? extends Plan> join,
PushDownAggContext context) {
+ if (join.getJoinType().isAsofJoin() || join.isMarkJoin()) {
+ // do nothing for asof join and mark join
+ return Pair.of(false, false);
+ }
+
+ boolean deduplicateOnly = context.getAggFunctions().isEmpty();
+ boolean toLeft = false;
+ boolean toRight = false;
+ if (deduplicateOnly) {
+ toLeft = true;
+ toRight = true;
} else {
for (AggregateFunction aggFunc : context.getAggFunctions()) {
- if
(join.left().getOutputSet().containsAll(aggFunc.getInputSlots())) {
+ Set<Slot> inputs = aggFunc.getInputSlots();
+ if (inputs.isEmpty()) {
+ if (join.getJoinType().isRightSemiOrAntiJoin()) {
+ toRight = true;
+ } else {
+ toLeft = true;
+ }
+ continue;
+ }
+ if (join.left().getOutputSet().containsAll(inputs)) {
toLeft = true;
- } else if
(join.right().getOutputSet().containsAll(aggFunc.getInputSlots())) {
+ } else if (join.right().getOutputSet().containsAll(inputs)) {
toRight = true;
} else {
- pushHere = true;
+ toLeft = false;
+ toRight = false;
+ break;
}
}
}
+ if (!toLeft && !toRight) {
+ return Pair.of(false, false);
+ }
+ if (deduplicateOnly) {
+ return adjustPushSideForCaseWhen(join, context, toLeft, toRight);
+ }
+ if (toLeft && toRight) {
+ return join.getJoinType().isInnerOrCrossJoin()
+ ? Pair.of(true, true)
Review Comment:
fixed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]