This is an automated email from the ASF dual-hosted git repository.
englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new a1da57c63e [opt](Nereids)(WIP) optimize agg and window normalization
step 2 #19305
a1da57c63e is described below
commit a1da57c63ecef45a02b914ed377742925b76ed92
Author: Zhang Wenxin <[email protected]>
AuthorDate: Fri May 12 14:00:13 2023 +0800
[opt](Nereids)(WIP) optimize agg and window normalization step 2 #19305
1. refactor aggregate normalization to avoid data amplification before
aggregate
2. remove useless aggreagte processing in
ExtractAndNormalizeWindowExpression
3. only push distinct aggregate function children
TODO:
1. push down redundant expression in aggregate functions
2. refactor normalize repeat rule
3. move expression normalization and optimization after plan normalization
to avoid unexpected expression optimization.
---
.../doris/nereids/jobs/batch/NereidsRewriter.java | 26 +-
.../rules/analysis/ProjectToGlobalAggregate.java | 6 +-
.../rules/expression/rules/FunctionBinder.java | 1 +
.../rules/implementation/AggregateStrategies.java | 22 +-
.../rewrite/logical/EliminateGroupByConstant.java | 8 +-
.../ExtractAndNormalizeWindowExpression.java | 62 ++---
.../rules/rewrite/logical/NormalizeAggregate.java | 289 +++++++++------------
.../rules/rewrite/logical/NormalizeToSlot.java | 112 ++++++--
.../org/apache/doris/nereids/trees/TreeNode.java | 27 --
.../trees/expressions/WindowExpression.java | 7 -
.../functions/agg/AggregateFunction.java | 12 +-
.../ExtractAndNormalizeWindowExpressionTest.java | 2 +-
.../rewrite/logical/NormalizeAggregateTest.java | 9 +-
.../suites/nereids_syntax_p0/explain.groovy | 1 -
14 files changed, 289 insertions(+), 295 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
index 07c8903334..b43621562c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
@@ -84,27 +84,27 @@ import java.util.List;
*/
public class NereidsRewriter extends BatchRewriteJob {
private static final List<RewriteJob> REWRITE_JOBS = jobs(
- topic("Normalization",
+ topic("Plan Normalization",
topDown(
new EliminateOrderByConstant(),
new EliminateGroupByConstant(),
-
// MergeProjects depends on this rule
new LogicalSubQueryAliasToLogicalProject(),
-
- // rewrite expressions, no depends
+ // TODO: we should do expression normalization after plan
normalization
+ // because some rewritten depends on sub expression tree
matching
+ // such as group by key matching and replaced
+ // but we need to do some normalization before subquery
unnesting,
+ // such as extract common expression.
new ExpressionNormalization(),
new ExpressionOptimization(),
new AvgDistinctToSumDivCount(),
new CountDistinctRewrite(),
-
new ExtractFilterFromCrossJoin()
),
-
- // ExtractSingleTableExpressionFromDisjunction conflict to
InPredicateToEqualToRule
- // in the ExpressionNormalization, so must invoke in another
job, or else run into
- // dead loop
topDown(
+ // ExtractSingleTableExpressionFromDisjunction conflict to
InPredicateToEqualToRule
+ // in the ExpressionNormalization, so must invoke in
another job, or else run into
+ // dead loop
new ExtractSingleTableExpressionFromDisjunction()
)
),
@@ -131,15 +131,15 @@ public class NereidsRewriter extends BatchRewriteJob {
)
),
+ // we should eliminate hint again because some hint maybe exist in
the CTE or subquery.
+ // so this rule should invoke after "Subquery unnesting"
+ custom(RuleType.ELIMINATE_HINT, EliminateLogicalSelectHint::new),
+
// please note: this rule must run before NormalizeAggregate
topDown(
new AdjustAggregateNullableForEmptySet()
),
- // we should eliminate hint again because some hint maybe exist in
the CTE or subquery.
- // so this rule should invoke after "Subquery unnesting"
- custom(RuleType.ELIMINATE_HINT, EliminateLogicalSelectHint::new),
-
// The rule modification needs to be done after the subquery is
unnested,
// because for scalarSubQuery, the connection condition is stored
in apply in the analyzer phase,
// but when normalizeAggregate/normalizeSort is performed, the
members in apply cannot be obtained,
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java
index a4cf1d1a8c..66371ae000 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java
@@ -49,7 +49,7 @@ public class ProjectToGlobalAggregate extends
OneAnalysisRuleFactory {
logicalProject().then(project -> {
boolean needGlobalAggregate = project.getProjects()
.stream()
- .anyMatch(p -> p.accept(NeedAggregateChecker.INSTANCE,
null));
+ .anyMatch(p ->
p.accept(ContainsAggregateChecker.INSTANCE, null));
if (needGlobalAggregate) {
return new LogicalAggregate<>(ImmutableList.of(),
project.getProjects(), project.child());
@@ -60,9 +60,9 @@ public class ProjectToGlobalAggregate extends
OneAnalysisRuleFactory {
);
}
- private static class NeedAggregateChecker extends
DefaultExpressionVisitor<Boolean, Void> {
+ private static class ContainsAggregateChecker extends
DefaultExpressionVisitor<Boolean, Void> {
- private static final NeedAggregateChecker INSTANCE = new
NeedAggregateChecker();
+ private static final ContainsAggregateChecker INSTANCE = new
ContainsAggregateChecker();
@Override
public Boolean visit(Expression expr, Void context) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java
index cc64666e60..c2c48de051 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FunctionBinder.java
@@ -58,6 +58,7 @@ import java.util.stream.Collectors;
* function binder
*/
public class FunctionBinder extends AbstractExpressionRewriteRule {
+
public static final FunctionBinder INSTANCE = new FunctionBinder();
@Override
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index 50bbbc0b20..1a1f085c5c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@@ -215,6 +216,25 @@ public class AggregateStrategies implements
ImplementationRuleFactory {
return canNotPush;
}
+ // TODO: refactor this to process slot reference or expression together
+ boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
+ .map(ExpressionTrait::getArguments)
+ .flatMap(List::stream)
+ .allMatch(argument -> {
+ if (argument instanceof SlotReference) {
+ return true;
+ }
+ if (argument instanceof Cast) {
+ return argument.child(0) instanceof SlotReference
+ && argument.getDataType().isNumericType()
+ &&
argument.child(0).getDataType().isNumericType();
+ }
+ return false;
+ });
+ if (!onlyContainsSlotOrNumericCastSlot) {
+ return canNotPush;
+ }
+
// we already normalize the arguments to slotReference
List<Expression> argumentsOfAggregateFunction =
aggregateFunctions.stream()
.flatMap(aggregateFunction ->
aggregateFunction.getArguments().stream())
@@ -228,7 +248,7 @@ public class AggregateStrategies implements
ImplementationRuleFactory {
.collect(ImmutableList.toImmutableList());
}
- boolean onlyContainsSlotOrNumericCastSlot =
argumentsOfAggregateFunction
+ onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
.stream()
.allMatch(argument -> {
if (argument instanceof SlotReference) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstant.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstant.java
index a9f3650cb4..5ba3689edd 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstant.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateGroupByConstant.java
@@ -57,8 +57,12 @@ public class EliminateGroupByConstant extends
OneRewriteRuleFactory {
Set<Expression> slotGroupByExprs = Sets.newLinkedHashSet();
Expression lit = null;
for (Expression expression : groupByExprs) {
- expression = FoldConstantRule.INSTANCE.rewrite(expression,
context);
- if (!(expression instanceof Literal)) {
+ // NOTICE: we should not use the expression after fold as new
aggregate's output or group expr
+ // because we rely on expression matching to replace subtree
that same as group by expr in output
+ // if we do constant folding before normalize aggregate, the
subtree will change and matching fail
+ // such as: select a + 1 + 2 + 3, sum(b) from t group by a +
1 + 2
+ Expression foldExpression =
FoldConstantRule.INSTANCE.rewrite(expression, context);
+ if (!(foldExpression instanceof Literal)) {
slotGroupByExprs.add(expression);
} else {
lit = expression;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
index 9282ef3825..da972816e6 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
@@ -25,9 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
-import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.util.ExpressionUtils;
@@ -41,7 +39,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
- * extract window expressions from LogicalProject.projects and Normalize
LogicalWindow
+ * extract window expressions from LogicalProject#projects and Normalize
LogicalWindow
*/
public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory
implements NormalizeToSlot {
@@ -60,15 +58,8 @@ public class ExtractAndNormalizeWindowExpression extends
OneRewriteRuleFactory i
if (bottomProjects.isEmpty()) {
normalizedChild = project.child();
} else {
- boolean needAggregate = bottomProjects.stream().anyMatch(expr
->
- expr.anyMatch(AggregateFunction.class::isInstance));
- if (needAggregate) {
- normalizedChild = new
LogicalAggregate<>(ImmutableList.of(),
- ImmutableList.copyOf(bottomProjects),
project.child());
- } else {
- normalizedChild = project.withProjectsAndChild(
- ImmutableList.copyOf(bottomProjects),
project.child());
- }
+ normalizedChild = project.withProjectsAndChild(
+ ImmutableList.copyOf(bottomProjects), project.child());
}
// 2. handle window's outputs and windowExprs
@@ -96,35 +87,32 @@ public class ExtractAndNormalizeWindowExpression extends
OneRewriteRuleFactory i
// bottomProjects includes:
// 1. expressions from function and WindowSpec's partitionKeys and
orderKeys
// 2. other slots of outputExpressions
- /*
- avg(c) / sum(a+1) over (order by avg(b)) group by a
- win(x/sum(z) over y)
- prj(x, y, a+1 as z)
- agg(avg(c) x, avg(b) y, a)
- proj(a b c)
- toBePushDown = {avg(c), a+1, avg(b)}
- */
+ //
+ // avg(c) / sum(a+1) over (order by avg(b)) group by a
+ // win(x/sum(z) over y)
+ // prj(x, y, a+1 as z)
+ // agg(avg(c) x, avg(b) y, a)
+ // proj(a b c)
+ // toBePushDown = {avg(c), a+1, avg(b)}
return expressions.stream()
.flatMap(expression -> {
if (expression.anyMatch(WindowExpression.class::isInstance)) {
- Set<Slot> inputSlots =
expression.getInputSlots().stream().collect(Collectors.toSet());
+ Set<Slot> inputSlots =
Sets.newHashSet(expression.getInputSlots());
Set<WindowExpression> collects =
expression.collect(WindowExpression.class::isInstance);
- Set<Slot> windowInputSlots = collects.stream().flatMap(
- win -> win.getInputSlots().stream()
- ).collect(Collectors.toSet());
- /*
- substr(
- ref_1.cp_type,
- max(
- cast(ref_1.`cp_catalog_page_number` as int)) over
(...)
- ),
- 1)
-
- in above case, ref_1.cp_type should be pushed down.
ref_1.cp_type is in
- substr.inputSlots, but not in windowExpression.inputSlots
-
- inputSlots= {ref_1.cp_type}
- */
+ Set<Slot> windowInputSlots = collects.stream()
+ .flatMap(win -> win.getInputSlots().stream())
+ .collect(Collectors.toSet());
+ // substr(
+ // ref_1.cp_type,
+ // max(
+ // cast(ref_1.`cp_catalog_page_number` as int)) over
(...)
+ // ),
+ // 1)
+ //
+ // in above case, ref_1.cp_type should be pushed down.
ref_1.cp_type is in
+ // substr.inputSlots, but not in
windowExpression.inputSlots
+ //
+ // inputSlots= {ref_1.cp_type}
inputSlots.removeAll(windowInputSlots);
return Stream.concat(
collects.stream().flatMap(windowExpression ->
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
index fccd933094..b9859a14cf 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -20,14 +20,14 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
-import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@@ -36,12 +36,12 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
+import com.google.common.collect.Maps;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
-import java.util.stream.Stream;
/**
* normalize aggregate's group keys and AggregateFunction's child to
SlotReference
@@ -95,173 +95,144 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
@Override
public Rule build() {
return
logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
- // push expression to bottom project
- Set<Alias> existsAliases = ExpressionUtils.mutableCollect(
- aggregate.getOutputExpressions(), Alias.class::isInstance);
- Set<AggregateFunction> aggregateFunctionsInWindow =
collectAggregateFunctionsInWindow(
- aggregate.getOutputExpressions());
- Set<Expression> existsAggAlias =
existsAliases.stream().map(UnaryNode::child)
- .filter(AggregateFunction.class::isInstance)
- .collect(Collectors.toSet());
-
- /*
- * agg-functions inside window function is regarded as an output
of aggregate.
- * select sum(avg(c)) over ...
- * is regarded as
- * select avg(c), sum(avg(c)) over ...
- *
- * the plan:
- * project(sum(y) over)
- * Aggregate(avg(c) as y)
- *
- * after Aggregate, the 'y' is removed by upper project.
- *
- * aliasOfAggFunInWindowUsedAsAggOutput = {alias(avg(c))}
- */
- List<Alias> aliasOfAggFunInWindowUsedAsAggOutput =
Lists.newArrayList();
- for (AggregateFunction aggFun : aggregateFunctionsInWindow) {
- if (!existsAggAlias.contains(aggFun)) {
- Alias alias = new Alias(aggFun, aggFun.toSql());
- existsAliases.add(alias);
- aliasOfAggFunInWindowUsedAsAggOutput.add(alias);
+ List<NamedExpression> aggregateOutput =
aggregate.getOutputExpressions();
+ Set<Alias> existsAlias =
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
+ Set<Expression> groupingByExprs =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
+ NormalizeToSlotContext groupByToSlotContext =
+ NormalizeToSlotContext.buildContext(existsAlias,
groupingByExprs);
+ Set<NamedExpression> bottomGroupByProjects =
+
groupByToSlotContext.pushDownToNamedExpression(groupingByExprs);
+
+ List<AggregateFunction> aggFuncs = Lists.newArrayList();
+ aggregateOutput.forEach(o ->
o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
+ // use group by context to normalize agg functions to process
+ // sql like: select sum(a + 1) from t group by a + 1
+ //
+ // before normalize:
+ // agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1])
+ // +-- project(a[#0], (a[#0] + 1)[#1])
+ //
+ // after normalize:
+ // agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a +
1)[#1])
+ // +-- project((a[#0] + 1)[#1])
+ List<AggregateFunction> normalizedAggFuncs =
groupByToSlotContext.normalizeToUseSlotRef(aggFuncs);
+ List<NamedExpression> bottomProjects =
Lists.newArrayList(bottomGroupByProjects);
+ // TODO: if we have distinct agg, we must push down its children,
+ // because need use it to generate distribution enforce
+ // step 1: split agg functions into 2 parts: distinct and not
distinct
+ List<AggregateFunction> distinctAggFuncs = Lists.newArrayList();
+ List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList();
+ for (AggregateFunction aggregateFunction : normalizedAggFuncs) {
+ if (aggregateFunction.isDistinct()) {
+ distinctAggFuncs.add(aggregateFunction);
+ } else {
+ nonDistinctAggFuncs.add(aggregateFunction);
}
}
- Set<Expression> needToSlots =
collectGroupByAndArgumentsOfAggregateFunctions(aggregate);
- NormalizeToSlotContext groupByAndArgumentToSlotContext =
- NormalizeToSlotContext.buildContext(existsAliases,
needToSlots);
- Set<NamedExpression> bottomProjects =
-
groupByAndArgumentToSlotContext.pushDownToNamedExpression(needToSlots);
- Plan normalizedChild = bottomProjects.isEmpty()
- ? aggregate.child()
- : new
LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
-
- // begin normalize aggregate
-
- // replace groupBy and arguments of aggregate function to slot,
may be this output contains
- // some expression on the aggregate functions, e.g. `sum(value) +
1`, we should replace
- // the sum(value) to slot and move the `slot + 1` to the upper
project later.
- List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
- aggregate.getOutputExpressions().stream(),
- aliasOfAggFunInWindowUsedAsAggOutput.stream())
- .map(expr -> groupByAndArgumentToSlotContext
- .normalizeToUseSlotRefUp(expr,
WindowExpression.class::isInstance))
- .collect(Collectors.toList());
-
- Set<Slot> windowInputSlots =
collectWindowInputSlots(aggregate.getOutputExpressions());
- Set<Expression> itemsInWindow = Sets.newHashSet(windowInputSlots);
- itemsInWindow.addAll(aggregateFunctionsInWindow);
- NormalizeToSlotContext windowToSlotContext =
- NormalizeToSlotContext.buildContext(existsAliases,
itemsInWindow);
- normalizeOutputPhase1 = normalizeOutputPhase1.stream()
- .map(expr -> windowToSlotContext
- .normalizeToUseSlotRefDown(expr,
WindowExpression.class::isInstance, true))
- .collect(Collectors.toList());
-
- Set<AggregateFunction> normalizedAggregateFunctions =
-
collectNonWindowedAggregateFunctions(normalizeOutputPhase1);
-
- existsAliases = ExpressionUtils.collect(normalizeOutputPhase1,
Alias.class::isInstance);
-
- // now reuse the exists alias for the aggregate functions,
- // or create new alias for the aggregate functions
- NormalizeToSlotContext aggregateFunctionToSlotContext =
- NormalizeToSlotContext.buildContext(existsAliases,
normalizedAggregateFunctions);
-
- Set<NamedExpression> normalizedAggregateFunctionsWithAlias =
-
aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions);
-
- List<Slot> normalizedGroupBy =
- (List) groupByAndArgumentToSlotContext
-
.normalizeToUseSlotRef(aggregate.getGroupByExpressions());
-
- // we can safely add all groupBy and aggregate functions to
output, because we will
- // add a project on it, and the upper project can protect the
scope of visible of slot
- List<NamedExpression> normalizedAggregateOutput =
ImmutableList.<NamedExpression>builder()
- .addAll(normalizedGroupBy)
- .addAll(normalizedAggregateFunctionsWithAlias)
+ // step 2: if we only have one distinct agg function, we do push
down for it
+ if (!distinctAggFuncs.isEmpty()) {
+ // process distinct normalize and put it back to
normalizedAggFuncs
+ List<AggregateFunction> newDistinctAggFuncs =
Lists.newArrayList();
+ Map<Expression, Expression> replaceMap = Maps.newHashMap();
+ Map<Expression, NamedExpression> aliasCache =
Maps.newHashMap();
+ for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
+ List<Expression> newChildren = Lists.newArrayList();
+ for (Expression child : distinctAggFunc.children()) {
+ if (child instanceof SlotReference) {
+ newChildren.add(child);
+ } else {
+ NamedExpression alias;
+ if (aliasCache.containsKey(child)) {
+ alias = aliasCache.get(child);
+ } else {
+ alias = new Alias(child, child.toSql());
+ aliasCache.put(child, alias);
+ }
+ bottomProjects.add(alias);
+ newChildren.add(alias.toSlot());
+ }
+ }
+ AggregateFunction newDistinctAggFunc =
distinctAggFunc.withChildren(newChildren);
+ replaceMap.put(distinctAggFunc, newDistinctAggFunc);
+ newDistinctAggFuncs.add(newDistinctAggFunc);
+ }
+ aggregateOutput = aggregateOutput.stream()
+ .map(e -> ExpressionUtils.replace(e, replaceMap))
+ .map(NamedExpression.class::cast)
+ .collect(Collectors.toList());
+ distinctAggFuncs = newDistinctAggFuncs;
+ }
+ normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs);
+ normalizedAggFuncs.addAll(distinctAggFuncs);
+ // TODO: process redundant expressions in aggregate functions
children
+ // build normalized agg output
+ NormalizeToSlotContext normalizedAggFuncsToSlotContext =
+ NormalizeToSlotContext.buildContext(existsAlias,
normalizedAggFuncs);
+ // agg output include 2 part, normalized group by slots and
normalized agg functions
+ List<NamedExpression> normalizedAggOutput =
ImmutableList.<NamedExpression>builder()
+
.addAll(bottomGroupByProjects.stream().map(NamedExpression::toSlot).iterator())
+
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
.build();
+ // add normalized agg's input slots to bottom projects
+ Set<Slot> bottomProjectSlots = bottomProjects.stream()
+ .map(NamedExpression::toSlot)
+ .collect(Collectors.toSet());
+ Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream()
+ .map(Expression::getInputSlots)
+ .flatMap(Set::stream)
+ .filter(e -> !bottomProjectSlots.contains(e))
+ .collect(Collectors.toSet());
+ bottomProjects.addAll(aggInputSlots);
+ // build group by exprs
+ List<Expression> normalizedGroupExprs =
groupByToSlotContext.normalizeToUseSlotRef(groupingByExprs);
+ // build upper project, use two context to do pop up, because agg
output maybe contain two part:
+ // group by keys and agg expressions
+ List<NamedExpression> upperProjects = groupByToSlotContext
+
.normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput);
+ upperProjects =
normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects);
+ // process Expression like Alias(SlotReference#0)#0
+ upperProjects = upperProjects.stream().map(e -> {
+ if (e instanceof Alias) {
+ Alias alias = (Alias) e;
+ if (alias.child() instanceof SlotReference) {
+ SlotReference slotReference = (SlotReference)
alias.child();
+ if
(slotReference.getExprId().equals(alias.getExprId())) {
+ return slotReference;
+ }
+ }
+ }
+ return e;
+ }).collect(Collectors.toList());
+
+ Plan bottomPlan;
+ if (!bottomProjects.isEmpty()) {
+ bottomPlan = new
LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
+ } else {
+ bottomPlan = aggregate.child();
+ }
- LogicalAggregate<Plan> normalizedAggregate =
aggregate.withNormalized(
- (List) normalizedGroupBy, normalizedAggregateOutput,
normalizedChild);
-
-
normalizeOutputPhase1.removeAll(aliasOfAggFunInWindowUsedAsAggOutput);
- // exclude same-name functions in WindowExpression
- List<NamedExpression> upperProjects =
normalizeOutputPhase1.stream()
-
.map(aggregateFunctionToSlotContext::normalizeToUseSlotRef).collect(Collectors.toList());
- return new LogicalProject<>(upperProjects, normalizedAggregate);
+ return new LogicalProject<>(upperProjects,
+ aggregate.withNormalized(normalizedGroupExprs,
normalizedAggOutput, bottomPlan));
}).toRule(RuleType.NORMALIZE_AGGREGATE);
}
- private Set<Expression>
collectGroupByAndArgumentsOfAggregateFunctions(LogicalAggregate<? extends Plan>
aggregate) {
- // 2 parts need push down:
- // groupingByExpr, argumentsOfAggregateFunction
-
- Set<Expression> groupingByExpr =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
-
- Set<AggregateFunction> aggregateFunctions =
collectNonWindowedAggregateFunctions(
- aggregate.getOutputExpressions());
+ private static class CollectNonWindowedAggFuncs extends
DefaultExpressionVisitor<Void, List<AggregateFunction>> {
- Set<Expression> argumentsOfAggregateFunction =
aggregateFunctions.stream()
- .flatMap(function -> function.getArguments().stream()
- .map(expr -> expr instanceof OrderExpression ?
expr.child(0) : expr))
- .collect(ImmutableSet.toImmutableSet());
+ private static final CollectNonWindowedAggFuncs INSTANCE = new
CollectNonWindowedAggFuncs();
- Set<Expression> windowFunctionKeys =
collectWindowFunctionKeys(aggregate.getOutputExpressions());
-
- Set<Expression> needPushDown = ImmutableSet.<Expression>builder()
- // group by should be pushed down, e.g. group by (k + 1),
- // we should push down the `k + 1` to the bottom plan
- .addAll(groupingByExpr)
- // e.g. sum(k + 1), we should push down the `k + 1` to the
bottom plan
- .addAll(argumentsOfAggregateFunction)
- .addAll(windowFunctionKeys)
- .build();
- return needPushDown;
- }
-
- private Set<Expression> collectWindowFunctionKeys(List<NamedExpression>
aggOutput) {
- Set<Expression> windowInputs = Sets.newHashSet();
- for (Expression expr : aggOutput) {
- Set<WindowExpression> windows =
expr.collect(WindowExpression.class::isInstance);
- for (WindowExpression win : windows) {
- windowInputs.addAll(win.getPartitionKeys().stream().flatMap(pk
-> pk.getInputSlots().stream()).collect(
- Collectors.toList()));
- windowInputs.addAll(win.getOrderKeys().stream().flatMap(ok ->
ok.getInputSlots().stream()).collect(
- Collectors.toList()));
+ @Override
+ public Void visitWindow(WindowExpression windowExpression,
List<AggregateFunction> context) {
+ for (Expression child :
windowExpression.getExpressionsInWindowSpec()) {
+ child.accept(this, context);
}
+ return null;
}
- return windowInputs;
- }
- /**
- * select sum(c2), avg(min(c2)) over (partition by max(c1) order by
count(c1)) from T ...
- * extract {sum, min, max, count}. avg is not extracted.
- */
- private Set<AggregateFunction>
collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) {
- return ExpressionUtils.collect(aggOutput, expr -> {
- if (expr instanceof AggregateFunction) {
- return !((AggregateFunction) expr).isWindowFunction();
- }
- return false;
- });
- }
-
- private Set<AggregateFunction>
collectAggregateFunctionsInWindow(List<NamedExpression> aggOutput) {
-
- List<WindowExpression> windows = Lists.newArrayList(
- ExpressionUtils.collect(aggOutput,
WindowExpression.class::isInstance));
- return ExpressionUtils.collect(windows, expr -> {
- if (expr instanceof AggregateFunction) {
- return !((AggregateFunction) expr).isWindowFunction();
- }
- return false;
- });
- }
-
- private Set<Slot> collectWindowInputSlots(List<NamedExpression> aggOutput)
{
- List<WindowExpression> windows = Lists.newArrayList(
- ExpressionUtils.collect(aggOutput,
WindowExpression.class::isInstance));
- return windows.stream().flatMap(win ->
win.getInputSlots().stream()).collect(Collectors.toSet());
+ @Override
+ public Void visitAggregateFunction(AggregateFunction
aggregateFunction, List<AggregateFunction> context) {
+ context.add(aggregateFunction);
+ return null;
+ }
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
index 8ef966496e..974655a80b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
@@ -21,17 +21,20 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
+import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
-import java.util.function.Predicate;
+import java.util.stream.Collectors;
import javax.annotation.Nullable;
/** NormalizeToSlot */
@@ -45,9 +48,16 @@ public interface NormalizeToSlot {
this.normalizeToSlotMap = normalizeToSlotMap;
}
- /** buildContext */
+ /**
+ * build normalization context by follow step.
+ * 1. collect all exists alias by input parameters existsAliases
build a reverted map: expr -> alias
+ * 2. for all input source expressions, use existsAliasMap to
construct triple:
+ * origin expr, pushed expr and alias to replace origin expr,
+ * see more detail in {@link NormalizeToSlotTriplet}
+ * 3. construct a map: original expr -> triple constructed by step 2
+ */
public static NormalizeToSlotContext buildContext(
- Set<Alias> existsAliases, Set<? extends Expression>
sourceExpressions) {
+ Set<Alias> existsAliases, Collection<? extends Expression>
sourceExpressions) {
Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap =
Maps.newLinkedHashMap();
Map<Expression, Alias> existsAliasMap = Maps.newLinkedHashMap();
@@ -70,13 +80,21 @@ public interface NormalizeToSlot {
return normalizeToUseSlotRef(ImmutableList.of(expression)).get(0);
}
- /** normalizeToUseSlotRef, no custom normalize */
- public <E extends Expression> List<E> normalizeToUseSlotRef(List<E>
expressions) {
+ /**
+ * normalizeToUseSlotRef, no custom normalize.
+ * This function use a lambda that always return original expression
as customNormalize
+ * So always use normalizeToSlotMap to process normalization when we
call this function
+ */
+ public <E extends Expression> List<E>
normalizeToUseSlotRef(Collection<E> expressions) {
return normalizeToUseSlotRef(expressions, (context, expr) -> expr);
}
- /** normalizeToUseSlotRef */
- public <E extends Expression> List<E> normalizeToUseSlotRef(List<E>
expressions,
+ /**
+ * normalizeToUseSlotRef.
+ * try to use customNormalize do normalization first. if
customNormalize cannot handle current expression,
+ * use normalizeToSlotMap to get the default replaced expression.
+ */
+ public <E extends Expression> List<E>
normalizeToUseSlotRef(Collection<E> expressions,
BiFunction<NormalizeToSlotContext, Expression, Expression>
customNormalize) {
return expressions.stream()
.map(expr -> (E) expr.rewriteDownShortCircuit(child -> {
@@ -89,22 +107,11 @@ public interface NormalizeToSlot {
})).collect(ImmutableList.toImmutableList());
}
- public <E extends Expression> E normalizeToUseSlotRefUp(E expression,
Predicate skip) {
- return (E) expression.rewriteDownShortCircuitUp(child -> {
- NormalizeToSlotTriplet normalizeToSlotTriplet =
normalizeToSlotMap.get(child);
- return normalizeToSlotTriplet == null ? child :
normalizeToSlotTriplet.remainExpr;
- }, skip);
- }
-
- /**
- * rewrite subtrees whose root matches predicate border
- * when we traverse to the node satisfies border predicate,
aboveBorder becomes false
- */
- public <E extends Expression> E normalizeToUseSlotRefDown(E
expression, Predicate border, boolean aboveBorder) {
- return (E) expression.rewriteDownShortCircuitDown(child -> {
- NormalizeToSlotTriplet normalizeToSlotTriplet =
normalizeToSlotMap.get(child);
- return normalizeToSlotTriplet == null ? child :
normalizeToSlotTriplet.remainExpr;
- }, border, aboveBorder);
+ public <E extends Expression> List<E>
normalizeToUseSlotRefWithoutWindowFunction(
+ Collection<E> expressions) {
+ return expressions.stream()
+ .map(e -> (E)
e.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap))
+ .collect(Collectors.toList());
}
/**
@@ -124,6 +131,54 @@ public interface NormalizeToSlot {
}
}
+ /**
+ * replace any expression except window function.
+ * because the window function could be same with aggregate function and
should never be replaced.
+ */
+ class NormalizeWithoutWindowFunction
+ extends DefaultExpressionRewriter<Map<Expression,
NormalizeToSlotTriplet>> {
+
+ public static final NormalizeWithoutWindowFunction INSTANCE = new
NormalizeWithoutWindowFunction();
+
+ private NormalizeWithoutWindowFunction() {
+ }
+
+ @Override
+ public Expression visit(Expression expr, Map<Expression,
NormalizeToSlotTriplet> replaceMap) {
+ if (replaceMap.containsKey(expr)) {
+ return replaceMap.get(expr).remainExpr;
+ }
+ return super.visit(expr, replaceMap);
+ }
+
+ @Override
+ public Expression visitWindow(WindowExpression windowExpression,
+ Map<Expression, NormalizeToSlotTriplet> replaceMap) {
+ if (replaceMap.containsKey(windowExpression)) {
+ return replaceMap.get(windowExpression).remainExpr;
+ }
+ List<Expression> newChildren = new ArrayList<>();
+ Expression function = super.visit(windowExpression.getFunction(),
replaceMap);
+ newChildren.add(function);
+ boolean hasNewChildren = function !=
windowExpression.getFunction();
+ for (Expression partitionKey :
windowExpression.getPartitionKeys()) {
+ Expression newChild = partitionKey.accept(this, replaceMap);
+ if (newChild != partitionKey) {
+ hasNewChildren = true;
+ }
+ newChildren.add(newChild);
+ }
+ for (Expression orderKey : windowExpression.getOrderKeys()) {
+ Expression newChild = orderKey.accept(this, replaceMap);
+ if (newChild != orderKey) {
+ hasNewChildren = true;
+ }
+ newChildren.add(newChild);
+ }
+ return hasNewChildren ? windowExpression.withChildren(newChildren)
: windowExpression;
+ }
+ }
+
/** NormalizeToSlotTriplet */
class NormalizeToSlotTriplet {
// which expression need to normalized to slot?
@@ -142,7 +197,12 @@ public interface NormalizeToSlot {
this.pushedExpr = pushedExpr;
}
- /** toTriplet */
+ /**
+ * construct triplet by three conditions.
+ * 1. already has exists alias: use this alias as pushed expr
+ * 2. expression is {@link NamedExpression}, use itself as pushed expr
+ * 3. other expression, construct a new Alias contains current
expression as pushed expr
+ */
public static NormalizeToSlotTriplet toTriplet(Expression expression,
@Nullable Alias existsAlias) {
if (existsAlias != null) {
return new NormalizeToSlotTriplet(expression,
existsAlias.toSlot(), existsAlias);
@@ -150,9 +210,7 @@ public interface NormalizeToSlot {
if (expression instanceof NamedExpression) {
NamedExpression namedExpression = (NamedExpression) expression;
- NormalizeToSlotTriplet normalizeToSlotTriplet =
- new NormalizeToSlotTriplet(expression,
namedExpression.toSlot(), namedExpression);
- return normalizeToSlotTriplet;
+ return new NormalizeToSlotTriplet(expression,
namedExpression.toSlot(), namedExpression);
}
Alias alias = new Alias(expression, expression.toSql());
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
index 0394ebea87..3c64a043d6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
@@ -96,33 +96,6 @@ public interface TreeNode<NODE_TYPE extends
TreeNode<NODE_TYPE>> {
return currentNode;
}
- /**
- * same as rewriteDownShortCircuit,
- * except that subtrees, whose root satisfies predicate is satisfied, are
not rewritten
- */
- default NODE_TYPE rewriteDownShortCircuitUp(Function<NODE_TYPE, NODE_TYPE>
rewriteFunction, Predicate skip) {
- NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this);
- if (skip.test(currentNode)) {
- return currentNode;
- }
- if (currentNode == this) {
- Builder<NODE_TYPE> newChildren =
ImmutableList.builderWithExpectedSize(arity());
- boolean changed = false;
- for (NODE_TYPE child : children()) {
- NODE_TYPE newChild =
child.rewriteDownShortCircuitUp(rewriteFunction, skip);
- if (child != newChild) {
- changed = true;
- }
- newChildren.add(newChild);
- }
-
- if (changed) {
- currentNode = currentNode.withChildren(newChildren.build());
- }
- }
- return currentNode;
- }
-
/**
* similar to rewriteDownShortCircuit, except that only subtrees, whose
root satisfies
* border predicate are rewritten.
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
index ffc0522498..831acbf5a8 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
@@ -19,7 +19,6 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.UnaryNode;
-import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@@ -55,9 +54,6 @@ public class WindowExpression extends Expression {
.addAll(orderKeys)
.build().toArray(new Expression[0]));
this.function = function;
- if (function instanceof AggregateFunction) {
- ((AggregateFunction) function).setWindowFunction(true);
- }
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
this.orderKeys = ImmutableList.copyOf(orderKeys);
this.windowFrame = Optional.empty();
@@ -73,9 +69,6 @@ public class WindowExpression extends Expression {
.add(windowFrame)
.build().toArray(new Expression[0]));
this.function = function;
- if (function instanceof AggregateFunction) {
- ((AggregateFunction) function).setWindowFunction(true);
- }
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
this.orderKeys = ImmutableList.copyOf(orderKeys);
this.windowFrame = Optional.of(Objects.requireNonNull(windowFrame));
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
index a170ae0dd5..a7e523dfdb 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
@@ -38,7 +38,6 @@ import java.util.stream.Collectors;
public abstract class AggregateFunction extends BoundFunction implements
ExpectsInputTypes {
protected final boolean distinct;
- protected boolean isWindowFunction = false;
public AggregateFunction(String name, Expression... arguments) {
this(name, false, arguments);
@@ -78,14 +77,6 @@ public abstract class AggregateFunction extends
BoundFunction implements Expects
return distinct;
}
- public boolean isWindowFunction() {
- return isWindowFunction;
- }
-
- public void setWindowFunction(boolean windowFunction) {
- isWindowFunction = windowFunction;
- }
-
@Override
public boolean equals(Object o) {
if (this == o) {
@@ -95,8 +86,7 @@ public abstract class AggregateFunction extends BoundFunction
implements Expects
return false;
}
AggregateFunction that = (AggregateFunction) o;
- return isWindowFunction == that.isWindowFunction
- && Objects.equals(distinct, that.distinct)
+ return Objects.equals(distinct, that.distinct)
&& Objects.equals(getName(), that.getName())
&& Objects.equals(children, that.children);
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpressionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpressionTest.java
index 96c833a1de..be7da80ed1 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpressionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpressionTest.java
@@ -197,7 +197,7 @@ public class ExtractAndNormalizeWindowExpressionTest
implements MemoPatternMatch
// when Window's function is same as
AggregateFunction.
// In this example, agg function [sum(id+1)] is
same as Window's function [sum(id+1) over...]
List<NamedExpression> projects =
project.getProjects();
- return projects.get(1).child(0) instanceof
SlotReference
+ return projects.get(1) instanceof SlotReference
&& projects.get(2).equals(windowAlias);
})
)
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
index 254684eedb..ee0316e67f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregateTest.java
@@ -88,10 +88,9 @@ public class NormalizeAggregateTest implements
MemoPatternMatchSupported {
.equals(aggregateFunction.child(0)))
.when(FieldChecker.check("normalized",
true))
).when(project ->
project.getProjects().get(0).equals(key))
- .when(project -> project.getProjects().get(1)
instanceof Alias)
.when(project ->
(project.getProjects().get(1)).getExprId()
.equals(aggregateFunction.getExprId()))
- .when(project -> project.getProjects().get(1).child(0)
instanceof SlotReference)
+ .when(project -> project.getProjects().get(1)
instanceof SlotReference)
);
}
@@ -102,8 +101,8 @@ public class NormalizeAggregateTest implements
MemoPatternMatchSupported {
*
* after rewrite:
* LogicalProject ( projects=[(sum((id * 1))#6 + 2) AS `(sum((id * 1)) +
2)`#4] )
- * +--LogicalAggregate ( phase=LOCAL, outputExpr=[sum((id * 1)#5) AS
`sum((id * 1))`#6], groupByExpr=[name#2] )
- * +--LogicalProject ( projects=[name#2, (id#0 * 1) AS `(id * 1)`#5] )
+ * +--LogicalAggregate ( phase=LOCAL, outputExpr=[sum(id#0 * 1) AS
`sum((id * 1))`#6], groupByExpr=[name#2] )
+ * +--LogicalProject ( projects=[name#2, id#0] )
* +--GroupPlan( GroupId#0 )
*/
@Test
@@ -126,8 +125,6 @@ public class NormalizeAggregateTest implements
MemoPatternMatchSupported {
logicalProject(
logicalOlapScan()
).when(project ->
project.getProjects().size() == 2)
- .when(project ->
project.getProjects().get(0) instanceof SlotReference)
- .when(project ->
project.getProjects().get(1).child(0).equals(multiply))
).when(agg ->
agg.getGroupByExpressions().equals(
ImmutableList.of(rStudent.getOutput().get(2)))
)
diff --git a/regression-test/suites/nereids_syntax_p0/explain.groovy
b/regression-test/suites/nereids_syntax_p0/explain.groovy
index 91a2abb95d..3a98d42f62 100644
--- a/regression-test/suites/nereids_syntax_p0/explain.groovy
+++ b/regression-test/suites/nereids_syntax_p0/explain.groovy
@@ -25,7 +25,6 @@ suite("nereids_explain") {
explain {
sql("select count(2) + 1, sum(2) + sum(lo_suppkey) from lineorder")
contains "(sum(2) + sum(lo_suppkey))[#"
- contains "project output tuple id: 1"
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]