924060929 commented on code in PR #14827:
URL: https://github.com/apache/doris/pull/14827#discussion_r1050647059
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java:
##########
@@ -56,100 +52,86 @@
* After rule:
* Project(k1#1, Alias(SR#9)#4, Alias(k1#1 + 1)#5, Alias(SR#10))#6,
Alias(SR#11))#7, Alias(SR#10 + 1)#8)
* +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10,
Alias(SUM(v1#3 + 1))#11])
- * +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
+ * +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
* <p>
* More example could get from UT {NormalizeAggregateTest}
*/
-public class NormalizeAggregate extends OneRewriteRuleFactory {
+public class NormalizeAggregate extends OneRewriteRuleFactory implements
NormalizeToSlot {
@Override
public Rule build() {
return
logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
- // substitution map used to substitute expression in aggregate's
output to use it as top projections
- Map<Expression, Expression> substitutionMap = Maps.newHashMap();
- List<Expression> keys = aggregate.getGroupByExpressions();
- List<NamedExpression> newOutputs = Lists.newArrayList();
-
- // keys
- Map<Boolean, List<Expression>> partitionedKeys = keys.stream()
-
.collect(Collectors.groupingBy(SlotReference.class::isInstance));
- List<Expression> newKeys = Lists.newArrayList();
- List<NamedExpression> bottomProjections = Lists.newArrayList();
- if (partitionedKeys.containsKey(false)) {
- // process non-SlotReference keys
- newKeys.addAll(partitionedKeys.get(false).stream()
- .map(e -> new Alias(e, e.toSql()))
- .peek(a -> substitutionMap.put(a.child(), a.toSlot()))
- .peek(bottomProjections::add)
- .map(Alias::toSlot)
- .collect(Collectors.toList()));
- }
- if (partitionedKeys.containsKey(true)) {
- // process SlotReference keys
- partitionedKeys.get(true).stream()
- .map(SlotReference.class::cast)
- .peek(s -> substitutionMap.put(s, s))
- .peek(bottomProjections::add)
- .forEach(newKeys::add);
- }
- // add all necessary key to output
- substitutionMap.entrySet().stream()
- .filter(kv -> aggregate.getOutputExpressions().stream()
- .anyMatch(e -> e.anyMatch(kv.getKey()::equals)))
- .map(Entry::getValue)
- .map(NamedExpression.class::cast)
- .forEach(newOutputs::add);
-
- // if we generate bottom, we need to generate to project too.
- // output
- List<NamedExpression> outputs = aggregate.getOutputExpressions();
- Map<Boolean, List<NamedExpression>> partitionedOutputs =
outputs.stream()
- .collect(Collectors.groupingBy(e ->
e.anyMatch(AggregateFunction.class::isInstance)));
-
- boolean needBottomProjects = partitionedKeys.containsKey(false);
- if (partitionedOutputs.containsKey(true)) {
- // process expressions that contain aggregate function
- Set<AggregateFunction> aggregateFunctions =
partitionedOutputs.get(true).stream()
- .flatMap(e ->
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
- .collect(Collectors.toSet());
-
- // replace all non-slot expression in aggregate functions
children.
- for (AggregateFunction aggregateFunction : aggregateFunctions)
{
- List<Expression> newChildren = Lists.newArrayList();
- for (Expression child : aggregateFunction.getArguments()) {
- if (child instanceof SlotReference || child instanceof
Literal) {
- newChildren.add(child);
- if (child instanceof SlotReference) {
- bottomProjections.add((SlotReference) child);
- }
- } else {
- needBottomProjects = true;
- Alias alias = new Alias(child, child.toSql());
- bottomProjections.add(alias);
- newChildren.add(alias.toSlot());
- }
- }
- AggregateFunction newFunction = (AggregateFunction)
aggregateFunction.withChildren(newChildren);
- Alias alias = new Alias(newFunction, newFunction.toSql());
- newOutputs.add(alias);
- substitutionMap.put(aggregateFunction, alias.toSlot());
- }
- }
-
- // assemble
- LogicalPlan root = aggregate.child();
- if (needBottomProjects) {
- root = new LogicalProject<>(bottomProjections, root);
- }
- root = new LogicalAggregate<>(newKeys, newOutputs,
aggregate.isDisassembled(),
- true, aggregate.isFinalPhase(), aggregate.getAggPhase(),
- aggregate.getSourceRepeat(), root);
- List<NamedExpression> projections = outputs.stream()
- .map(e -> ExpressionUtils.replace(e, substitutionMap))
- .map(NamedExpression.class::cast)
- .collect(Collectors.toList());
- root = new LogicalProject<>(projections, root);
-
- return root;
+ // push expression to bottom project
+ Set<Alias> existsAliases = ExpressionUtils.collect(
Review Comment:
the result is
```
LogicalAggregate(groupBy=(slot#1, slot#2, slot#3), output=[slot#1, slot#2,
slot#3, sum(slot#4)])
|
LogicalProject(projects=[(a + 1)#1, (a + 2)#2, (a + 3)#3, b#4])
```
and
```
LogicalAggregate(groupBy=(slot#0), output=[slot#0, sum(slot#1), sum(slot#2),
sum(slot#3)])
|
LogicalProject(projects=[a#0, (b + 1)#1, (b + 2)#2, (b + 3)#3])
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]