github-actions[bot] commented on code in PR #63690:
URL: https://github.com/apache/doris/pull/63690#discussion_r3316996970


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java:
##########
@@ -77,70 +91,154 @@ public class EagerAggRewriter extends 
DefaultPlanRewriter<PushDownAggContext> {
 
     @Override
     public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> 
join, PushDownAggContext context) {
-        boolean toLeft = false;
-        boolean toRight = false;
-        boolean pushHere = false;
-        if (join.getJoinType().isAsofJoin()) {
-            // do nothing for asof join
-            return join;
+        Pair<Boolean, Boolean> pushSide = decideJoinPushSide(join, context);
+        boolean toLeft = pushSide.first;
+        boolean toRight = pushSide.second;
+        if (!toLeft && !toRight) {
+            if (SessionVariable.isEagerAggregationOnJoin()) {
+                return genAggregate(join, context);
+            } else {
+                return join;
+            }
+        }
+
+        // construct left and right group by keys
+        List<SlotReference> leftChildGroupByKeys = new ArrayList<>();
+        List<SlotReference> rightChildGroupByKeys = new ArrayList<>();
+        if (toLeft) {
+            fillGroupByKeys(join, join.left(), context, leftChildGroupByKeys);
+        }
+        if (toRight) {
+            fillGroupByKeys(join, join.right(), context, 
rightChildGroupByKeys);
         }
-        if (context.getAggFunctions().isEmpty()) {
-            // select t1.v from t1 join t2 on t1.id = t2.id group by t1.v, t2.v
-            // if no agg function, try to push agg to the child which contains 
all group keys
-            // TODO: consider t1.rows/(t1.id, t1.v).ndv and t2.rows/(t2.id, 
t2.v).ndv to determine push target
-            if 
(join.left().getOutputSet().containsAll(context.getGroupKeys())) {
-                toLeft = true;
-            } else if 
(join.right().getOutputSet().containsAll(context.getGroupKeys())) {
-                toRight = true;
+        // construct left and right aggFuncs and aliasMap
+        List<AggregateFunction> leftFuncs = new ArrayList<>();
+        List<AggregateFunction> rightFuncs = new ArrayList<>();
+        Map<AggregateFunction, Alias> leftAliasMap = new IdentityHashMap<>();
+        Map<AggregateFunction, Alias> rightAliasMap = new IdentityHashMap<>();
+        for (AggregateFunction f : context.getAggFunctions()) {
+            Set<Slot> inputs = f.getInputSlots();
+            Alias a = context.getAliasMap().get(f);
+            if (join.left().getOutputSet().containsAll(inputs)) {
+                leftFuncs.add(f);
+                leftAliasMap.put(f, a);
+            } else if (join.right().getOutputSet().containsAll(inputs)) {
+                rightFuncs.add(f);
+                rightAliasMap.put(f, a);
             } else {
-                pushHere = true;
+                return join;
             }
+        }
+
+        boolean passThroughBigJoin = isPassThroughBigJoin(join, context);
+        Optional<PushDownAggContext> leftChildContext = toLeft ? 
Optional.of(context.forBilateralBranch(leftFuncs,
+                leftAliasMap, leftChildGroupByKeys, passThroughBigJoin)) : 
Optional.empty();
+        Optional<PushDownAggContext> rightChildContext = toRight ? 
Optional.of(context.forBilateralBranch(rightFuncs,
+                rightAliasMap, rightChildGroupByKeys, passThroughBigJoin)) : 
Optional.empty();
+
+        Plan newLeft = join.left();
+        Plan newRight = join.right();
+        if (leftChildContext.isPresent()) {
+            newLeft = join.left().accept(this, leftChildContext.get());
+        }
+        if (rightChildContext.isPresent()) {
+            newRight = join.right().accept(this, rightChildContext.get());
+        }
+
+        if (newLeft == join.left() && newRight == join.right()) {
+            context.getBilateralState().registerNoCountSlot(join);
+            return join;
+        }
+        Optional<Slot> leftChildCountSlot = 
context.getBilateralState().getCountSlot(newLeft);
+        Optional<Slot> rightChildCountSlot = 
context.getBilateralState().getCountSlot(newRight);
+        LogicalJoin<? extends Plan, ? extends Plan> newJoin = (LogicalJoin<? 
extends Plan, ? extends Plan>)
+                join.withChildren(newLeft, newRight);
+
+        if (leftChildCountSlot.isPresent() || rightChildCountSlot.isPresent()) 
{
+            return buildCanonicalJoinProject(newJoin, context, 
leftChildContext, rightChildContext,
+                    leftChildCountSlot, rightChildCountSlot);
+        }
+        context.getBilateralState().registerNoCountSlot(newJoin);
+        return newJoin;
+    }
+
+    private Pair<Boolean, Boolean> decideJoinPushSide(
+            LogicalJoin<? extends Plan, ? extends Plan> join, 
PushDownAggContext context) {
+        if (join.getJoinType().isAsofJoin()) {
+            // do nothing for asof join
+            return Pair.of(false, false);
+        }
+
+        boolean deduplicateOnly = context.getAggFunctions().isEmpty();
+        boolean toLeft = false;
+        boolean toRight = false;
+        if (deduplicateOnly) {
+            toLeft = true;
+            toRight = true;
         } else {
             for (AggregateFunction aggFunc : context.getAggFunctions()) {
                 if 
(join.left().getOutputSet().containsAll(aggFunc.getInputSlots())) {
                     toLeft = true;
                 } else if 
(join.right().getOutputSet().containsAll(aggFunc.getInputSlots())) {
                     toRight = true;
                 } else {
-                    pushHere = true;
+                    toLeft = false;
+                    toRight = false;
                 }
             }
         }
-
-        if (pushHere || (toLeft && toRight)) {
-            if (SessionVariable.isEagerAggregationOnJoin()) {
-                return genAggregate(join, context);
-            } else {
-                return join;
-            }
+        if (!toLeft && !toRight) {
+            return Pair.of(false, false);
+        }
+        if (deduplicateOnly) {
+            return adjustPushSideForCaseWhen(join, context, toLeft, toRight);
         }
+        if (toLeft && toRight) {
+            return join.getJoinType().isInnerOrCrossJoin()
+                    ? Pair.of(true, true)

Review Comment:
   This permits bilateral pushdown for `CROSS_JOIN`, but the project 
reconstruction below only restores multiplicity when `joinType.isInnerJoin()` 
is true. For a query like `SELECT t1.k, sum(t1.v), sum(t2.v) FROM t1 CROSS JOIN 
t2 GROUP BY t1.k` with mode 1, both sides can be pre-aggregated here; then 
`shouldUseJoinOppositeCntAdjustAggOutput()` returns false and 
`computeJoinCount()` returns empty for `CROSS_JOIN`, so the top aggregate sees 
each pre-aggregated value once instead of multiplying `sum(t1.v)` by 
`count(t2)` and `sum(t2.v)` by the left-side count. Please either exclude 
`CROSS_JOIN` from this bilateral path or handle it with the same 
multiplier/count semantics as inner joins.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to