This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch groupjoin
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/groupjoin by this push:
     new 26f40d4694a fe-part (#64763)
26f40d4694a is described below

commit 26f40d4694a098ca09c54ca9ccf31afd94b94525
Author: minghong <[email protected]>
AuthorDate: Wed Jun 24 13:02:55 2026 +0800

    fe-part (#64763)
    
    ### What problem does this PR solve?
    
    Issue Number: close #xxx
    
    Related PR: #xxx
    
    Problem Summary:
    
    ### Release note
    
    None
    
    ### Check List (For Author)
    
    - Test <!-- At least one of them must be included. -->
        - [ ] Regression test
        - [ ] Unit Test
        - [ ] Manual test (add detailed scripts or steps below)
        - [ ] No need to test or manual test. Explain why:
    - [ ] This is a refactor/code format and no logic has been changed.
            - [ ] Previous test can cover this change.
            - [ ] No code files have been changed.
            - [ ] Other reason <!-- Add your reason?  -->
    
    - Behavior changed:
        - [ ] No.
        - [ ] Yes. <!-- Explain the behavior change -->
    
    - Does this need documentation?
        - [ ] No.
    - [ ] Yes. <!-- Add document PR link here. eg:
    https://github.com/apache/doris-website/pull/1214 -->
    
    ### Check List (For Reviewer who merge this PR)
    
    - [ ] Confirm the release note
    - [ ] Confirm test cases
    - [ ] Confirm document
    - [ ] Add branch pick label <!-- Add branch pick label that this PR
    should merge into -->
---
 .../glue/translator/PhysicalPlanTranslator.java    | 319 +++++++++++++++++++++
 .../org/apache/doris/planner/GroupJoinNode.java    | 245 ++++++++++++++++
 .../java/org/apache/doris/qe/SessionVariable.java  |  16 ++
 3 files changed, 580 insertions(+)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 75c65bc120d..13ac708b0fe 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.glue.translator;
 import org.apache.doris.analysis.AggregateInfo;
 import org.apache.doris.analysis.AnalyticWindow;
 import org.apache.doris.analysis.AssertNumRowsElement;
+import org.apache.doris.analysis.BinaryPredicate;
 import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.FunctionCallExpr;
 import org.apache.doris.analysis.GroupingInfo;
@@ -199,6 +200,7 @@ import org.apache.doris.planner.EmptySetNode;
 import org.apache.doris.planner.ExceptNode;
 import org.apache.doris.planner.ExchangeNode;
 import org.apache.doris.planner.GroupCommitBlockSink;
+import org.apache.doris.planner.GroupJoinNode;
 import org.apache.doris.planner.HashJoinNode;
 import org.apache.doris.planner.HiveTableSink;
 import org.apache.doris.planner.IcebergDeleteSink;
@@ -235,6 +237,8 @@ import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.qe.SessionVariable;
 import org.apache.doris.statistics.StatisticConstants;
 import org.apache.doris.tablefunction.TableValuedFunctionIf;
+import org.apache.doris.thrift.TGroupJoinAggOutputMode;
+import org.apache.doris.thrift.TGroupJoinAggSide;
 import org.apache.doris.thrift.TPartitionType;
 import org.apache.doris.thrift.TPushAggOp;
 import org.apache.doris.thrift.TResultSinkType;
@@ -1140,6 +1144,12 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
             PhysicalHashAggregate<? extends Plan> aggregate,
             PlanTranslatorContext context) {
 
+        // V2: try GroupJoin fusion directly in translator
+        PlanFragment groupJoinFragment = maybeTranslateToGroupJoin(aggregate, 
context);
+        if (groupJoinFragment != null) {
+            return groupJoinFragment;
+        }
+
         PlanFragment inputPlanFragment = aggregate.child(0).accept(this, 
context);
         List<List<Expr>> distributeExprLists = 
getDistributeExprs(aggregate.child(0));
 
@@ -3043,6 +3053,315 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         return leftFragment;
     }
 
+    private PlanFragment connectGroupJoinNode(GroupJoinNode groupJoinNode, 
PlanFragment leftFragment,
+            PlanFragment rightFragment, PlanTranslatorContext context, 
AbstractPlan groupJoin) {
+        groupJoinNode.setChild(0, leftFragment.getPlanRoot());
+        groupJoinNode.setChild(1, rightFragment.getPlanRoot());
+        setPlanRoot(leftFragment, groupJoinNode, groupJoin);
+        context.mergePlanFragment(rightFragment, leftFragment);
+        for (PlanFragment rightChild : rightFragment.getChildren()) {
+            leftFragment.addChild(rightChild);
+        }
+        return leftFragment;
+    }
+
+    /**
+     * V2: Try to fuse HashAggregate(HashJoin) into GroupJoin directly in the 
translator stage.
+     * <p>
+     * When the aggregate's child is an INNER hash join with compatible 
group-by/join-key,
+     * generate a GroupJoinNode instead of separate AggregationNode + 
HashJoinNode.
+     * Returns null if conditions are not met — caller falls through to normal 
translation.
+     */
+    private PlanFragment maybeTranslateToGroupJoin(
+            Plan aggregate,
+            PlanTranslatorContext context) {
+        // Gate: session variable
+        ConnectContext connectContext = ConnectContext.get();
+        if (connectContext == null
+                || 
!connectContext.getSessionVariable().isEnableGroupJoinFusion()) {
+            return null;
+        }
+        // Gate: spill not supported
+        if (connectContext.getSessionVariable().enableSpill) {
+            return null;
+        }
+
+        // Child must be PhysicalHashJoin (optionally through PhysicalProject)
+        Plan child = aggregate.child(0);
+        if (child instanceof PhysicalProject) {
+            child = child.child(0);
+        }
+        if (!(child instanceof PhysicalHashJoin)) {
+            return null;
+        }
+        PhysicalHashJoin<?, ?> join = (PhysicalHashJoin<?, ?>) child;
+
+        // Only INNER_JOIN for V1
+        if (join.getJoinType() != JoinType.INNER_JOIN && 
!join.getJoinType().isCrossJoin()) {
+            return null;
+        }
+        // Not mark join
+        if (join.isMarkJoin()) {
+            return null;
+        }
+        // Not broadcast join
+        if (join.isBroadCastJoin()) {
+            return null;
+        }
+
+        Aggregate<?> agg = (Aggregate<?>) aggregate;
+
+        // Group-by keys must be equivalent to hash-join keys
+        if (!isGroupKeyEquivalentToJoinKey(agg, join)) {
+            return null;
+        }
+        // Aggregate functions must not reference columns from both sides
+        if (!aggFunctionsSingleSide(agg, join)) {
+            return null;
+        }
+
+        // All checks passed — generate GroupJoinNode
+        return translateToGroupJoinNode(agg, join, context);
+    }
+
+    /** Check group-by expressions match equi-join keys. */
+    private boolean isGroupKeyEquivalentToJoinKey(
+            Aggregate<?> aggregate, PhysicalHashJoin<?, ?> join) {
+        List<Expression> groupByExprs = aggregate.getGroupByExpressions();
+        List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
+        if (groupByExprs.isEmpty() || hashJoinConjuncts.isEmpty()
+                || groupByExprs.size() != hashJoinConjuncts.size()) {
+            return false;
+        }
+        Set<Slot> groupBySlots = groupByExprs.stream()
+                .flatMap(e -> e.getInputSlots().stream())
+                .collect(Collectors.toSet());
+        Set<Slot> joinLeftSlots = new HashSet<>();
+        Set<Slot> joinRightSlots = new HashSet<>();
+        Set<Slot> leftOutput = join.left().getOutputSet();
+        Set<Slot> rightOutput = join.right().getOutputSet();
+        for (Expression conjunct : hashJoinConjuncts) {
+            if (!(conjunct instanceof EqualPredicate)) {
+                return false;
+            }
+            EqualPredicate eq = (EqualPredicate) conjunct;
+            Set<Slot> leftSide = eq.left().getInputSlots();
+            Set<Slot> rightSide = eq.right().getInputSlots();
+            if (leftOutput.containsAll(leftSide) && 
rightOutput.containsAll(rightSide)) {
+                joinLeftSlots.addAll(leftSide);
+                joinRightSlots.addAll(rightSide);
+            } else if (leftOutput.containsAll(rightSide) && 
rightOutput.containsAll(leftSide)) {
+                joinLeftSlots.addAll(rightSide);
+                joinRightSlots.addAll(leftSide);
+            } else {
+                return false;
+            }
+        }
+        return groupBySlots.equals(joinLeftSlots) || 
groupBySlots.equals(joinRightSlots);
+    }
+
+    /** Check no aggregate function references columns from both join sides. */
+    private boolean aggFunctionsSingleSide(
+            Aggregate<?> aggregate, PhysicalHashJoin<?, ?> join) {
+        Set<Slot> leftOutput = join.left().getOutputSet();
+        Set<Slot> rightOutput = join.right().getOutputSet();
+        for (NamedExpression outputExpr : aggregate.getOutputExpressions()) {
+            List<AggregateExpression> aggExprs = outputExpr
+                    .collect(AggregateExpression.class::isInstance).stream()
+                    .map(AggregateExpression.class::cast)
+                    .collect(Collectors.toList());
+            for (AggregateExpression aggExpr : aggExprs) {
+                Set<Slot> inputSlots = aggExpr.getInputSlots();
+                boolean hasLeft = false;
+                boolean hasRight = false;
+                for (Slot slot : inputSlots) {
+                    if (leftOutput.contains(slot)) {
+                        hasLeft = true;
+                    } else if (rightOutput.contains(slot)) {
+                        hasRight = true;
+                    }
+                }
+                if (hasLeft && hasRight) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
+    /** Translate Aggregate(HashJoin) pattern into a GroupJoinNode fragment. */
+    private PlanFragment translateToGroupJoinNode(
+            Aggregate<?> aggregate,
+            PhysicalHashJoin<?, ?> join,
+            PlanTranslatorContext context) {
+        PhysicalHashJoin<PhysicalPlan, PhysicalPlan> physicalJoin
+                = (PhysicalHashJoin<PhysicalPlan, PhysicalPlan>) join;
+
+        // Visit children right-to-left (right = build, left = probe)
+        PlanFragment rightFragment = join.child(1).accept(this, context);
+        PlanFragment leftFragment = join.child(0).accept(this, context);
+        PlanNode leftPlanRoot = leftFragment.getPlanRoot();
+        PlanNode rightPlanRoot = rightFragment.getPlanRoot();
+
+        // Create GroupJoinNode
+        GroupJoinNode groupJoinNode = new GroupJoinNode(
+                context.nextPlanNodeId(), leftPlanRoot, rightPlanRoot);
+        groupJoinNode.setNereidsId(join.getId());
+        context.getNereidsIdToPlanNodeIdMap().put(join.getId(), 
groupJoinNode.getId());
+
+        // Join operator
+        groupJoinNode.setJoinOp(JoinType.toJoinOperator(join.getJoinType()));
+
+        // Distribute expr lists
+        List<List<Expr>> distributeExprLists = getDistributeExprs(
+                physicalJoin.left(), physicalJoin.right());
+        groupJoinNode.setChildrenDistributeExprLists(distributeExprLists);
+
+        // Equi-join conjuncts
+        List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
+        for (Expression hashConjunct : hashJoinConjuncts) {
+            EqualPredicate equalTo = JoinUtils.swapEqualToForChildrenOrder(
+                    (EqualPredicate) hashConjunct, join.left().getOutputSet());
+            groupJoinNode.addEqJoinConjunct(
+                    (BinaryPredicate) ExpressionTranslator.translate(equalTo, 
context));
+        }
+
+        // Group-by expressions
+        List<Expr> groupingExprs = new ArrayList<>();
+        for (Expression e : aggregate.getGroupByExpressions()) {
+            groupingExprs.add(ExpressionTranslator.translate(e, context));
+        }
+        groupJoinNode.setGroupingExprs(groupingExprs);
+
+        // Aggregate functions with side annotations
+        List<Expr> aggFuncExprs = new ArrayList<>();
+        List<TGroupJoinAggSide> aggSides = new ArrayList<>();
+        Set<Slot> rightOutput = join.right().getOutputSet();
+        Set<AggregateExpression> seen = new HashSet<>();
+        for (NamedExpression outputExpr : aggregate.getOutputExpressions()) {
+            for (AggregateExpression aggExpr : outputExpr
+                    .collect(AggregateExpression.class::isInstance).stream()
+                    .map(AggregateExpression.class::cast)
+                    .collect(Collectors.toList())) {
+                if (seen.add(aggExpr)) {
+                    aggFuncExprs.add(ExpressionTranslator.translate(aggExpr, 
context));
+                    boolean fromBuild = aggExpr.getInputSlots().stream()
+                            .anyMatch(rightOutput::contains);
+                    aggSides.add(fromBuild
+                            ? TGroupJoinAggSide.BUILD : 
TGroupJoinAggSide.PROBE);
+                }
+            }
+        }
+        groupJoinNode.setAggregateFunctions(aggFuncExprs);
+        groupJoinNode.setAggSides(aggSides);
+        groupJoinNode.setAggOutputMode(TGroupJoinAggOutputMode.FINAL_RESULT);
+
+        // Connect fragments
+        PlanFragment currentFragment = connectGroupJoinNode(
+                groupJoinNode, leftFragment, rightFragment, context, join);
+
+        // Distribution mode
+        if (JoinUtils.shouldColocateJoin(physicalJoin)) {
+            groupJoinNode.setColocate(true);
+            leftFragment.setHasColocatePlanNode(true);
+        } else if (JoinUtils.shouldBroadcastJoin(physicalJoin)) {
+            Preconditions.checkState(rightPlanRoot instanceof ExchangeNode,
+                    "right child of broadcast GroupJoin must be ExchangeNode");
+            ((ExchangeNode) 
rightPlanRoot).setRightChildOfBroadcastHashJoin(true);
+            groupJoinNode.setDistributionMode(DistributionMode.BROADCAST);
+        } else if (JoinUtils.shouldBucketShuffleJoin(physicalJoin)) {
+            groupJoinNode.setDistributionMode(DistributionMode.BUCKET_SHUFFLE);
+        } else {
+            groupJoinNode.setDistributionMode(DistributionMode.PARTITIONED);
+        }
+
+        // Runtime filters
+        context.getRuntimeTranslator().ifPresent(rt ->
+                rt.createLegacyRuntimeFilters(
+                        physicalJoin.getRuntimeFilters(), groupJoinNode, 
context));
+
+        // Output tuple descriptor
+        createGroupJoinOutputTuple(aggregate, join, groupJoinNode, context);
+
+        if (join.getStats() != null) {
+            groupJoinNode.setCardinality((long) join.getStats().getRowCount());
+        }
+        updateLegacyPlanIdToPhysicalPlan(currentFragment.getPlanRoot(), join);
+        return currentFragment;
+    }
+
+    /** Build output tuple descriptor for GroupJoinNode. */
+    private void createGroupJoinOutputTuple(
+            Aggregate<?> aggregate,
+            PhysicalHashJoin<?, ?> join,
+            GroupJoinNode groupJoinNode,
+            PlanTranslatorContext context) {
+        PlanNode leftNode = groupJoinNode.getChild(0);
+        PlanNode rightNode = groupJoinNode.getChild(1);
+
+        List<TupleDescriptor> leftTuples = context.getTupleDesc(leftNode);
+        List<SlotDescriptor> leftSlotDescriptors = leftTuples.stream()
+                .map(TupleDescriptor::getSlots)
+                .flatMap(Collection::stream)
+                .collect(Collectors.toList());
+        List<TupleDescriptor> rightTuples = context.getTupleDesc(rightNode);
+        List<SlotDescriptor> rightSlotDescriptors = rightTuples.stream()
+                .map(TupleDescriptor::getSlots)
+                .flatMap(Collection::stream)
+                .collect(Collectors.toList());
+
+        Map<ExprId, SlotReference> outputSlotRefMap = new HashMap<>();
+        for (Slot slot : join.getOutput()) {
+            SlotReference sf = (SlotReference) slot;
+            outputSlotRefMap.putIfAbsent(sf.getExprId(), sf);
+        }
+
+        // Intermediate tuple for hash output slots
+        TupleDescriptor intermediateDescriptor = context.generateTupleDesc();
+        Map<ExprId, SlotReference> leftOutputMap = 
join.left().getOutput().stream()
+                .map(SlotReference.class::cast)
+                .collect(Collectors.toMap(Slot::getExprId, s -> s,
+                        (existing, replacement) -> existing));
+
+        for (SlotDescriptor leftSlotDesc : leftSlotDescriptors) {
+            SlotReference sf = leftOutputMap.get(
+                    context.findExprId(leftSlotDesc.getId()));
+            if (sf != null && outputSlotRefMap.get(sf.getExprId()) != null) {
+                context.createSlotDesc(intermediateDescriptor, sf);
+                
groupJoinNode.addSlotIdToHashOutputSlotIds(leftSlotDesc.getId());
+            }
+        }
+        Map<ExprId, SlotReference> rightOutputMap = 
join.right().getOutput().stream()
+                .map(SlotReference.class::cast)
+                .collect(Collectors.toMap(Slot::getExprId, s -> s,
+                        (existing, replacement) -> existing));
+        for (SlotDescriptor rightSlotDesc : rightSlotDescriptors) {
+            SlotReference sf = rightOutputMap.get(
+                    context.findExprId(rightSlotDesc.getId()));
+            if (sf != null && outputSlotRefMap.get(sf.getExprId()) != null) {
+                context.createSlotDesc(intermediateDescriptor, sf);
+                
groupJoinNode.addSlotIdToHashOutputSlotIds(rightSlotDesc.getId());
+            }
+        }
+
+        // Output tuple
+        TupleDescriptor outputTupleDesc = context.generateTupleDesc();
+        for (Expression groupByExpr : aggregate.getGroupByExpressions()) {
+            if (groupByExpr instanceof SlotReference) {
+                context.createSlotDesc(outputTupleDesc, (SlotReference) 
groupByExpr);
+            }
+        }
+        for (NamedExpression outputExpr : aggregate.getOutputExpressions()) {
+            if (outputExpr.containsType(AggregateExpression.class)) {
+                Slot slot = outputExpr.toSlot();
+                if (slot instanceof SlotReference) {
+                    context.createSlotDesc(outputTupleDesc, (SlotReference) 
slot);
+                }
+            }
+        }
+        groupJoinNode.setOutputTupleDesc(outputTupleDesc);
+    }
+
     /**
      * Translate group-by expressions from Nereids Expression to legacy Expr.
      * Shared by visitPhysicalHashAggregate and 
visitPhysicalBucketedHashAggregate.
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/planner/GroupJoinNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/GroupJoinNode.java
new file mode 100644
index 00000000000..a6ffa8a2da9
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/GroupJoinNode.java
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.planner;
+
+import org.apache.doris.analysis.BinaryPredicate;
+import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.ExprToThriftVisitor;
+import org.apache.doris.analysis.JoinOperator;
+import org.apache.doris.analysis.SlotId;
+import org.apache.doris.analysis.TupleDescriptor;
+import org.apache.doris.analysis.TupleId;
+import org.apache.doris.thrift.TEqJoinCondition;
+import org.apache.doris.thrift.TExplainLevel;
+import org.apache.doris.thrift.TGroupJoinAggFunction;
+import org.apache.doris.thrift.TGroupJoinAggOutputMode;
+import org.apache.doris.thrift.TGroupJoinAggSide;
+import org.apache.doris.thrift.TGroupJoinNode;
+import org.apache.doris.thrift.TJoinDistributionType;
+import org.apache.doris.thrift.TPlanNode;
+import org.apache.doris.thrift.TPlanNodeType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * GroupJoin operator that fuses a hash join with a hash aggregation.
+ * <p>
+ * The hash table is shared between join probing and aggregation state storage.
+ * Because the join shuffle already distributes data by the join/group key,
+ * only single-stage aggregation (FINAL_RESULT) is needed.
+ */
+public class GroupJoinNode extends PlanNode {
+
+    private JoinOperator joinOp;
+    private final List<BinaryPredicate> eqJoinConjuncts = Lists.newArrayList();
+    private List<Expr> groupingExprs;
+    private List<Expr> aggregateFunctions;
+    private List<TGroupJoinAggSide> aggSides;
+    private TGroupJoinAggOutputMode aggOutputMode;
+    private TupleDescriptor outputTupleDesc;
+    private DistributionMode distrMode;
+    private boolean isColocate = false;
+
+    // Intermediate tuple descriptors for left and right child outputs
+    private List<TupleDescriptor> vIntermediateTupleDescList = 
Lists.newArrayList();
+    private final Set<SlotId> hashOutputSlotIds = Sets.newHashSet();
+
+    public GroupJoinNode(PlanNodeId id, PlanNode leftChild, PlanNode 
rightChild) {
+        super(id, "GROUP JOIN");
+        // Add both children's output tuple ids for row_tuples thrift field
+        tupleIds.addAll(leftChild.getOutputTupleIds());
+        tupleIds.addAll(rightChild.getOutputTupleIds());
+        children.add(leftChild);
+        children.add(rightChild);
+    }
+
+    @Override
+    public ArrayList<TupleId> getOutputTupleIds() {
+        if (outputTupleDesc != null) {
+            return Lists.newArrayList(outputTupleDesc.getId());
+        }
+        return tupleIds;
+    }
+
+    public void setJoinOp(JoinOperator joinOp) {
+        this.joinOp = joinOp;
+    }
+
+    public JoinOperator getJoinOp() {
+        return joinOp;
+    }
+
+    public void addEqJoinConjunct(BinaryPredicate conjunct) {
+        eqJoinConjuncts.add(conjunct);
+    }
+
+    public List<BinaryPredicate> getEqJoinConjuncts() {
+        return eqJoinConjuncts;
+    }
+
+    public void setGroupingExprs(List<Expr> groupingExprs) {
+        this.groupingExprs = groupingExprs;
+    }
+
+    public List<Expr> getGroupingExprs() {
+        return groupingExprs;
+    }
+
+    public void setAggregateFunctions(List<Expr> aggregateFunctions) {
+        this.aggregateFunctions = aggregateFunctions;
+    }
+
+    public List<Expr> getAggregateFunctions() {
+        return aggregateFunctions;
+    }
+
+    public void setAggSides(List<TGroupJoinAggSide> aggSides) {
+        this.aggSides = aggSides;
+    }
+
+    public List<TGroupJoinAggSide> getAggSides() {
+        return aggSides;
+    }
+
+    public void setAggOutputMode(TGroupJoinAggOutputMode aggOutputMode) {
+        this.aggOutputMode = aggOutputMode;
+    }
+
+    public TGroupJoinAggOutputMode getAggOutputMode() {
+        return aggOutputMode;
+    }
+
+    public void setOutputTupleDesc(TupleDescriptor outputTupleDesc) {
+        this.outputTupleDesc = outputTupleDesc;
+        if (outputTupleDesc != null) {
+            tupleIds.add(outputTupleDesc.getId());
+        }
+    }
+
+    @Override
+    public TupleDescriptor getOutputTupleDesc() {
+        return outputTupleDesc;
+    }
+
+    public void setDistributionMode(DistributionMode distrMode) {
+        this.distrMode = distrMode;
+    }
+
+    public DistributionMode getDistributionMode() {
+        return distrMode;
+    }
+
+    public void setColocate(boolean colocate) {
+        this.isColocate = colocate;
+    }
+
+    public boolean isColocate() {
+        return isColocate;
+    }
+
+    public void setvIntermediateTupleDescList(List<TupleDescriptor> 
vIntermediateTupleDescList) {
+        this.vIntermediateTupleDescList = vIntermediateTupleDescList;
+    }
+
+    public List<TupleDescriptor> getvIntermediateTupleDescList() {
+        return vIntermediateTupleDescList;
+    }
+
+    public void addSlotIdToHashOutputSlotIds(SlotId slotId) {
+        hashOutputSlotIds.add(slotId);
+    }
+
+    public Set<SlotId> getHashOutputSlotIds() {
+        return hashOutputSlotIds;
+    }
+
+    @Override
+    protected void toThrift(TPlanNode msg) {
+        msg.node_type = TPlanNodeType.GROUP_JOIN_NODE;
+        msg.group_join_node = new TGroupJoinNode();
+
+        // Join info
+        msg.group_join_node.join_op = joinOp.toThrift();
+        for (BinaryPredicate eqJoinPredicate : eqJoinConjuncts) {
+            TEqJoinCondition eqJoinCondition = new TEqJoinCondition(
+                    
ExprToThriftVisitor.treeToThrift(eqJoinPredicate.getChild(0)),
+                    
ExprToThriftVisitor.treeToThrift(eqJoinPredicate.getChild(1)));
+            
eqJoinCondition.setOpcode(ExprToThriftVisitor.toThriftOpcode(eqJoinPredicate.getOp()));
+            msg.group_join_node.addToEqJoinConjuncts(eqJoinCondition);
+        }
+        msg.group_join_node.setDistType(isColocate
+                ? TJoinDistributionType.COLOCATE : distrMode != null ? 
distrMode.toThrift()
+                : TJoinDistributionType.PARTITIONED);
+
+        // Aggregation info
+        for (Expr groupingExpr : groupingExprs) {
+            
msg.group_join_node.addToGroupingExprs(ExprToThriftVisitor.treeToThrift(groupingExpr));
+        }
+        // Build aggregate functions with side annotation
+        Preconditions.checkState(aggregateFunctions.size() == aggSides.size(),
+                "aggregateFunctions and aggSides must have same size");
+        for (int i = 0; i < aggregateFunctions.size(); i++) {
+            TGroupJoinAggFunction aggFunc = new TGroupJoinAggFunction();
+            
aggFunc.setAggregateFunction(ExprToThriftVisitor.treeToThrift(aggregateFunctions.get(i)));
+            aggFunc.setInputSide(aggSides.get(i));
+            msg.group_join_node.addToAggregateFunctions(aggFunc);
+        }
+        msg.group_join_node.setAggOutputMode(aggOutputMode);
+        if (outputTupleDesc != null) {
+            
msg.group_join_node.setOutputTupleId(outputTupleDesc.getId().asInt());
+        }
+    }
+
+    @Override
+    public String getNodeExplainString(String detailPrefix, TExplainLevel 
detailLevel) {
+        StringBuilder output = new StringBuilder();
+        output.append(detailPrefix).append("group join: ");
+        output.append(joinOp.toString()).append("\n");
+
+        output.append(detailPrefix).append("  join op: 
").append(joinOp).append("\n");
+        output.append(detailPrefix).append("  equi conjuncts: 
").append(eqJoinConjuncts).append("\n");
+        output.append(detailPrefix).append("  grouping exprs: 
").append(groupingExprs).append("\n");
+        output.append(detailPrefix).append("  agg functions: 
").append(aggregateFunctions).append("\n");
+        output.append(detailPrefix).append("  distribution: ")
+                .append(isColocate ? "COLOCATE" : distrMode).append("\n");
+
+        return output.toString();
+    }
+
+    @Override
+    public int getNumInstances() {
+        // Same as HashJoin: if colocate or broadcast, use right child's num 
instances
+        return Math.max(children.get(0).getNumInstances(), 
children.get(1).getNumInstances());
+    }
+
+    @Override
+    public ArrayList<TupleId> getTupleIds() {
+        ArrayList<TupleId> tupleIds = Lists.newArrayList();
+        tupleIds.addAll(super.getTupleIds());
+        if (outputTupleDesc != null) {
+            tupleIds.add(outputTupleDesc.getId());
+        }
+        return tupleIds;
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index bc4e80ae401..91abbc980c8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -1058,6 +1058,22 @@ public class SessionVariable implements Serializable, 
Writable {
     @VarAttrDef.VarAttr(name = "enable_aggregate_cse", needForward = true)
     public boolean enableAggregateCse = true;
 
+    // Enable GroupJoin fusion: fuse INNER hash join + hash aggregation into a 
single GroupJoin operator
+    // when the GROUP BY keys are equivalent to the equi-join keys. The hash 
table is reused for both
+    // join and aggregation, eliminating the need for a separate aggregation 
pass.
+    @VarAttrDef.VarAttr(name = "enable_group_join_fusion", needForward = true,
+            varType = VariableAnnotation.EXPERIMENTAL,
+            description = {"是否启用 GroupJoin 融合算子(实验特性)。"
+                    + "当 GROUP BY 键与 equi-join 键等价时,将 INNER hash join 与 hash 
agg 融合为单个算子",
+                    "Enable GroupJoin fusion (experimental). "
+                    + "Fuse INNER hash join + hash agg into a single operator 
when GROUP BY keys "
+                    + "are equivalent to equi-join keys"})
+    public boolean enableGroupJoinFusion = false;
+
+    public boolean isEnableGroupJoinFusion() {
+        return enableGroupJoinFusion;
+    }
+
     // Experimental: enable pushing down virtual slots (common 
sub-expressions) into OlapScan.
     // When false (default), the optimizer rule 
PushDownVirtualColumnsIntoOlapScan will not apply.
     @VarAttrDef.VarAttr(name = "enable_virtual_slot_for_cse", needForward = 
true,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to