This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch pick_4.0_58397 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 8ef96279ff02e364819cd96093ab729c3bca8ce6 Author: morrySnow <[email protected]> AuthorDate: Wed Dec 3 10:54:00 2025 +0800 branch-4.0: [refactor](grouping set) remove virtual slot reference #58397 picked from #58397 --- .../org/apache/doris/analysis/GroupingInfo.java | 8 +- .../org/apache/doris/analysis/VirtualSlotRef.java | 3 - .../glue/translator/ExpressionTranslator.java | 6 - .../glue/translator/PhysicalPlanTranslator.java | 32 ++-- .../glue/translator/PlanTranslatorContext.java | 25 +-- .../nereids/rules/analysis/CheckAfterRewrite.java | 32 ++-- .../nereids/rules/analysis/NormalizeRepeat.java | 51 +++---- .../mv/AbstractMaterializedViewAggregateRule.java | 44 ++---- .../nereids/rules/exploration/mv/StructInfo.java | 52 ++++--- .../LogicalRepeatToPhysicalRepeat.java | 1 + .../nereids/rules/rewrite/ExprIdRewriter.java | 28 +--- .../nereids/rules/rewrite/SetPreAggStatus.java | 35 +++-- .../doris/nereids/stats/ExpressionEstimation.java | 12 +- .../org/apache/doris/nereids/trees/TreeNode.java | 27 ++++ .../nereids/trees/copier/ExpressionDeepCopier.java | 29 ---- .../trees/copier/LogicalPlanDeepCopier.java | 4 +- .../trees/expressions/VirtualSlotReference.java | 170 --------------------- .../expressions/functions/scalar/Grouping.java | 2 +- .../expressions/functions/scalar/GroupingId.java | 2 +- .../functions/scalar/GroupingScalarFunction.java | 2 +- .../trees/expressions/functions/udf/JavaUdaf.java | 12 +- .../trees/expressions/functions/udf/JavaUdf.java | 12 +- .../trees/expressions/functions/udf/JavaUdtf.java | 12 +- .../expressions/visitor/ExpressionVisitor.java | 5 - .../doris/nereids/trees/plans/algebra/Repeat.java | 54 ++----- .../nereids/trees/plans/logical/LogicalRepeat.java | 85 +++++++---- .../plans/physical/PhysicalHashAggregate.java | 10 -- .../trees/plans/physical/PhysicalRepeat.java | 28 ++-- .../apache/doris/nereids/util/ExpressionUtils.java | 11 +- .../org/apache/doris/planner/ResultFileSink.java | 27 ---- .../properties/ChildOutputPropertyDeriverTest.java | 9 ++ .../trees/copier/LogicalPlanDeepCopierTest.java | 4 + 32 files changed, 275 insertions(+), 559 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/GroupingInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/GroupingInfo.java index 2e17b2c8c94..83f2d8b1215 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/GroupingInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/GroupingInfo.java @@ -21,20 +21,14 @@ import java.util.List; import java.util.Objects; public class GroupingInfo { - public static final String GROUPING_PREFIX = "GROUPING_PREFIX_"; - private TupleDescriptor virtualTuple; private TupleDescriptor outputTupleDesc; - private GroupByClause.GroupingType groupingType; private List<Expr> preRepeatExprs; /** * Used by new optimizer. */ - public GroupingInfo(GroupByClause.GroupingType groupingType, TupleDescriptor virtualTuple, - TupleDescriptor outputTupleDesc, List<Expr> preRepeatExprs) { - this.groupingType = groupingType; - this.virtualTuple = Objects.requireNonNull(virtualTuple, "virtualTuple can not be null"); + public GroupingInfo(TupleDescriptor outputTupleDesc, List<Expr> preRepeatExprs) { this.outputTupleDesc = Objects.requireNonNull(outputTupleDesc, "outputTupleDesc can not be null"); this.preRepeatExprs = Objects.requireNonNull(preRepeatExprs, "preRepeatExprs can not be null"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/VirtualSlotRef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/VirtualSlotRef.java index b580aa63446..6080c6c666e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/VirtualSlotRef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/VirtualSlotRef.java @@ -25,8 +25,6 @@ import java.util.Set; * It like a SlotRef except that it is not a real column exist in table. */ public class VirtualSlotRef extends SlotRef { - // results of analysis slot - private TupleDescriptor tupleDescriptor; private List<Expr> realSlots; protected VirtualSlotRef(VirtualSlotRef other) { @@ -34,7 +32,6 @@ public class VirtualSlotRef extends SlotRef { if (other.realSlots != null) { realSlots = Expr.cloneList(other.realSlots); } - tupleDescriptor = other.tupleDescriptor; } public VirtualSlotRef(SlotDescriptor desc) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java index 121b7a05aa3..02e49373b3f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java @@ -80,7 +80,6 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; import org.apache.doris.nereids.trees.expressions.TryCast; import org.apache.doris.nereids.trees.expressions.UnaryArithmetic; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; @@ -764,11 +763,6 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra return timestampArithmeticExpr; } - @Override - public Expr visitVirtualReference(VirtualSlotReference virtualSlotReference, PlanTranslatorContext context) { - return context.findSlotRef(virtualSlotReference.getExprId()); - } - @Override public Expr visitIsNull(IsNull isNull, PlanTranslatorContext context) { IsNullPredicate isNullPredicate = new IsNullPredicate(isNull.child().accept(this, context), false, true); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index c7b79c80c72..787a34701b5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -105,10 +105,10 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WindowFrame; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; +import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; import org.apache.doris.nereids.trees.plans.AbstractPlan; import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.AggPhase; @@ -2499,17 +2499,13 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla PlanFragment inputPlanFragment = repeat.child(0).accept(this, context); List<List<Expr>> distributeExprLists = getDistributeExprs(repeat.child(0)); - Set<VirtualSlotReference> sortedVirtualSlots = repeat.getSortedVirtualSlots(); - TupleDescriptor virtualSlotsTuple = - generateTupleDesc(ImmutableList.copyOf(sortedVirtualSlots), null, context); - ImmutableSet<Expression> flattenGroupingSetExprs = ImmutableSet.copyOf( ExpressionUtils.flatExpressions(repeat.getGroupingSets())); List<Slot> aggregateFunctionUsedSlots = repeat.getOutputExpressions() .stream() - .filter(output -> !(output instanceof VirtualSlotReference)) .filter(output -> !flattenGroupingSetExprs.contains(output)) + .filter(output -> !output.containsType(GroupingScalarFunction.class)) .distinct() .map(NamedExpression::toSlot) .collect(ImmutableList.toImmutableList()); @@ -2519,11 +2515,18 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla .map(expr -> ExpressionTranslator.translate(expr, context)).collect(ImmutableList.toImmutableList()); // outputSlots's order need same with preRepeatExprs - List<Slot> outputSlots = Stream + List<Slot> outputSlots = Stream.concat(Stream .concat(repeat.getOutputExpressions().stream() .filter(output -> flattenGroupingSetExprs.contains(output)), repeat.getOutputExpressions().stream() - .filter(output -> !flattenGroupingSetExprs.contains(output)).distinct()) + .filter(output -> !flattenGroupingSetExprs.contains(output)) + .filter(output -> !output.containsType(GroupingScalarFunction.class)) + .distinct() + ), + Stream.concat(Stream.of(repeat.getGroupingId().toSlot()), + repeat.getOutputExpressions().stream() + .filter(output -> output.containsType(GroupingScalarFunction.class))) + ) .map(NamedExpression::toSlot).collect(ImmutableList.toImmutableList()); // NOTE: we should first translate preRepeatExprs, then generate output tuple, @@ -2532,8 +2535,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla TupleDescriptor outputTuple = generateTupleDesc(outputSlots, null, context); // cube and rollup already convert to grouping sets in LogicalPlanBuilder.withAggregate() - GroupingInfo groupingInfo = new GroupingInfo( - GroupingType.GROUPING_SETS, virtualSlotsTuple, outputTuple, preRepeatExprs); + GroupingInfo groupingInfo = new GroupingInfo(outputTuple, preRepeatExprs); List<Set<Integer>> repeatSlotIdList = repeat.computeRepeatSlotIdList(getSlotIds(outputTuple)); Set<Integer> allSlotId = repeatSlotIdList.stream() @@ -2542,7 +2544,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla RepeatNode repeatNode = new RepeatNode(context.nextPlanNodeId(), inputPlanFragment.getPlanRoot(), groupingInfo, repeatSlotIdList, - allSlotId, repeat.computeVirtualSlotValues(sortedVirtualSlots)); + allSlotId, repeat.computeGroupingFunctionsValues()); repeatNode.setNereidsId(repeat.getId()); context.getNereidsIdToPlanNodeIdMap().put(repeat.getId(), repeatNode.getId()); repeatNode.setChildrenDistributeExprLists(distributeExprLists); @@ -2938,17 +2940,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla private List<SlotReference> collectGroupBySlots(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions) { List<SlotReference> groupSlots = Lists.newArrayList(); - Set<VirtualSlotReference> virtualSlotReferences = groupByExpressions.stream() - .filter(VirtualSlotReference.class::isInstance) - .map(VirtualSlotReference.class::cast) - .collect(Collectors.toSet()); for (Expression e : groupByExpressions) { if (e instanceof SlotReference && outputExpressions.stream().anyMatch(o -> o.anyMatch(e::equals))) { groupSlots.add((SlotReference) e); - } else if (e instanceof SlotReference && !virtualSlotReferences.isEmpty()) { - // When there is a virtualSlot, it is a groupingSets scenario, - // and the original exprId should be retained at this time. - groupSlots.add((SlotReference) e); } else { groupSlots.add(new SlotReference(e.toSql(), e.getDataType(), e.nullable(), ImmutableList.of())); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java index 3a28faba94a..7e1c2b0c2cc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java @@ -24,7 +24,6 @@ import org.apache.doris.analysis.SlotId; import org.apache.doris.analysis.SlotRef; import org.apache.doris.analysis.TupleDescriptor; import org.apache.doris.analysis.TupleId; -import org.apache.doris.analysis.VirtualSlotRef; import org.apache.doris.catalog.Column; import org.apache.doris.catalog.TableIf; import org.apache.doris.common.IdGenerator; @@ -34,7 +33,6 @@ import org.apache.doris.nereids.processor.post.runtimefilterv2.RuntimeFilterCont import org.apache.doris.nereids.trees.expressions.CTEId; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer; import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer; @@ -299,23 +297,16 @@ public class PlanTranslatorContext { // Only the SlotDesc that in the tuple generated for scan node would have corresponding column. Optional<Column> column = slotReference.getOriginalColumn(); column.ifPresent(slotDescriptor::setColumn); + slotDescriptor.setLabel(slotReference.getName()); slotDescriptor.setType(slotReference.getDataType().toCatalogDataType()); SlotRef slotRef; - if (slotReference instanceof VirtualSlotReference) { - slotRef = new VirtualSlotRef(slotDescriptor); - VirtualSlotReference virtualSlot = (VirtualSlotReference) slotReference; - slotDescriptor.setColumn(new Column( - virtualSlot.getName(), virtualSlot.getDataType().toCatalogDataType())); - slotDescriptor.setLabel(slotReference.getName()); - } else { - slotRef = new SlotRef(slotDescriptor); - if (slotReference.hasSubColPath() && slotReference.getOriginalColumn().isPresent()) { - slotDescriptor.setSubColLables(slotReference.getSubPath()); - // use lower case name for variant's root, since backend treat parent column as lower case - // see issue: https://github.com/apache/doris/pull/32999/commits - slotDescriptor.setMaterializedColumnName(slotRef.getColumnName().toLowerCase() - + "." + String.join(".", slotReference.getSubPath())); - } + slotRef = new SlotRef(slotDescriptor); + if (slotReference.hasSubColPath() && slotReference.getOriginalColumn().isPresent()) { + slotDescriptor.setSubColLables(slotReference.getSubPath()); + // use lower case name for variant's root, since backend treat parent column as lower case + // see issue: https://github.com/apache/doris/pull/32999/commits + slotDescriptor.setMaterializedColumnName(slotRef.getColumnName().toLowerCase() + + "." + String.join(".", slotReference.getSubPath())); } slotRef.setTable(table); slotRef.setLabel(slotReference.getName()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java index d923444e84f..cbec9deb0eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java @@ -29,7 +29,6 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren; import org.apache.doris.nereids.trees.expressions.SubqueryExpr; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; @@ -43,15 +42,15 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalDeferMaterializeOlapS import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; -import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import org.apache.commons.lang3.StringUtils; import org.roaringbitmap.RoaringBitmap; -import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -74,6 +73,7 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory { private void checkUnexpectedExpression(Plan plan) { boolean isGenerate = plan instanceof Generate; boolean isAgg = plan instanceof LogicalAggregate; + boolean isRepeat = plan instanceof LogicalRepeat; boolean isWindow = plan instanceof LogicalWindow; boolean notAggAndWindow = !isAgg && !isWindow; @@ -85,7 +85,7 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory { throw new AnalysisException("table generating function is not allowed in " + plan.getType()); } else if (notAggAndWindow && expr instanceof AggregateFunction) { throw new AnalysisException("aggregate function is not allowed in " + plan.getType()); - } else if (!isAgg && expr instanceof GroupingScalarFunction) { + } else if (!isRepeat && expr instanceof GroupingScalarFunction) { throw new AnalysisException("grouping scalar function is not allowed in " + plan.getType()); } else if (!isWindow && (expr instanceof WindowExpression || expr instanceof WindowFunction)) { throw new AnalysisException("analytic function is not allowed in " + plan.getType()); @@ -98,16 +98,20 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory { Set<Slot> inputSlots = plan.getInputSlots(); RoaringBitmap childrenOutput = plan.getChildrenOutputExprIdBitSet(); - ImmutableSet.Builder<Slot> notFromChildrenBuilder = ImmutableSet.builderWithExpectedSize(inputSlots.size()); + Set<Slot> notFromChildren = Sets.newHashSet(); for (Slot inputSlot : inputSlots) { if (!childrenOutput.contains(inputSlot.getExprId().asInt())) { - notFromChildrenBuilder.add(inputSlot); + notFromChildren.add(inputSlot); } } - Set<Slot> notFromChildren = notFromChildrenBuilder.build(); + if (notFromChildren.isEmpty()) { return; } + if (plan instanceof LogicalRepeat) { + LogicalRepeat repeat = (LogicalRepeat) plan; + notFromChildren.remove(repeat.getGroupingId().get()); + } notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, childrenOutput); if (!notFromChildren.isEmpty()) { if (plan.arity() != 0 && plan.child(0) instanceof LogicalAggregate) { @@ -130,19 +134,7 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory { private Set<Slot> removeValidSlotsNotFromChildren(Set<Slot> slots, RoaringBitmap childrenOutput) { return slots.stream() .filter(expr -> { - if (expr instanceof VirtualSlotReference) { - List<Expression> realExpressions = ((VirtualSlotReference) expr).getRealExpressions(); - if (realExpressions.isEmpty()) { - // valid - return false; - } - return realExpressions.stream() - .map(Expression::getInputSlots) - .flatMap(Set::stream) - .anyMatch(realUsedExpr -> !childrenOutput.contains(realUsedExpr.getExprId().asInt())); - } else { - return !(expr instanceof SlotNotFromChildren); - } + return !(expr instanceof SlotNotFromChildren); }) .collect(Collectors.toSet()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 36bcbea1f12..319e8209771 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; @@ -38,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.algebra.Repeat; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import org.apache.doris.qe.SqlModeHelper; @@ -86,7 +86,7 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { @Override public Rule build() { return RuleType.NORMALIZE_REPEAT.build( - logicalRepeat(any()).when(LogicalRepeat::canBindVirtualSlot).then(repeat -> { + logicalRepeat(any()).whenNot(r -> r.getGroupingId().isPresent()).then(repeat -> { if (repeat.getGroupingSets().size() == 1 && ExpressionUtils.collect(repeat.getOutputExpressions(), GroupingScalarFunction.class::isInstance).isEmpty()) { @@ -154,38 +154,30 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { // rewrite grouping scalar function to virtual slots // rewrite the arguments of agg function to slots List<NamedExpression> normalizedAggOutput = Lists.newArrayList(); + List<NamedExpression> groupingFunctions = Lists.newArrayList(); for (Expression expr : repeat.getOutputExpressions()) { Expression rewrittenExpr = expr.rewriteDownShortCircuit( - e -> normalizeAggFuncChildrenAndGroupingScalarFunc(argsContext, e)); + e -> normalizeAggFuncChildrenAndGroupingScalarFunc(argsContext, e, groupingFunctions)); normalizedAggOutput.add((NamedExpression) rewrittenExpr); } // use groupingExprContext rewrite the normalizedAggOutput normalizedAggOutput = groupingExprContext.normalizeToUseSlotRef(normalizedAggOutput); - Set<VirtualSlotReference> virtualSlotsInFunction = - ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance); - - List<VirtualSlotReference> allVirtualSlots = ImmutableList.<VirtualSlotReference>builder() - // add the virtual grouping id slot - .add(Repeat.generateVirtualGroupingIdSlot()) - // add other virtual slots in the grouping scalar functions - .addAll(virtualSlotsInFunction) - .build(); - - Set<SlotReference> aggUsedNonVirtualSlots = ExpressionUtils.collect( + Set<SlotReference> aggUsedSlots = ExpressionUtils.collect( normalizedAggOutput, expr -> expr.getClass().equals(SlotReference.class)); Set<Slot> groupingSetsUsedSlot = ImmutableSet.copyOf( ExpressionUtils.flatExpressions(normalizedGroupingSets)); SetView<SlotReference> aggUsedSlotNotInGroupBy - = Sets.difference(aggUsedNonVirtualSlots, groupingSetsUsedSlot); + = Sets.difference(Sets.difference(aggUsedSlots, groupingFunctions.stream() + .map(NamedExpression::toSlot).collect(Collectors.toSet())), groupingSetsUsedSlot); - List<Slot> normalizedRepeatOutput = ImmutableList.<Slot>builder() + List<NamedExpression> normalizedRepeatOutput = ImmutableList.<NamedExpression>builder() .addAll(groupingSetsUsedSlot) .addAll(aggUsedSlotNotInGroupBy) - .addAll(allVirtualSlots) + .addAll(groupingFunctions) .build(); // 3 parts need push down: @@ -206,12 +198,14 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { Plan normalizedChild = pushDownProject(pushedProject, repeat.child()); + SlotReference groupingId = new SlotReference(Repeat.COL_GROUPING_ID, BigIntType.INSTANCE, false); LogicalRepeat<Plan> normalizedRepeat = repeat.withNormalizedExpr( - (List) normalizedGroupingSets, (List) normalizedRepeatOutput, normalizedChild); + (List) normalizedGroupingSets, normalizedRepeatOutput, groupingId, normalizedChild); List<Expression> normalizedAggGroupBy = ImmutableList.<Expression>builder() .addAll(groupingSetsUsedSlot) - .addAll(allVirtualSlots) + .addAll(groupingFunctions.stream().map(NamedExpression::toSlot).collect(Collectors.toList())) + .add(groupingId) .build(); normalizedAggOutput = getExprIdUnchangedNormalizedAggOutput(normalizedAggOutput, repeat.getOutputExpressions()); @@ -314,7 +308,7 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { } private static Expression normalizeAggFuncChildrenAndGroupingScalarFunc(NormalizeToSlotContext context, - Expression expr) { + Expression expr, List<NamedExpression> groupingSetExpressions) { if (expr instanceof AggregateFunction) { AggregateFunction function = (AggregateFunction) expr; List<Expression> normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); @@ -325,7 +319,9 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { List<Expression> normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); function = function.withChildren(normalizedRealExpressions); // eliminate GroupingScalarFunction and replace to VirtualSlotReference - return Repeat.generateVirtualSlotByFunction(function); + Alias alias = new Alias(function, Repeat.generateVirtualSlotName(function)); + groupingSetExpressions.add(alias); + return alias.toSlot(); } else { return expr; } @@ -379,24 +375,13 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { repeat = repeat.withChildren(ImmutableList.of(newLogicalProject)); // modify repeat outputs - List<Slot> originRepeatSlots = repeat.getOutput(); - List<Slot> virtualSlots = Lists.newArrayList(); - List<Slot> nonVirtualSlots = Lists.newArrayList(); - for (Slot slot : originRepeatSlots) { - if (slot instanceof VirtualSlotReference) { - virtualSlots.add(slot); - } else { - nonVirtualSlots.add(slot); - } - } List<Slot> newSlots = Lists.newArrayList(); for (Alias alias : newAliases) { newSlots.add(alias.toSlot()); } repeat = repeat.withAggOutput(ImmutableList.<NamedExpression>builder() - .addAll(nonVirtualSlots) + .addAll(repeat.getOutputExpressions()) .addAll(newSlots) - .addAll(virtualSlots) .build()); aggregate = aggregate.withChildren(ImmutableList.of(repeat)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java index b4e867e305b..6f85ea76ad3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java @@ -43,13 +43,10 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.functions.Function; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.algebra.Repeat; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -176,7 +173,6 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate MaterializationContext materializationContext, ExpressionRewriteMode groupByMode, ExpressionRewriteMode aggregateFunctionMode) { - // try to roll up. // split the query top plan expressions to group expressions and functions, if can not, bail out. Pair<Set<? extends Expression>, Set<? extends Expression>> queryGroupAndFunctionPair @@ -320,7 +316,7 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate queryStructInfo.getTableBitSet()); AggregateExpressionRewriteContext expressionRewriteContext = new AggregateExpressionRewriteContext( rewriteMode, mvShuttledExprToMvScanExprQueryBased, queryStructInfo.getTopPlan(), - queryStructInfo.getTableBitSet()); + queryStructInfo.getTableBitSet(), queryStructInfo.getGroupingId()); Expression rewrittenExpression = queryFunctionShuttled.accept(AGGREGATE_EXPRESSION_REWRITER, expressionRewriteContext); if (!expressionRewriteContext.isValid()) { @@ -740,39 +736,13 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate return aggregateFunction; } - @Override - public Expression visitGroupingScalarFunction(GroupingScalarFunction groupingScalarFunction, - AggregateExpressionRewriteContext context) { - List<Expression> children = groupingScalarFunction.children(); - List<Expression> rewrittenChildren = new ArrayList<>(); - for (Expression child : children) { - Expression rewrittenChild = child.accept(this, context); - if (!context.isValid()) { - return groupingScalarFunction; - } - rewrittenChildren.add(rewrittenChild); - } - return groupingScalarFunction.withChildren(rewrittenChildren); - } - @Override public Expression visitSlot(Slot slot, AggregateExpressionRewriteContext rewriteContext) { if (!rewriteContext.isValid()) { return slot; } - if (slot instanceof VirtualSlotReference) { - Optional<GroupingScalarFunction> originExpression = ((VirtualSlotReference) slot).getOriginExpression(); - if (!originExpression.isPresent()) { - return Repeat.generateVirtualGroupingIdSlot(); - } else { - GroupingScalarFunction groupingScalarFunction = originExpression.get(); - groupingScalarFunction = - (GroupingScalarFunction) groupingScalarFunction.accept(this, rewriteContext); - if (!rewriteContext.isValid()) { - return slot; - } - return Repeat.generateVirtualSlotByFunction(groupingScalarFunction); - } + if (rewriteContext.getGroupingId().isPresent() && slot.equals(rewriteContext.getGroupingId().get())) { + return slot; } if (rewriteContext.getMvExprToMvScanExprQueryBasedMapping().containsKey(slot)) { return rewriteContext.getMvExprToMvScanExprQueryBasedMapping().get(slot); @@ -817,14 +787,16 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate private final Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping; private final Plan queryTopPlan; private final BitSet queryTableBitSet; + private final Optional<SlotReference> groupingId; public AggregateExpressionRewriteContext(ExpressionRewriteMode expressionRewriteMode, Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping, Plan queryTopPlan, - BitSet queryTableBitSet) { + BitSet queryTableBitSet, Optional<SlotReference> groupingId) { this.expressionRewriteMode = expressionRewriteMode; this.mvExprToMvScanExprQueryBasedMapping = mvExprToMvScanExprQueryBasedMapping; this.queryTopPlan = queryTopPlan; this.queryTableBitSet = queryTableBitSet; + this.groupingId = groupingId; } public boolean isValid() { @@ -851,6 +823,10 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate return queryTableBitSet; } + public Optional<SlotReference> getGroupingId() { + return groupingId; + } + /** * The expression rewrite mode, which decide how the expression in query is rewritten by mv */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java index eeaf80f4b79..79c6d2ee119 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java @@ -78,6 +78,7 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -112,6 +113,8 @@ public class StructInfo { private final Map<RelationId, StructInfoNode> relationIdStructInfoNodeMap; // this recorde the predicates which can pull up, not shuttled private final Predicates predicates; + // this record the grouping id generated by repeat node + private final Optional<SlotReference> groupingId; // split predicates is shuttled private SplitPredicate splitPredicate; private EquivalenceClass equivalenceClass; @@ -141,6 +144,7 @@ public class StructInfo { Plan bottomPlan, List<CatalogRelation> relations, Map<RelationId, StructInfoNode> relationIdStructInfoNodeMap, @Nullable Predicates predicates, + Optional<SlotReference> groupingId, Map<ExpressionPosition, Multimap<Expression, Pair<Expression, HyperElement>>> shuttledExpressionsToExpressionsMap, Map<ExpressionPosition, Map<Expression, Expression>> expressionToShuttledExpressionToMap, @@ -158,6 +162,7 @@ public class StructInfo { this.tableBitSet = tableIdSet; this.relationIdStructInfoNodeMap = relationIdStructInfoNodeMap; this.predicates = predicates; + this.groupingId = groupingId; this.splitPredicate = splitPredicate; this.equivalenceClass = equivalenceClass; this.shuttledExpressionsToExpressionsMap = shuttledExpressionsToExpressionsMap; @@ -170,7 +175,7 @@ public class StructInfo { */ public StructInfo withPredicates(Predicates predicates) { return new StructInfo(this.originalPlan, this.originalPlanId, this.hyperGraph, this.valid, this.topPlan, - this.bottomPlan, this.relations, this.relationIdStructInfoNodeMap, predicates, + this.bottomPlan, this.relations, this.relationIdStructInfoNodeMap, predicates, this.groupingId, this.shuttledExpressionsToExpressionsMap, this.expressionToShuttledExpressionToMap, this.tableBitSet, null, null, this.planOutputShuttledExpressions); } @@ -296,27 +301,11 @@ public class StructInfo { PlanSplitContext planSplitContext = new PlanSplitContext(set); // if single table without join, the bottom is derivedPlan.accept(PLAN_SPLITTER, planSplitContext); - return StructInfo.of(originalPlan, planSplitContext.getTopPlan(), planSplitContext.getBottomPlan(), - HyperGraph.builderForMv(planSplitContext.getBottomPlan()).build(), cascadesContext); - } - - /** - * The construct method for init StructInfo - */ - public static StructInfo of(Plan originalPlan, @Nullable Plan topPlan, @Nullable Plan bottomPlan, - HyperGraph hyperGraph, - CascadesContext cascadesContext) { + Plan topPlan = planSplitContext.getTopPlan(); + Plan bottomPlan = planSplitContext.getBottomPlan(); + HyperGraph hyperGraph = HyperGraph.builderForMv(planSplitContext.getBottomPlan()).build(); ObjectId originalPlanId = originalPlan.getGroupExpression() .map(GroupExpression::getId).orElseGet(() -> new ObjectId(-1)); - // if any of topPlan or bottomPlan is null, split the top plan to two parts by join node - if (topPlan == null || bottomPlan == null) { - Set<Class<? extends Plan>> set = Sets.newLinkedHashSet(); - set.add(LogicalJoin.class); - PlanSplitContext planSplitContext = new PlanSplitContext(set); - originalPlan.accept(PLAN_SPLITTER, planSplitContext); - bottomPlan = planSplitContext.getBottomPlan(); - topPlan = planSplitContext.getTopPlan(); - } // collect struct info fromGraph List<CatalogRelation> relationList = new ArrayList<>(); Map<RelationId, StructInfoNode> relationIdStructInfoNodeMap = new LinkedHashMap<>(); @@ -346,10 +335,9 @@ public class StructInfo { List<? extends Expression> planOutputShuttledExpressions = ExpressionUtils.shuttleExpressionWithLineage(originalPlan.getOutput(), originalPlan, new BitSet()); return new StructInfo(originalPlan, originalPlanId, hyperGraph, valid, topPlan, bottomPlan, - relationList, relationIdStructInfoNodeMap, predicates, shuttledHashConjunctsToConjunctsMap, - expressionToShuttledExpressionToMap, - tableBitSet, null, null, - planOutputShuttledExpressions); + relationList, relationIdStructInfoNodeMap, predicates, planSplitContext.getGroupingId(), + shuttledHashConjunctsToConjunctsMap, expressionToShuttledExpressionToMap, + tableBitSet, null, null, planOutputShuttledExpressions); } public List<CatalogRelation> getRelations() { @@ -360,6 +348,10 @@ public class StructInfo { return predicates; } + public Optional<SlotReference> getGroupingId() { + return groupingId; + } + public Plan getOriginalPlan() { return originalPlan; } @@ -557,6 +549,9 @@ public class StructInfo { if (context.getTopPlan() == null) { context.setTopPlan(plan); } + if (plan instanceof LogicalRepeat) { + context.setGroupingId(((LogicalRepeat<?>) plan).getGroupingId()); + } if (plan.children().isEmpty() && context.getBottomPlan() == null) { context.setBottomPlan(plan); return null; @@ -592,6 +587,7 @@ public class StructInfo { private Plan bottomPlan; private Plan topPlan; private Set<Class<? extends Plan>> boundaryPlanClazzSet; + private Optional<SlotReference> groupingId = Optional.empty(); public PlanSplitContext(Set<Class<? extends Plan>> boundaryPlanClazzSet) { this.boundaryPlanClazzSet = boundaryPlanClazzSet; @@ -613,6 +609,14 @@ public class StructInfo { this.topPlan = topPlan; } + public Optional<SlotReference> getGroupingId() { + return groupingId; + } + + public void setGroupingId(Optional<SlotReference> groupingId) { + this.groupingId = groupingId; + } + /** * isBoundary */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalRepeatToPhysicalRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalRepeatToPhysicalRepeat.java index a299eb9eb6b..00d89034327 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalRepeatToPhysicalRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalRepeatToPhysicalRepeat.java @@ -31,6 +31,7 @@ public class LogicalRepeatToPhysicalRepeat extends OneImplementationRuleFactory new PhysicalRepeat<>( repeat.getGroupingSets(), repeat.getOutputExpressions(), + repeat.getGroupingId().get(), repeat.getLogicalProperties(), repeat.child() ) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java index d3c0343a8c3..093abaadd13 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java @@ -29,8 +29,6 @@ import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; -import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.Plan; @@ -38,7 +36,6 @@ import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Map; -import java.util.Optional; /** replace SlotReference ExprId in logical plans */ public class ExprIdRewriter extends ExpressionRewrite { @@ -110,30 +107,7 @@ public class ExprIdRewriter extends ExpressionRewrite { matchesType(SlotReference.class).thenApply(ctx -> { Slot slot = ctx.expr; return slot.accept(SLOT_REPLACER, replaceMap); - }).toRule(ExpressionRuleType.EXPR_ID_REWRITE_REPLACE), - matchesType(VirtualSlotReference.class).thenApply(ctx -> { - VirtualSlotReference virtualSlot = ctx.expr; - return virtualSlot.accept(new DefaultExpressionRewriter<Map<ExprId, ExprId>>() { - @Override - public Expression visitVirtualReference(VirtualSlotReference virtualSlot, - Map<ExprId, ExprId> replaceMap) { - Optional<GroupingScalarFunction> originExpression = virtualSlot.getOriginExpression(); - if (!originExpression.isPresent()) { - return virtualSlot; - } - GroupingScalarFunction groupingScalarFunction = originExpression.get(); - GroupingScalarFunction rewrittenFunction = - (GroupingScalarFunction) groupingScalarFunction.accept( - SLOT_REPLACER, replaceMap); - if (!rewrittenFunction.children().equals(groupingScalarFunction.children())) { - return virtualSlot.withOriginExpressionAndComputeLongValueMethod( - Optional.of(rewrittenFunction), - rewrittenFunction::computeVirtualSlotValue); - } - return virtualSlot; - } - }, replaceMap); - }).toRule(ExpressionRuleType.VIRTUAL_EXPR_ID_REWRITE_REPLACE) + }).toRule(ExpressionRuleType.EXPR_ID_REWRITE_REPLACE) ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SetPreAggStatus.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SetPreAggStatus.java index fad3e02c0cf..2f7743cb86f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SetPreAggStatus.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SetPreAggStatus.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; @@ -38,6 +37,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; @@ -55,7 +55,6 @@ import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -66,6 +65,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.Stack; +import java.util.stream.Collectors; /** * SetPreAggStatus @@ -83,6 +83,7 @@ public class SetPreAggStatus extends DefaultPlanRewriter<Stack<SetPreAggStatus.P private List<Expression> filterConjuncts = new ArrayList<>(); private List<Expression> joinConjuncts = new ArrayList<>(); private List<Expression> groupByExpresssions = new ArrayList<>(); + private List<Expression> groupingScalarFunctionExpresssions = new ArrayList<>(); private Set<AggregateFunction> aggregateFunctions = new HashSet<>(); private Set<RelationId> olapScanIds = new HashSet<>(); @@ -108,9 +109,18 @@ public class SetPreAggStatus extends DefaultPlanRewriter<Stack<SetPreAggStatus.P private void addGroupByExpresssions(List<Expression> expressions) { groupByExpresssions.addAll(expressions); + groupByExpresssions.removeAll(groupingScalarFunctionExpresssions); groupByExpresssions = Lists.newArrayList(ExpressionUtils.replace(groupByExpresssions, replaceMap)); } + private void addGroupingScalarFunctionExpresssions(List<Expression> expressions) { + groupingScalarFunctionExpresssions.addAll(expressions); + } + + private void addGroupingScalarFunctionExpresssion(Expression expression) { + groupingScalarFunctionExpresssions.add(expression); + } + private void addAggregateFunctions(Set<AggregateFunction> functions) { aggregateFunctions.addAll(functions); Set<AggregateFunction> newAggregateFunctions = Sets.newHashSet(); @@ -191,7 +201,7 @@ public class SetPreAggStatus extends DefaultPlanRewriter<Stack<SetPreAggStatus.P if (!context.isEmpty()) { PreAggInfoContext preAggInfoContext = context.pop(); preAggInfoContext.addAggregateFunctions(logicalAggregate.getAggregateFunctions()); - preAggInfoContext.addGroupByExpresssions(nonVirtualGroupByExprs(logicalAggregate)); + preAggInfoContext.addGroupByExpresssions(logicalAggregate.getGroupByExpressions()); for (RelationId id : preAggInfoContext.olapScanIds) { olapScanPreAggContexts.put(id, preAggInfoContext); } @@ -200,14 +210,17 @@ public class SetPreAggStatus extends DefaultPlanRewriter<Stack<SetPreAggStatus.P } @Override - public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> logicalRepeat, Stack<PreAggInfoContext> context) { - return super.visit(logicalRepeat, context); - } - - private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends Plan> agg) { - return agg.getGroupByExpressions().stream() - .filter(expr -> !(expr instanceof VirtualSlotReference)) - .collect(ImmutableList.toImmutableList()); + public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, Stack<PreAggInfoContext> context) { + repeat = (LogicalRepeat<? extends Plan>) super.visit(repeat, context); + if (!context.isEmpty()) { + context.peek().addGroupingScalarFunctionExpresssion(repeat.getGroupingId().get()); + context.peek().addGroupingScalarFunctionExpresssions( + repeat.getOutputExpressions().stream() + .filter(e -> e.containsType(GroupingScalarFunction.class)) + .map(e -> e.toSlot()) + .collect(Collectors.toList())); + } + return repeat; } private static class SetOlapScanPreAgg extends DefaultPlanRewriter<Map<RelationId, PreAggInfoContext>> { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java index 2eb3c3f88c2..cb3234461b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java @@ -39,7 +39,6 @@ import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; @@ -300,7 +299,11 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta @Override public ColumnStatistic visitSlotReference(SlotReference slotReference, Statistics context) { - return context.findColumnStatistics(slotReference); + ColumnStatistic columnStatistic = context.findColumnStatistics(slotReference); + if (columnStatistic == null) { + return ColumnStatistic.UNKNOWN; + } + return columnStatistic; } @Override @@ -485,11 +488,6 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta return alias.child().accept(this, context); } - @Override - public ColumnStatistic visitVirtualReference(VirtualSlotReference virtualSlotReference, Statistics context) { - return ColumnStatistic.UNKNOWN; - } - @Override public ColumnStatistic visitBoundFunction(BoundFunction boundFunction, Statistics context) { return ColumnStatistic.UNKNOWN; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index a5e3ee9c5e1..b16ac1d954c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -222,6 +222,20 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> { } } + /** + * Foreach treeNode. Top-down traverse implicitly. + * @param func foreach function + */ + default void foreachWithTest(Consumer<TreeNode<NODE_TYPE>> func, Predicate<TreeNode<NODE_TYPE>> predicate) { + if (!predicate.test(this)) { + return; + } + func.accept(this); + for (NODE_TYPE child : children()) { + child.foreach(func); + } + } + /** foreachBreath */ default void foreachBreath(Predicate<TreeNode<NODE_TYPE>> func) { LinkedList<TreeNode<NODE_TYPE>> queue = new LinkedList<>(); @@ -288,6 +302,19 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> { return (Set<T>) result.build(); } + /** + * Collect the nodes that satisfied the predicate. + */ + default <T> Set<T> collectWithTest(Predicate<TreeNode<NODE_TYPE>> predicate, Predicate<TreeNode<NODE_TYPE>> test) { + ImmutableSet.Builder<TreeNode<NODE_TYPE>> result = ImmutableSet.builder(); + foreachWithTest(node -> { + if (predicate.test(node)) { + result.add(node); + } + }, test); + return (Set<T>) result.build(); + } + /** * Collect the nodes that satisfied the predicate to list. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/ExpressionDeepCopier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/ExpressionDeepCopier.java index db9f4306f16..59a82046ba0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/ExpressionDeepCopier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/ExpressionDeepCopier.java @@ -27,14 +27,9 @@ import org.apache.doris.nereids.trees.expressions.ScalarSubquery; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; -import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; -import org.apache.doris.nereids.trees.plans.algebra.Repeat.GroupingSetShapes; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import com.google.common.base.Function; - import java.util.List; import java.util.Map; import java.util.Optional; @@ -80,30 +75,6 @@ public class ExpressionDeepCopier extends DefaultExpressionRewriter<DeepCopierCo return slotReference.withExprId(newExprId); } - @Override - public Expression visitVirtualReference(VirtualSlotReference virtualSlotReference, DeepCopierContext context) { - Map<ExprId, ExprId> exprIdReplaceMap = context.exprIdReplaceMap; - ExprId newExprId; - if (exprIdReplaceMap.containsKey(virtualSlotReference.getExprId())) { - newExprId = exprIdReplaceMap.get(virtualSlotReference.getExprId()); - } else { - newExprId = StatementScopeIdGenerator.newExprId(); - } - // according to VirtualReference generating logic in Repeat.java - // generateVirtualGroupingIdSlot and generateVirtualSlotByFunction - Optional<GroupingScalarFunction> newOriginExpression = virtualSlotReference.getOriginExpression() - .map(func -> (GroupingScalarFunction) func.accept(this, context)); - Function<GroupingSetShapes, List<Long>> newFunction = newOriginExpression - .<Function<GroupingSetShapes, List<Long>>>map(f -> f::computeVirtualSlotValue) - .orElseGet(() -> GroupingSetShapes::computeVirtualGroupingIdValue); - VirtualSlotReference newOne = new VirtualSlotReference(newExprId, - virtualSlotReference.getName(), virtualSlotReference.getDataType(), - virtualSlotReference.nullable(), virtualSlotReference.getQualifier(), - newOriginExpression, newFunction); - exprIdReplaceMap.put(virtualSlotReference.getExprId(), newOne.getExprId()); - return newOne; - } - @Override public Expression visitArrayItemReference(ArrayItemReference arrayItemSlot, DeepCopierContext context) { Expression arrayExpression = arrayItemSlot.getArrayExpression().accept(this, context); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java index 71984291d9c..33e32fe92e7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java @@ -206,7 +206,9 @@ public class LogicalPlanDeepCopier extends DefaultPlanRewriter<DeepCopierContext List<NamedExpression> outputExpressions = repeat.getOutputExpressions().stream() .map(e -> (NamedExpression) ExpressionDeepCopier.INSTANCE.deepCopy(e, context)) .collect(ImmutableList.toImmutableList()); - return new LogicalRepeat<>(groupingSets, outputExpressions, child); + SlotReference groupingId = (SlotReference) ExpressionDeepCopier.INSTANCE + .deepCopy(repeat.getGroupingId().get(), context); + return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, child); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java deleted file mode 100644 index bac559f407d..00000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java +++ /dev/null @@ -1,170 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.trees.expressions; - -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; -import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; -import org.apache.doris.nereids.trees.plans.algebra.Repeat.GroupingSetShapes; -import org.apache.doris.nereids.types.DataType; - -import com.google.common.base.Function; -import com.google.common.collect.ImmutableList; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -/** - * it is not a real column exist in table. - */ -public class VirtualSlotReference extends SlotReference implements SlotNotFromChildren { - // arguments of GroupingScalarFunction - private final List<Expression> realExpressions; - - // if this VirtualSlotReference come from the GroupingScalarFunction, we will save it. - private final Optional<GroupingScalarFunction> originExpression; - - // save the method to compute the long value list, and then backend can fill the long - // value result for this VirtualSlotReference. - // this long values can compute by the shape of grouping sets. - private final Function<GroupingSetShapes, List<Long>> computeLongValueMethod; - - public VirtualSlotReference(String name, DataType dataType, Optional<GroupingScalarFunction> originExpression, - Function<GroupingSetShapes, List<Long>> computeLongValueMethod) { - this(StatementScopeIdGenerator.newExprId(), name, dataType, false, ImmutableList.of(), - originExpression, computeLongValueMethod); - } - - /** VirtualSlotReference */ - public VirtualSlotReference(ExprId exprId, String name, DataType dataType, - boolean nullable, List<String> qualifier, Optional<GroupingScalarFunction> originExpression, - Function<GroupingSetShapes, List<Long>> computeLongValueMethod) { - super(exprId, name, dataType, nullable, qualifier); - this.originExpression = Objects.requireNonNull(originExpression, "originExpression can not be null"); - this.realExpressions = originExpression.isPresent() - ? ImmutableList.copyOf(originExpression.get().getArguments()) - : ImmutableList.of(); - this.computeLongValueMethod = - Objects.requireNonNull(computeLongValueMethod, "computeLongValueMethod can not be null"); - } - - public List<Expression> getRealExpressions() { - return realExpressions; - } - - public Optional<GroupingScalarFunction> getOriginExpression() { - return originExpression; - } - - public Function<GroupingSetShapes, List<Long>> getComputeLongValueMethod() { - return computeLongValueMethod; - } - - @Override - public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { - return visitor.visitVirtualReference(this, context); - } - - @Override - public String computeToSql() { - return getName(); - } - - @Override - public String toString() { - // Just return name and exprId, add another method to show fully qualified name when it's necessary. - String str = getName() + "#" + getExprId(); - - if (originExpression.isPresent()) { - str += " originExpression=" + originExpression.get(); - } - return str; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - VirtualSlotReference that = (VirtualSlotReference) o; - return Objects.equals(realExpressions, that.realExpressions) - && Objects.equals(originExpression, that.originExpression) - && super.equals(that); - } - - @Override - public int computeHashCode() { - return Objects.hash(realExpressions, originExpression, getExprId()); - } - - @Override - public boolean nullable() { - return false; - } - - public VirtualSlotReference withNullable(boolean nullable) { - if (this.nullable == nullable) { - return this; - } - return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, - originExpression, computeLongValueMethod); - } - - @Override - public Slot withNullableAndDataType(boolean nullable, DataType dataType) { - if (this.nullable == nullable && this.dataType.equals(dataType)) { - return this; - } - return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, - originExpression, computeLongValueMethod); - } - - @Override - public VirtualSlotReference withQualifier(List<String> qualifier) { - return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, - originExpression, computeLongValueMethod); - } - - @Override - public VirtualSlotReference withName(String name) { - return new VirtualSlotReference(exprId, name, dataType, nullable, qualifier, - originExpression, computeLongValueMethod); - } - - @Override - public VirtualSlotReference withExprId(ExprId exprId) { - return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, - originExpression, computeLongValueMethod); - } - - public VirtualSlotReference withOriginExpressionAndComputeLongValueMethod( - Optional<GroupingScalarFunction> originExpression, - Function<GroupingSetShapes, List<Long>> computeLongValueMethod) { - return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, - originExpression, computeLongValueMethod); - } - - @Override - public Slot withIndexInSql(Pair<Integer, Integer> index) { - return this; - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Grouping.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Grouping.java index 0592e9f6bb6..2655009a2c3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Grouping.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Grouping.java @@ -51,7 +51,7 @@ public class Grouping extends GroupingScalarFunction implements UnaryExpression, } @Override - public List<Long> computeVirtualSlotValue(GroupingSetShapes shapes) { + public List<Long> computeValue(GroupingSetShapes shapes) { int index = shapes.indexOf(child()); return shapes.shapes.stream() .map(groupingSetShape -> computeLongValue(groupingSetShape, index)) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingId.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingId.java index aa982801c9b..d3d8398b034 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingId.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingId.java @@ -57,7 +57,7 @@ public class GroupingId extends GroupingScalarFunction implements CustomSignatur } @Override - public List<Long> computeVirtualSlotValue(GroupingSetShapes shapes) { + public List<Long> computeValue(GroupingSetShapes shapes) { List<Expression> arguments = getArguments(); List<Integer> argumentIndexes = arguments.stream() .map(shapes::indexOf) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingScalarFunction.java index 4f810d04265..7844547c576 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingScalarFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/GroupingScalarFunction.java @@ -45,7 +45,7 @@ public abstract class GroupingScalarFunction extends ScalarFunction implements A /** * compute a long value that backend need to fill to the VirtualSlotRef */ - public abstract List<Long> computeVirtualSlotValue(GroupingSetShapes shapes); + public abstract List<Long> computeValue(GroupingSetShapes shapes); @Override public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdaf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdaf.java index 85143d2aca6..80cebf1c2fc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdaf.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdaf.java @@ -26,7 +26,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.util.URI; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.Udf; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; @@ -39,7 +39,6 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; -import java.util.Optional; import java.util.stream.Collectors; /** @@ -142,10 +141,9 @@ public class JavaUdaf extends AggregateFunction implements ExplicitlyCastableSig ? sigBuilder.varArgs(argTypes.toArray(new DataType[0])) : sigBuilder.args(argTypes.toArray(new DataType[0])); - VirtualSlotReference[] virtualSlots = argTypes.stream() - .map(type -> new VirtualSlotReference(type.toString(), type, Optional.empty(), - (shape) -> ImmutableList.of())) - .toArray(VirtualSlotReference[]::new); + SlotReference[] arguments = argTypes.stream() + .map(type -> new SlotReference(type.toString(), type)) + .toArray(SlotReference[]::new); DataType intermediateType = null; if (aggregate.getIntermediateType() != null) { @@ -168,7 +166,7 @@ public class JavaUdaf extends AggregateFunction implements ExplicitlyCastableSig aggregate.getChecksum(), aggregate.isStaticLoad(), aggregate.getExpirationTime(), - virtualSlots); + arguments); JavaUdafBuilder builder = new JavaUdafBuilder(udaf); Env.getCurrentEnv().getFunctionRegistry().addUdf(dbName, fnName, builder); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdf.java index 10771985f1b..408fe7903cc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdf.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdf.java @@ -26,7 +26,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.util.URI; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.Udf; import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; @@ -39,7 +39,6 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; -import java.util.Optional; import java.util.stream.Collectors; /** @@ -127,10 +126,9 @@ public class JavaUdf extends ScalarFunction implements ExplicitlyCastableSignatu ? sigBuilder.varArgs(argTypes.toArray(new DataType[0])) : sigBuilder.args(argTypes.toArray(new DataType[0])); - VirtualSlotReference[] virtualSlots = argTypes.stream() - .map(type -> new VirtualSlotReference(type.toString(), type, Optional.empty(), - (shape) -> ImmutableList.of())) - .toArray(VirtualSlotReference[]::new); + SlotReference[] arguments = argTypes.stream() + .map(type -> new SlotReference(type.toString(), type)) + .toArray(SlotReference[]::new); JavaUdf udf = new JavaUdf(fnName, scalar.getId(), dbName, scalar.getBinaryType(), sig, scalar.getNullableMode(), @@ -139,7 +137,7 @@ public class JavaUdf extends ScalarFunction implements ExplicitlyCastableSignatu scalar.getPrepareFnSymbol(), scalar.getCloseFnSymbol(), scalar.getChecksum(), scalar.isStaticLoad(), scalar.getExpirationTime(), - virtualSlots); + arguments); JavaUdfBuilder builder = new JavaUdfBuilder(udf); Env.getCurrentEnv().getFunctionRegistry().addUdf(dbName, fnName, builder); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtf.java index c90a8c343a3..287bdaec16c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtf.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtf.java @@ -26,7 +26,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.util.URI; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.Udf; import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; @@ -39,7 +39,6 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; -import java.util.Optional; import java.util.stream.Collectors; /** @@ -148,10 +147,9 @@ public class JavaUdtf extends TableGeneratingFunction implements ExplicitlyCasta ? sigBuilder.varArgs(argTypes.toArray(new DataType[0])) : sigBuilder.args(argTypes.toArray(new DataType[0])); - VirtualSlotReference[] virtualSlots = argTypes.stream() - .map(type -> new VirtualSlotReference(type.toString(), type, Optional.empty(), - (shape) -> ImmutableList.of())) - .toArray(VirtualSlotReference[]::new); + SlotReference[] arguments = argTypes.stream() + .map(type -> new SlotReference(type.toString(), type)) + .toArray(SlotReference[]::new); JavaUdtf udf = new JavaUdtf(fnName, scalar.getId(), dbName, scalar.getBinaryType(), sig, scalar.getNullableMode(), @@ -162,7 +160,7 @@ public class JavaUdtf extends TableGeneratingFunction implements ExplicitlyCasta scalar.getChecksum(), scalar.isStaticLoad(), scalar.getExpirationTime(), - virtualSlots); + arguments); JavaUdtfBuilder builder = new JavaUdtfBuilder(udf); Env.getCurrentEnv().getFunctionRegistry().addUdf(dbName, fnName, builder); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java index 981e0e964ce..cf7478bba69 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java @@ -83,7 +83,6 @@ import org.apache.doris.nereids.trees.expressions.UnaryArithmetic; import org.apache.doris.nereids.trees.expressions.UnaryOperator; import org.apache.doris.nereids.trees.expressions.Variable; import org.apache.doris.nereids.trees.expressions.VariableDesc; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.WindowFrame; @@ -459,10 +458,6 @@ public abstract class ExpressionVisitor<R, C> return visit(groupingScalarFunction, context); } - public R visitVirtualReference(VirtualSlotReference virtualSlotReference, C context) { - return visitSlotReference(virtualSlotReference, context); - } - public R visitArrayItemReference(ArrayItemReference arrayItemReference, C context) { return visit(arrayItemReference, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java index 06a1c7d47ff..e35b48073b5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java @@ -21,10 +21,8 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.BitUtils; import org.apache.doris.nereids.util.ExpressionUtils; @@ -38,7 +36,6 @@ import org.apache.commons.lang3.StringUtils; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -64,12 +61,6 @@ public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> { ImmutableList.Builder<NamedExpression> outputBuilder = ImmutableList.builderWithExpectedSize(prunedOutputs.size() + 1); outputBuilder.addAll(prunedOutputs); - for (NamedExpression output : getOutputExpressions()) { - Set<VirtualSlotReference> v = output.collect(VirtualSlotReference.class::isInstance); - if (v.stream().anyMatch(slot -> slot.getName().equals(COL_GROUPING_ID))) { - outputBuilder.add(output); - } - } // prune groupingSets, if parent operator do not need some exprs in grouping sets, we removed it. // this could not lead to wrong result because be repeat other columns by normal. ImmutableList.Builder<List<Expression>> groupingSetsBuilder @@ -90,17 +81,6 @@ public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> { Repeat<CHILD_PLAN> withGroupSetsAndOutput(List<List<Expression>> groupingSets, List<NamedExpression> outputExpressions); - static VirtualSlotReference generateVirtualGroupingIdSlot() { - return new VirtualSlotReference(COL_GROUPING_ID, BigIntType.INSTANCE, Optional.empty(), - GroupingSetShapes::computeVirtualGroupingIdValue); - } - - static VirtualSlotReference generateVirtualSlotByFunction(GroupingScalarFunction function) { - return new VirtualSlotReference( - generateVirtualSlotName(function), function.getDataType(), Optional.of(function), - function::computeVirtualSlotValue); - } - /** * get common grouping set expressions. * e.g. grouping sets((a, b, c), (b, c), (c)) @@ -120,33 +100,19 @@ public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> { return commonGroupingExpressions; } - /** - * getSortedVirtualSlots: order by virtual GROUPING_ID slot first. - */ - default Set<VirtualSlotReference> getSortedVirtualSlots() { - Set<VirtualSlotReference> virtualSlots = - ExpressionUtils.collect(getOutputExpressions(), VirtualSlotReference.class::isInstance); - - VirtualSlotReference virtualGroupingSetIdSlot = virtualSlots.stream() - .filter(slot -> slot.getName().equals(COL_GROUPING_ID)) - .findFirst() - .get(); - - return ImmutableSet.<VirtualSlotReference>builder() - .add(virtualGroupingSetIdSlot) - .addAll(Sets.difference(virtualSlots, ImmutableSet.of(virtualGroupingSetIdSlot))) - .build(); - } - /** * computeVirtualSlotValues. backend will fill this long value to the VirtualSlotRef */ - default List<List<Long>> computeVirtualSlotValues(Set<VirtualSlotReference> sortedVirtualSlots) { + default List<List<Long>> computeGroupingFunctionsValues() { GroupingSetShapes shapes = toShapes(); - - return sortedVirtualSlots.stream() - .map(virtualSlot -> virtualSlot.getComputeLongValueMethod().apply(shapes)) - .collect(ImmutableList.toImmutableList()); + List<GroupingScalarFunction> functions = ExpressionUtils.collectToList( + getOutputExpressions(), GroupingScalarFunction.class::isInstance); + List<List<Long>> groupingFunctionsValues = Lists.newArrayList(); + groupingFunctionsValues.add(shapes.computeGroupingIdValue()); + for (GroupingScalarFunction function : functions) { + groupingFunctionsValues.add(function.computeValue(shapes)); + } + return groupingFunctionsValues; } /** @@ -258,7 +224,7 @@ public interface Repeat<CHILD_PLAN extends Plan> extends Aggregate<CHILD_PLAN> { } /**compute a long value that backend need to fill to the GROUPING_ID slot*/ - public List<Long> computeVirtualGroupingIdValue() { + public List<Long> computeGroupingIdValue() { Set<Long> res = Sets.newLinkedHashSet(); long k = (long) Math.pow(2, flattenGroupingSetExpression.size()); for (GroupingSetShape shape : shapes) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java index f2b4a18e46b..5a099f5a4ad 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java @@ -23,7 +23,7 @@ import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.algebra.Repeat; @@ -40,6 +40,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * LogicalRepeat. @@ -52,6 +53,7 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T private final List<List<Expression>> groupingSets; private final List<NamedExpression> outputExpressions; + private final Optional<SlotReference> groupingId; private final boolean withInProjection; /** @@ -61,16 +63,38 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T List<List<Expression>> groupingSets, List<NamedExpression> outputExpressions, CHILD_TYPE child) { - this(groupingSets, outputExpressions, Optional.empty(), Optional.empty(), true, child); + this(groupingSets, outputExpressions, Optional.empty(), child); } /** * Desc: Constructor for LogicalRepeat. */ - public LogicalRepeat(List<List<Expression>> groupingSets, List<NamedExpression> outputExpressions, - Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, - boolean withInProjection, + public LogicalRepeat( + List<List<Expression>> groupingSets, + List<NamedExpression> outputExpressions, + SlotReference groupingId, + CHILD_TYPE child) { + this(groupingSets, outputExpressions, Optional.empty(), Optional.empty(), + Optional.ofNullable(groupingId), true, child); + } + + /** + * Desc: Constructor for LogicalRepeat. + */ + private LogicalRepeat( + List<List<Expression>> groupingSets, + List<NamedExpression> outputExpressions, + Optional<SlotReference> groupingId, CHILD_TYPE child) { + this(groupingSets, outputExpressions, Optional.empty(), Optional.empty(), groupingId, true, child); + } + + /** + * Desc: Constructor for LogicalRepeat. + */ + private LogicalRepeat(List<List<Expression>> groupingSets, List<NamedExpression> outputExpressions, + Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, + Optional<SlotReference> groupingId, boolean withInProjection, CHILD_TYPE child) { super(PlanType.LOGICAL_REPEAT, groupExpression, logicalProperties, child); this.groupingSets = Objects.requireNonNull(groupingSets, "groupingSets can not be null") .stream() @@ -78,6 +102,7 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T .collect(ImmutableList.toImmutableList()); this.outputExpressions = ImmutableList.copyOf( Objects.requireNonNull(outputExpressions, "outputExpressions can not be null")); + this.groupingId = groupingId; this.withInProjection = withInProjection; } @@ -91,6 +116,10 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T return outputExpressions; } + public Optional<SlotReference> getGroupingId() { + return groupingId; + } + @Override public List<NamedExpression> getOutputs() { return outputExpressions; @@ -100,7 +129,8 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T public String toString() { return Utils.toSqlString("LogicalRepeat", "groupingSets", groupingSets, - "outputExpressions", outputExpressions + "outputExpressions", outputExpressions, + "groupingId", groupingId ); } @@ -136,7 +166,7 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T @Override public List<Slot> computeOutput() { - return outputExpressions.stream() + return Stream.concat(outputExpressions.stream(), groupingId.map(Stream::of).orElse(Stream.empty())) .map(NamedExpression::toSlot) .collect(ImmutableList.toImmutableList()); } @@ -148,39 +178,37 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T @Override public List<? extends Expression> getExpressions() { - return new ImmutableList.Builder<Expression>() - .addAll(ExpressionUtils.flatExpressions(groupingSets)) - .addAll(outputExpressions) - .build(); + ImmutableList.Builder<Expression> builder = ImmutableList.builder(); + builder.addAll(ExpressionUtils.flatExpressions(groupingSets)).addAll(outputExpressions); + groupingId.ifPresent(builder::add); + return builder.build(); } @Override public boolean equals(Object o) { - if (this == o) { - return true; - } if (o == null || getClass() != o.getClass()) { return false; } LogicalRepeat<?> that = (LogicalRepeat<?>) o; - return groupingSets.equals(that.groupingSets) && outputExpressions.equals(that.outputExpressions); + return Objects.equals(groupingSets, that.groupingSets) && Objects.equals(outputExpressions, + that.outputExpressions) && Objects.equals(groupingId, that.groupingId); } @Override public int hashCode() { - return Objects.hash(groupingSets, outputExpressions); + return Objects.hash(groupingSets, outputExpressions, groupingId); } @Override public LogicalRepeat<Plan> withChildren(List<Plan> children) { Preconditions.checkArgument(children.size() == 1); - return new LogicalRepeat<>(groupingSets, outputExpressions, children.get(0)); + return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, children.get(0)); } @Override public LogicalRepeat<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) { return new LogicalRepeat<>(groupingSets, outputExpressions, groupExpression, - Optional.of(getLogicalProperties()), withInProjection, child()); + Optional.of(getLogicalProperties()), groupingId, withInProjection, child()); } @Override @@ -188,40 +216,35 @@ public class LogicalRepeat<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T Optional<LogicalProperties> logicalProperties, List<Plan> children) { Preconditions.checkArgument(children.size() == 1); return new LogicalRepeat<>(groupingSets, outputExpressions, groupExpression, logicalProperties, - withInProjection, children.get(0)); + groupingId, withInProjection, children.get(0)); } public LogicalRepeat<CHILD_TYPE> withGroupSets(List<List<Expression>> groupingSets) { - return new LogicalRepeat<>(groupingSets, outputExpressions, child()); + return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, child()); } public LogicalRepeat<CHILD_TYPE> withGroupSetsAndOutput(List<List<Expression>> groupingSets, List<NamedExpression> outputExpressionList) { - return new LogicalRepeat<>(groupingSets, outputExpressionList, child()); + return new LogicalRepeat<>(groupingSets, outputExpressionList, groupingId, child()); } @Override public LogicalRepeat<CHILD_TYPE> withAggOutput(List<NamedExpression> newOutput) { - return new LogicalRepeat<>(groupingSets, newOutput, child()); + return new LogicalRepeat<>(groupingSets, newOutput, groupingId, child()); } public LogicalRepeat<Plan> withNormalizedExpr(List<List<Expression>> groupingSets, - List<NamedExpression> outputExpressionList, Plan child) { - return new LogicalRepeat<>(groupingSets, outputExpressionList, child); + List<NamedExpression> outputExpressionList, SlotReference groupingId, Plan child) { + return new LogicalRepeat<>(groupingSets, outputExpressionList, groupingId, child); } public LogicalRepeat<Plan> withAggOutputAndChild(List<NamedExpression> newOutput, Plan child) { - return new LogicalRepeat<>(groupingSets, newOutput, child); + return new LogicalRepeat<>(groupingSets, newOutput, groupingId, child); } public LogicalRepeat<CHILD_TYPE> withInProjection(boolean withInProjection) { return new LogicalRepeat<>(groupingSets, outputExpressions, - Optional.empty(), Optional.empty(), withInProjection, child()); - } - - public boolean canBindVirtualSlot() { - return bound() && outputExpressions.stream() - .noneMatch(output -> output.containsType(VirtualSlotReference.class)); + Optional.empty(), Optional.empty(), groupingId, withInProjection, child()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java index a46c46e975c..95509540c43 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; @@ -403,10 +402,6 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar @Override public void computeUnique(DataTrait.Builder builder) { - if (groupByExpressions.stream().anyMatch(s -> s instanceof VirtualSlotReference)) { - // roll up may generate new data - return; - } DataTrait childFd = child(0).getLogicalProperties().getTrait(); ImmutableSet<Slot> groupByKeys = groupByExpressions.stream() .map(s -> (Slot) s) @@ -438,11 +433,6 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar // always propagate uniform DataTrait childFd = child(0).getLogicalProperties().getTrait(); builder.addUniformSlot(childFd); - - if (groupByExpressions.stream().anyMatch(s -> s instanceof VirtualSlotReference)) { - // roll up may generate new data - return; - } ImmutableSet<Slot> groupByKeys = groupByExpressions.stream() .map(s -> (Slot) s) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalRepeat.java index 8194667fb6d..7b841ac1fc6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalRepeat.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.algebra.Repeat; @@ -47,6 +48,7 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD private final List<List<Expression>> groupingSets; private final List<NamedExpression> outputExpressions; + private final SlotReference groupingId; /** * Desc: Constructor for PhysicalRepeat. @@ -54,6 +56,7 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD public PhysicalRepeat( List<List<Expression>> groupingSets, List<NamedExpression> outputExpressions, + SlotReference groupingId, LogicalProperties logicalProperties, CHILD_TYPE child) { super(PlanType.PHYSICAL_REPEAT, logicalProperties, child); @@ -63,12 +66,14 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD .collect(ImmutableList.toImmutableList()); this.outputExpressions = ImmutableList.copyOf( Objects.requireNonNull(outputExpressions, "outputExpressions can not be null")); + this.groupingId = Objects.requireNonNull(groupingId, "groupingId can not be null"); } /** * Desc: Constructor for PhysicalRepeat. */ - public PhysicalRepeat(List<List<Expression>> groupingSets, List<NamedExpression> outputExpressions, + private PhysicalRepeat(List<List<Expression>> groupingSets, List<NamedExpression> outputExpressions, + SlotReference groupingId, Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, PhysicalProperties physicalProperties, Statistics statistics, CHILD_TYPE child) { super(PlanType.PHYSICAL_REPEAT, groupExpression, logicalProperties, @@ -79,6 +84,7 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD .collect(ImmutableList.toImmutableList()); this.outputExpressions = ImmutableList.copyOf( Objects.requireNonNull(outputExpressions, "outputExpressions can not be null")); + this.groupingId = Objects.requireNonNull(groupingId, "groupingId can not be null"); } @Override @@ -91,6 +97,10 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD return outputExpressions; } + public SlotReference getGroupingId() { + return groupingId; + } + @Override public List<NamedExpression> getOutputs() { return outputExpressions; @@ -148,13 +158,13 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD @Override public PhysicalRepeat<Plan> withChildren(List<Plan> children) { Preconditions.checkArgument(children.size() == 1); - return new PhysicalRepeat<>(groupingSets, outputExpressions, groupExpression, + return new PhysicalRepeat<>(groupingSets, outputExpressions, groupingId, groupExpression, getLogicalProperties(), physicalProperties, statistics, children.get(0)); } @Override public PhysicalRepeat<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) { - return new PhysicalRepeat<>(groupingSets, outputExpressions, groupExpression, + return new PhysicalRepeat<>(groupingSets, outputExpressions, groupingId, groupExpression, getLogicalProperties(), physicalProperties, statistics, child()); } @@ -162,33 +172,33 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, List<Plan> children) { Preconditions.checkArgument(children.size() == 1); - return new PhysicalRepeat<>(groupingSets, outputExpressions, groupExpression, + return new PhysicalRepeat<>(groupingSets, outputExpressions, groupingId, groupExpression, logicalProperties.get(), physicalProperties, statistics, children.get(0)); } @Override public PhysicalRepeat<CHILD_TYPE> withPhysicalPropertiesAndStats(PhysicalProperties physicalProperties, Statistics statistics) { - return new PhysicalRepeat<>(groupingSets, outputExpressions, groupExpression, + return new PhysicalRepeat<>(groupingSets, outputExpressions, groupingId, groupExpression, getLogicalProperties(), physicalProperties, statistics, child()); } @Override public PhysicalRepeat<CHILD_TYPE> withAggOutput(List<NamedExpression> newOutput) { - return new PhysicalRepeat<>(groupingSets, newOutput, Optional.empty(), + return new PhysicalRepeat<>(groupingSets, newOutput, groupingId, Optional.empty(), getLogicalProperties(), physicalProperties, statistics, child()); } @Override public PhysicalRepeat<CHILD_TYPE> withGroupSetsAndOutput(List<List<Expression>> groupingSets, List<NamedExpression> outputExpressionList) { - return new PhysicalRepeat<>(groupingSets, outputExpressionList, Optional.empty(), + return new PhysicalRepeat<>(groupingSets, outputExpressionList, groupingId, Optional.empty(), getLogicalProperties(), physicalProperties, statistics, child()); } @Override public PhysicalRepeat<CHILD_TYPE> resetLogicalProperties() { - return new PhysicalRepeat<>(groupingSets, outputExpressions, groupExpression, + return new PhysicalRepeat<>(groupingSets, outputExpressions, groupingId, groupExpression, null, physicalProperties, statistics, child()); } @@ -199,7 +209,7 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD @Override public void computeUniform(DataTrait.Builder builder) { - builder.addUniformSlot(child(0).getLogicalProperties().getTrait()); + // roll up may generate new data } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 34afc3652dc..8f1ca1658de 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -913,7 +913,16 @@ public class ExpressionUtils { Predicate<TreeNode<Expression>> predicate) { ImmutableSet.Builder<E> set = ImmutableSet.builder(); for (Expression expr : expressions) { - set.addAll(expr.collectToList(predicate)); + set.addAll(expr.collect(predicate)); + } + return set.build(); + } + + public static <E> Set<E> collectWithTest(Collection<? extends Expression> expressions, + Predicate<TreeNode<Expression>> predicate, Predicate<TreeNode<Expression>> test) { + ImmutableSet.Builder<E> set = ImmutableSet.builder(); + for (Expression expr : expressions) { + set.addAll(expr.collectWithTest(predicate, test)); } return set.build(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/ResultFileSink.java b/fe/fe-core/src/main/java/org/apache/doris/planner/ResultFileSink.java index 9059d62e1d5..0ba8311097a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/ResultFileSink.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/ResultFileSink.java @@ -17,13 +17,9 @@ package org.apache.doris.planner; -import org.apache.doris.analysis.DescriptorTable; import org.apache.doris.analysis.OutFileClause; -import org.apache.doris.analysis.SlotDescriptor; import org.apache.doris.analysis.StorageBackend; -import org.apache.doris.analysis.TupleDescriptor; import org.apache.doris.analysis.TupleId; -import org.apache.doris.catalog.Column; import org.apache.doris.common.util.FileFormatConstants; import org.apache.doris.common.util.Util; import org.apache.doris.thrift.TDataSink; @@ -143,27 +139,4 @@ public class ResultFileSink extends DataSink { public DataPartition getOutputPartition() { return outputPartition; } - - /** - * Construct a tuple for file status, the tuple schema as following: - * | FileNumber | Int | - * | TotalRows | Bigint | - * | FileSize | Bigint | - * | URL | Varchar | - * | WriteTimeSec | Varchar | - * | WriteSpeedKB | Varchar | - */ - public static TupleDescriptor constructFileStatusTupleDesc(DescriptorTable descriptorTable) { - TupleDescriptor resultFileStatusTupleDesc = - descriptorTable.createTupleDescriptor("result_file_status"); - for (int i = 0; i < OutFileClause.RESULT_COL_NAMES.size(); ++i) { - SlotDescriptor slotDescriptor = descriptorTable.addSlotDescriptor(resultFileStatusTupleDesc); - slotDescriptor.setLabel(OutFileClause.RESULT_COL_NAMES.get(i)); - slotDescriptor.setType(OutFileClause.RESULT_COL_TYPES.get(i)); - slotDescriptor.setColumn(new Column(OutFileClause.RESULT_COL_NAMES.get(i), - OutFileClause.RESULT_COL_TYPES.get(i))); - slotDescriptor.setIsNullable(false); - } - return resultFileStatusTupleDesc; - } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java index e3f966e91ad..a68479421b7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java @@ -870,9 +870,12 @@ class ChildOutputPropertyDeriverTest { new ExprId(2), "c2", TinyIntType.INSTANCE, true, ImmutableList.of()); SlotReference c3 = new SlotReference( new ExprId(3), "c3", TinyIntType.INSTANCE, true, ImmutableList.of()); + SlotReference c4 = new SlotReference( + new ExprId(4), "c4", TinyIntType.INSTANCE, true, ImmutableList.of()); PhysicalRepeat<GroupPlan> repeat = new PhysicalRepeat<>( ImmutableList.of(ImmutableList.of(c1, c2), ImmutableList.of(c1), ImmutableList.of(c1, c3)), ImmutableList.of(c1, c2, c3), + c4, logicalProperties, groupPlan ); @@ -893,9 +896,12 @@ class ChildOutputPropertyDeriverTest { new ExprId(2), "c2", TinyIntType.INSTANCE, true, ImmutableList.of()); SlotReference c3 = new SlotReference( new ExprId(3), "c3", TinyIntType.INSTANCE, true, ImmutableList.of()); + SlotReference c4 = new SlotReference( + new ExprId(4), "c4", TinyIntType.INSTANCE, true, ImmutableList.of()); PhysicalRepeat<GroupPlan> repeat = new PhysicalRepeat<>( ImmutableList.of(ImmutableList.of(c1, c2), ImmutableList.of(c1), ImmutableList.of(c1, c3)), ImmutableList.of(c1, c2, c3), + c4, logicalProperties, groupPlan ); @@ -916,9 +922,12 @@ class ChildOutputPropertyDeriverTest { new ExprId(2), "c2", TinyIntType.INSTANCE, true, ImmutableList.of()); SlotReference c3 = new SlotReference( new ExprId(3), "c3", TinyIntType.INSTANCE, true, ImmutableList.of()); + SlotReference c4 = new SlotReference( + new ExprId(4), "c4", TinyIntType.INSTANCE, true, ImmutableList.of()); PhysicalRepeat<GroupPlan> repeat = new PhysicalRepeat<>( ImmutableList.of(ImmutableList.of(c1, c2, c3), ImmutableList.of(c1, c2), ImmutableList.of(c1, c2)), ImmutableList.of(c1, c2, c3), + c4, logicalProperties, groupPlan ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java index bc2dbe097f0..4ba8edfe34c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java @@ -20,10 +20,12 @@ package org.apache.doris.nereids.trees.copier; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; @@ -55,9 +57,11 @@ public class LogicalPlanDeepCopierTest { ImmutableList.of(groupingKeys.get(0)), ImmutableList.of() ); + SlotReference groupingId = new SlotReference("grouping_id", BigIntType.INSTANCE, false); LogicalRepeat<Plan> repeat = new LogicalRepeat<>( groupingSets, scan.getOutput().stream().map(NamedExpression.class::cast).collect(Collectors.toList()), + groupingId, scan ); List<? extends NamedExpression> groupByExprs = repeat.getOutput().subList(0, 1).stream() --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
