This is an automated email from the ASF dual-hosted git repository.
starocean999 pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push:
new 803c0521001 [fix](nereids)modify agg function nullability in
PhysicalHashAggregate (#42018)
803c0521001 is described below
commit 803c05210015ac744dd61cb002c208cfe90b64cd
Author: starocean999 <[email protected]>
AuthorDate: Tue Oct 22 11:23:39 2024 +0800
[fix](nereids)modify agg function nullability in PhysicalHashAggregate
(#42018)
## Proposed changes
pick from master https://github.com/apache/doris/pull/41943
<!--Describe your changes.-->
---
.../doris/nereids/jobs/executor/Rewriter.java | 3 +-
.../plans/physical/PhysicalHashAggregate.java | 43 ++++++++++++++++--
.../rules/rewrite/AggregateStrategiesTest.java | 51 +++++++++++++++++++++-
3 files changed, 90 insertions(+), 7 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index a8860bf6cb3..3e1236cd265 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -403,7 +403,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new EliminateFilter(),
new PushDownFilterThroughProject(),
new MergeProjects(),
- new PruneOlapScanTablet()
+ new PruneOlapScanTablet(),
+ new AdjustAggregateNullableForEmptySet()
),
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java
index a79953dc71e..89e61ab86d2 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashAggregate.java
@@ -22,16 +22,20 @@ import
org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.RequireProperties;
import org.apache.doris.nereids.properties.RequirePropertiesSupplier;
+import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.Statistics;
@@ -91,8 +95,9 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnar
super(PlanType.PHYSICAL_HASH_AGGREGATE, groupExpression,
logicalProperties, child);
this.groupByExpressions = ImmutableList.copyOf(
Objects.requireNonNull(groupByExpressions, "groupByExpressions
cannot be null"));
- this.outputExpressions = ImmutableList.copyOf(
- Objects.requireNonNull(outputExpressions, "outputExpressions
cannot be null"));
+ this.outputExpressions = adjustNullableForOutputs(
+ Objects.requireNonNull(outputExpressions, "outputExpressions
cannot be null"),
+ groupByExpressions.isEmpty());
this.partitionExpressions = Objects.requireNonNull(
partitionExpressions, "partitionExpressions cannot be null");
this.aggregateParam = Objects.requireNonNull(aggregateParam,
"aggregate param cannot be null");
@@ -118,8 +123,9 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan>
extends PhysicalUnar
child);
this.groupByExpressions = ImmutableList.copyOf(
Objects.requireNonNull(groupByExpressions, "groupByExpressions
cannot be null"));
- this.outputExpressions = ImmutableList.copyOf(
- Objects.requireNonNull(outputExpressions, "outputExpressions
cannot be null"));
+ this.outputExpressions = adjustNullableForOutputs(
+ Objects.requireNonNull(outputExpressions, "outputExpressions
cannot be null"),
+ groupByExpressions.isEmpty());
this.partitionExpressions = Objects.requireNonNull(
partitionExpressions, "partitionExpressions cannot be null");
this.aggregateParam = Objects.requireNonNull(aggregateParam,
"aggregate param cannot be null");
@@ -299,4 +305,33 @@ public class PhysicalHashAggregate<CHILD_TYPE extends
Plan> extends PhysicalUnar
requireProperties, physicalProperties, statistics,
child());
}
+
+ /**
+ * sql: select sum(distinct c1) from t;
+ * assume c1 is not null, because there is no group by
+ * sum(distinct c1)'s nullable is alwasNullable in rewritten phase.
+ * But in implementation phase, we may create 3 phase agg with group by
key c1.
+ * And the sum(distinct c1)'s nullability should be changed depending on
if there is any group by expressions.
+ * This pr update the agg function's nullability accordingly
+ */
+ private List<NamedExpression>
adjustNullableForOutputs(List<NamedExpression> outputs, boolean alwaysNullable)
{
+ return ExpressionUtils.rewriteDownShortCircuit(outputs, output -> {
+ if (output instanceof AggregateExpression) {
+ AggregateFunction function = ((AggregateExpression)
output).getFunction();
+ if (function instanceof NullableAggregateFunction
+ && ((NullableAggregateFunction)
function).isAlwaysNullable() != alwaysNullable) {
+ AggregateParam param = ((AggregateExpression)
output).getAggregateParam();
+ Expression child = ((AggregateExpression) output).child();
+ AggregateFunction newFunction =
((NullableAggregateFunction) function)
+ .withAlwaysNullable(alwaysNullable);
+ if (function == child) {
+ // function is also child
+ child = newFunction;
+ }
+ return new AggregateExpression(newFunction, param, child);
+ }
+ }
+ return output;
+ });
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
index 34c16309181..e1e03e64d98 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategiesTest.java
@@ -29,8 +29,10 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.AggMode;
@@ -54,6 +56,7 @@ import org.junit.jupiter.api.TestInstance;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import java.util.Set;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class AggregateStrategiesTest implements MemoPatternMatchSupported {
@@ -138,7 +141,7 @@ public class AggregateStrategiesTest implements
MemoPatternMatchSupported {
Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList,
true, Optional.empty(), rStudent);
- Sum localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot());
+ Sum localOutput0 = new Sum(false, true,
rStudent.getOutput().get(0).toSlot());
PlanChecker.from(MemoTestUtils.createConnectContext(), root)
.applyImplementation(twoPhaseAggregateWithoutDistinct())
@@ -380,6 +383,40 @@ public class AggregateStrategiesTest implements
MemoPatternMatchSupported {
);
}
+ @Test
+ public void distinctApply4PhaseRuleNullableChange() {
+ Slot id = rStudent.getOutput().get(0).toSlot();
+ List<Expression> groupExpressionList = Lists.newArrayList();
+ List<NamedExpression> outputExpressionList = Lists.newArrayList(
+ new Alias(new Count(true, id), "count_id"),
+ new Alias(new Sum(id), "sum_id"));
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList,
+ true, Optional.empty(), rStudent);
+
+ // select count(distinct id), sum(id) from t;
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyImplementation(fourPhaseAggregateWithDistinct())
+ .matches(
+ physicalHashAggregate(
+ physicalHashAggregate(
+ physicalHashAggregate(
+ physicalHashAggregate()
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
+ .when(agg ->
agg.getGroupByExpressions().get(0).equals(id))
+ .when(agg ->
verifyAlwaysNullableFlag(
+
agg.getAggregateFunctions(), false)))
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.GLOBAL))
+ .when(agg ->
agg.getGroupByExpressions().get(0).equals(id))
+ .when(agg ->
verifyAlwaysNullableFlag(agg.getAggregateFunctions(),
+ false)))
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
+ .when(agg ->
agg.getGroupByExpressions().isEmpty())
+ .when(agg ->
verifyAlwaysNullableFlag(agg.getAggregateFunctions(), true)))
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.DISTINCT_GLOBAL))
+ .when(agg ->
agg.getGroupByExpressions().isEmpty())
+ .when(agg ->
verifyAlwaysNullableFlag(agg.getAggregateFunctions(), true)));
+ }
+
private Rule twoPhaseAggregateWithoutDistinct() {
return new AggregateStrategies().buildRules()
.stream()
@@ -400,8 +437,18 @@ public class AggregateStrategiesTest implements
MemoPatternMatchSupported {
private Rule fourPhaseAggregateWithDistinct() {
return new AggregateStrategies().buildRules()
.stream()
- .filter(rule -> rule.getRuleType() ==
RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT)
+ .filter(rule -> rule.getRuleType() ==
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT)
.findFirst()
.get();
}
+
+ private boolean verifyAlwaysNullableFlag(Set<AggregateFunction> functions,
boolean alwaysNullable) {
+ for (AggregateFunction f : functions) {
+ if (f instanceof NullableAggregateFunction
+ && ((NullableAggregateFunction) f).isAlwaysNullable() !=
alwaysNullable) {
+ return false;
+ }
+ }
+ return true;
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]