This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 1655f1ea33e branch-3.1: [opt](nereids) opt for adjustting slot
nullable and add exception for changing slot nullable #52748 (#53533)
1655f1ea33e is described below
commit 1655f1ea33e327bccc05654c619791119b57e142
Author: yujun <[email protected]>
AuthorDate: Fri Jul 18 23:16:42 2025 +0800
branch-3.1: [opt](nereids) opt for adjustting slot nullable and add
exception for changing slot nullable #52748 (#53533)
cherry pick from #52748
---
.../doris/nereids/jobs/executor/Analyzer.java | 38 +-
.../doris/nereids/jobs/executor/Rewriter.java | 15 +-
.../org/apache/doris/nereids/rules/RuleType.java | 4 +-
.../AdjustAggregateNullableForEmptySet.java | 109 ++++--
.../nereids/rules/analysis/BindExpression.java | 28 +-
.../nereids/rules/analysis/ExpressionAnalyzer.java | 12 +
.../nereids/rules/analysis/FillUpMissingSlots.java | 8 +-
.../rules/expression/ExpressionRewrite.java | 2 +-
.../nereids/rules/rewrite/AdjustNullable.java | 335 +++++++++++++-----
.../nereids/rules/rewrite/EliminateGroupBy.java | 9 +-
.../nereids/rules/rewrite/EliminateJoinByFK.java | 67 +++-
.../trees/plans/logical/LogicalAggregate.java | 7 +-
.../nereids/trees/plans/logical/LogicalApply.java | 39 +-
.../nereids/trees/plans/logical/LogicalHaving.java | 2 +-
.../rules/analysis/AnalyzeSubQueryTest.java | 32 +-
.../rules/analysis/AnalyzeWhereSubqueryTest.java | 4 +-
.../rules/analysis/NormalizeAggregateTest.java | 392 +++++++++++++++++++++
.../AggScalarSubQueryToWindowFunctionTest.java | 1 +
.../rules/rewrite/EliminateJoinByFkTest.java | 96 ++++-
.../rules/rewrite/PullUpProjectUnderApplyTest.java | 2 +
.../rules/rewrite/mv/SelectMvIndexTest.java | 1 +
.../trees/plans/logical/LogicalAggregateTest.java | 101 ++++++
.../apache/doris/utframe/TestWithFeService.java | 1 +
.../adjust_nullable/test_agg_nullable.out | Bin 0 -> 297 bytes
.../adjust_nullable/test_agg_nullable.groovy | 30 ++
25 files changed, 1141 insertions(+), 194 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
index 41684e8d63d..b6066102632 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.jobs.executor;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.rewrite.RewriteJob;
+import org.apache.doris.nereids.rules.RuleType;
import
org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AnalyzeCTE;
import org.apache.doris.nereids.rules.analysis.BindExpression;
@@ -44,6 +45,7 @@ import
org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
import org.apache.doris.nereids.rules.analysis.SubqueryToApply;
import org.apache.doris.nereids.rules.analysis.VariableToLiteral;
+import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
import org.apache.doris.nereids.rules.rewrite.SemiJoinCommute;
import org.apache.doris.nereids.rules.rewrite.SimplifyAggGroupBy;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
@@ -112,7 +114,19 @@ public class Analyzer extends AbstractBatchJobExecutor {
new EliminateDistinctConstant(),
new ProjectWithDistinctToAggregate(),
new ReplaceExpressionByChildOutput(),
- new OneRowRelationExtractAggregate()
+ new OneRowRelationExtractAggregate(),
+
+ // ProjectToGlobalAggregate may generate an aggregate with
empty group by expressions.
+ // for sort / having, need to adjust their agg functions'
nullable.
+ // for example: select sum(a) from t having sum(b) > 10
order by sum(c),
+ // then will have:
+ // sort(sum(c)) sort(sum(c))
+ // | |
+ // having(sum(b) > 10) ==> having(sum(b) > 10)
+ // | |
+ // project(sum(a)) agg(sum(a))
+ // then need to adjust SORT and HAVING's sum to nullable.
+ new AdjustAggregateNullableForEmptySet()
),
topDown(
new FillUpMissingSlots(),
@@ -121,7 +135,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
// LogicalProject for normalize. This rule depends on
FillUpMissingSlots to fill up slots.
new NormalizeRepeat()
),
- bottomUp(new AdjustAggregateNullableForEmptySet()),
// consider sql with user defined var @t_zone
// set @t_zone='GMT';
// SELECT
@@ -147,15 +160,18 @@ public class Analyzer extends AbstractBatchJobExecutor {
),
topDown(new LeadingJoin()),
bottomUp(new NormalizeGenerate()),
- bottomUp(new SubqueryToApply())
- /*
- * Notice, MergeProjects rule should NOT be placed after
SubqueryToApply in analyze phase.
- * because in SubqueryToApply, we may add assert_true function with
subquery output slot in projects list.
- * on the other hand, the assert_true function should be not be in
final output.
- * in order to keep the plan unchanged, we add a new project node to
prune the extra assert_true slot.
- * but MergeProjects rule will merge the two projects and keep
assert_true anyway.
- * so we move MergeProjects from analyze to rewrite phase.
- */
+ /*
+ * Notice, MergeProjects rule should NOT be placed after
SubqueryToApply in analyze phase.
+ * because in SubqueryToApply, we may add assert_true function
with subquery output slot in projects list.
+ * on the other hand, the assert_true function should be not be in
final output.
+ * in order to keep the plan unchanged, we add a new project node
to prune the extra assert_true slot.
+ * but MergeProjects rule will merge the two projects and keep
assert_true anyway.
+ * so we move MergeProjects from analyze to rewrite phase.
+ */
+ bottomUp(new SubqueryToApply()),
+ // for cte: analyze producer -> analyze consumer -> rewrite
consumer -> rewrite producer,
+ // in order to ensure cte consumer had right nullable attribute,
need adjust nullable at analyze phase.
+ custom(RuleType.ADJUST_NULLABLE, () -> new AdjustNullable(true))
);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index baa5a703ad2..26e7b14242f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -22,7 +22,6 @@ import
org.apache.doris.nereids.jobs.rewrite.CostBasedRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteJob;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.RuleType;
-import
org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.CheckAfterRewrite;
import
org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
@@ -246,8 +245,6 @@ public class Rewriter extends AbstractBatchJobExecutor {
new EliminateSemiJoin()
)
),
- // please note: this rule must run before NormalizeAggregate
- topDown(new AdjustAggregateNullableForEmptySet()),
// The rule modification needs to be done after the subquery
is unnested,
// because for scalarSubQuery, the connection condition is
stored in apply in the analyzer phase,
// but when normalizeAggregate/normalizeSort is performed, the
members in apply cannot be obtained,
@@ -346,9 +343,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("Eliminate GroupBy",
topDown(new EliminateGroupBy(),
- new MergeAggregate(),
- // need to adjust min/max/sum nullable
attribute after merge aggregate
- new AdjustAggregateNullableForEmptySet())
+ new MergeAggregate())
),
topic("Eager aggregation",
@@ -414,11 +409,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
new EliminateFilter(),
new PushDownFilterThroughProject(),
new MergeProjects(),
- new PruneOlapScanTablet(),
- // SelectMaterializedIndexWithAggregate may
change the nullability of agg functions
- // need rerun
AdjustAggregateNullableForEmptySet to make the nullability correct
- // TODO: remove
AdjustAggregateNullableForEmptySet when remove rbo mv selection rules
- new AdjustAggregateNullableForEmptySet()
+ new PruneOlapScanTablet()
),
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
@@ -577,7 +568,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
custom(RuleType.REWRITE_CTE_CHILDREN, () ->
new RewriteCteChildren(afterPushDownJobs))
),
topic("whole plan check",
- custom(RuleType.ADJUST_NULLABLE,
AdjustNullable::new)
+ custom(RuleType.ADJUST_NULLABLE, () -> new
AdjustNullable(false))
),
// NullableDependentExpressionRewrite need to be done
after nullable fixed
topic("condition function", bottomUp(ImmutableList.of(
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 6a699bfd1c5..4b65c79a848 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -85,10 +85,8 @@ public enum RuleType {
ANALYZE_CTE(RuleTypeClass.REWRITE),
RELATION_AUTHENTICATION(RuleTypeClass.VALIDATION),
- ADJUST_NULLABLE_FOR_PROJECT_SLOT(RuleTypeClass.REWRITE),
- ADJUST_NULLABLE_FOR_AGGREGATE_SLOT(RuleTypeClass.REWRITE),
+ ADJUST_NULLABLE_FOR_SORT_SLOT(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_FOR_HAVING_SLOT(RuleTypeClass.REWRITE),
- ADJUST_NULLABLE_FOR_REPEAT_SLOT(RuleTypeClass.REWRITE),
ADD_DEFAULT_LIMIT(RuleTypeClass.REWRITE),
CHECK_ROW_POLICY(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
index 5543341ae27..0ccb9f2d9de 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
@@ -17,23 +17,24 @@
package org.apache.doris.nereids.rules.analysis;
+import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Set;
-import java.util.stream.Collectors;
/**
* adjust aggregate nullable when: group expr list is empty and function is
NullableAggregateFunction,
@@ -43,57 +44,87 @@ public class AdjustAggregateNullableForEmptySet implements
RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
- RuleType.ADJUST_NULLABLE_FOR_AGGREGATE_SLOT.build(
- logicalAggregate()
- .then(agg -> {
- List<NamedExpression> outputExprs =
agg.getOutputExpressions();
- boolean noGroupBy =
agg.getGroupByExpressions().isEmpty();
- ImmutableList.Builder<NamedExpression>
newOutput
- =
ImmutableList.builderWithExpectedSize(outputExprs.size());
- for (NamedExpression ne : outputExprs) {
- NamedExpression newExpr =
- ((NamedExpression)
FunctionReplacer.INSTANCE.replace(ne, noGroupBy));
- newOutput.add(newExpr);
- }
- return
agg.withAggOutput(newOutput.build());
- })
- ),
RuleType.ADJUST_NULLABLE_FOR_HAVING_SLOT.build(
logicalHaving(logicalAggregate())
- .then(having -> {
- Set<Expression> conjuncts =
having.getConjuncts();
- boolean noGroupBy =
having.child().getGroupByExpressions().isEmpty();
- ImmutableSet.Builder<Expression>
newConjuncts
- =
ImmutableSet.builderWithExpectedSize(conjuncts.size());
- for (Expression expr : conjuncts) {
- Expression newExpr =
FunctionReplacer.INSTANCE.replace(expr, noGroupBy);
- newConjuncts.add(newExpr);
- }
- return new
LogicalHaving<>(newConjuncts.build(), having.child());
- })
+ .then(having -> replaceHaving(having,
having.child().getGroupByExpressions().isEmpty()))
+ ),
+ RuleType.ADJUST_NULLABLE_FOR_SORT_SLOT.build(
+ logicalSort(logicalAggregate())
+ .then(sort -> replaceSort(sort,
sort.child().getGroupByExpressions().isEmpty()))
+ ),
+ RuleType.ADJUST_NULLABLE_FOR_SORT_SLOT.build(
+ logicalSort(logicalHaving(logicalAggregate()))
+ .then(sort -> replaceSort(sort,
sort.child().child().getGroupByExpressions().isEmpty()))
)
);
}
+ public static Expression replaceExpression(Expression expression, boolean
alwaysNullable) {
+ return FunctionReplacer.INSTANCE.replace(expression, alwaysNullable);
+ }
+
+ private LogicalPlan replaceSort(LogicalSort<?> sort, boolean
alwaysNullable) {
+ ImmutableList.Builder<OrderKey> newOrderKeysBuilder
+ =
ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size());
+ sort.getOrderKeys().forEach(
+ key ->
newOrderKeysBuilder.add(key.withExpression(replaceExpression(key.getExpr(),
alwaysNullable))));
+ List<OrderKey> newOrderKeys = newOrderKeysBuilder.build();
+ if (newOrderKeys.equals(sort.getOrderKeys())) {
+ return null;
+ }
+ return sort.withOrderKeys(newOrderKeys);
+ }
+
+ private LogicalPlan replaceHaving(LogicalHaving<?> having, boolean
alwaysNullable) {
+ Set<Expression> conjuncts = having.getConjuncts();
+ ImmutableSet.Builder<Expression> newConjunctsBuilder
+ = ImmutableSet.builderWithExpectedSize(conjuncts.size());
+ for (Expression expr : conjuncts) {
+ Expression newExpr = replaceExpression(expr, alwaysNullable);
+ newConjunctsBuilder.add(newExpr);
+ }
+ ImmutableSet<Expression> newConjuncts = newConjunctsBuilder.build();
+ if (newConjuncts.equals(having.getConjuncts())) {
+ return null;
+ }
+ return (LogicalPlan) having.withConjuncts(newConjuncts);
+ }
+
+ /**
+ * replace NullableAggregateFunction nullable
+ */
private static class FunctionReplacer extends
DefaultExpressionRewriter<Boolean> {
public static final FunctionReplacer INSTANCE = new FunctionReplacer();
public Expression replace(Expression expression, boolean
alwaysNullable) {
- return expression.accept(INSTANCE, alwaysNullable);
+ return expression.accept(this, alwaysNullable);
}
@Override
public Expression visitWindow(WindowExpression windowExpression,
Boolean alwaysNullable) {
- return windowExpression.withPartitionKeysOrderKeys(
- windowExpression.getPartitionKeys().stream()
- .map(k -> k.accept(INSTANCE, alwaysNullable))
- .collect(Collectors.toList()),
- windowExpression.getOrderKeys().stream()
- .map(k -> (OrderExpression)
k.withChildren(k.children().stream()
- .map(c -> c.accept(INSTANCE,
alwaysNullable))
- .collect(Collectors.toList())))
- .collect(Collectors.toList())
- );
+ ImmutableList.Builder<Expression> newFunctionChildrenBuilder
+ =
ImmutableList.builderWithExpectedSize(windowExpression.getFunction().children().size());
+ for (Expression child : windowExpression.getFunction().children())
{
+ newFunctionChildrenBuilder.add(child.accept(this,
alwaysNullable));
+ }
+ Expression newFunction =
windowExpression.getFunction().withChildren(newFunctionChildrenBuilder.build());
+ ImmutableList.Builder<Expression> newPartitionKeysBuilder
+ =
ImmutableList.builderWithExpectedSize(windowExpression.getPartitionKeys().size());
+ for (Expression partitionKey :
windowExpression.getPartitionKeys()) {
+ newPartitionKeysBuilder.add(partitionKey.accept(this,
alwaysNullable));
+ }
+ ImmutableList.Builder<OrderExpression> newOrderKeysBuilder
+ =
ImmutableList.builderWithExpectedSize(windowExpression.getOrderKeys().size());
+ for (OrderExpression orderKey : windowExpression.getOrderKeys()) {
+ ImmutableList.Builder<Expression> newChildrenBuilder
+ =
ImmutableList.builderWithExpectedSize(orderKey.children().size());
+ for (Expression child : orderKey.children()) {
+ newChildrenBuilder.add(child.accept(this, alwaysNullable));
+ }
+ newOrderKeysBuilder.add((OrderExpression)
orderKey.withChildren(newChildrenBuilder.build()));
+ }
+ return windowExpression.withFunctionPartitionKeysOrderKeys(
+ newFunction, newPartitionKeysBuilder.build(),
newOrderKeysBuilder.build());
}
@Override
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
index 6960b2987e6..b2d7779f902 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
@@ -55,6 +55,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import
org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
@@ -62,6 +63,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement
import
org.apache.doris.nereids.trees.expressions.functions.table.TableValuedFunction;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@@ -330,6 +332,7 @@ public class BindExpression implements AnalysisRuleFactory {
SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(
oneRowRelation, cascadesContext, ImmutableList.of(), true,
true);
List<NamedExpression> projects =
analyzer.analyzeToList(oneRowRelation.getProjects());
+ projects = adjustProjectionAggNullable(projects);
return new LogicalOneRowRelation(oneRowRelation.getRelationId(),
projects);
}
@@ -618,7 +621,7 @@ public class BindExpression implements AnalysisRuleFactory {
Supplier<Set<NamedExpression>> boundExcepts = Suppliers.memoize(
() -> analyzer.analyzeToSet(project.getExcepts()));
- Builder<NamedExpression> boundProjections =
ImmutableList.builderWithExpectedSize(project.arity());
+ Builder<NamedExpression> boundProjections =
ImmutableList.builderWithExpectedSize(project.getProjects().size());
StatementContext statementContext = ctx.statementContext;
for (Expression expression : project.getProjects()) {
Expression expr = analyzer.analyze(expression);
@@ -640,7 +643,28 @@ public class BindExpression implements AnalysisRuleFactory
{
});
}
}
- return project.withProjects(boundProjections.build());
+ List<NamedExpression> projects =
adjustProjectionAggNullable(boundProjections.build());
+ return project.withProjects(projects);
+ }
+
+ private List<NamedExpression>
adjustProjectionAggNullable(List<NamedExpression> expressions) {
+ boolean hasAggregation = expressions.stream()
+ .anyMatch(expr ->
expr.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));
+ if (!hasAggregation) {
+ return expressions;
+ }
+ Builder<NamedExpression> newExpressionsBuilder =
ImmutableList.builderWithExpectedSize(expressions.size());
+ for (NamedExpression expr : expressions) {
+ expr = (NamedExpression) expr.rewriteDownShortCircuit(e -> {
+ // for `select sum(a) from t`, sum(a) is nullable
+ if (e instanceof NullableAggregateFunction) {
+ return ((NullableAggregateFunction)
e).withAlwaysNullable(true);
+ }
+ return e;
+ });
+ newExpressionsBuilder.add(expr);
+ }
+ return newExpressionsBuilder.build();
}
private Plan bindFilter(MatchingContext<LogicalFilter<Plan>> ctx) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
index b27b328f756..ea9638e9360 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
@@ -66,8 +66,10 @@ import
org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.Variable;
import org.apache.doris.nereids.trees.expressions.WhenClause;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import
org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder;
@@ -527,6 +529,16 @@ public class ExpressionAnalyzer extends
SubExprAnalyzer<ExpressionRewriteContext
return TypeCoercionUtils.processBoundFunction(boundFunction);
}
+ @Override
+ public Expression visitWindow(WindowExpression windowExpression,
ExpressionRewriteContext context) {
+ windowExpression = (WindowExpression)
super.visitWindow(windowExpression, context);
+ Expression function = windowExpression.getFunction();
+ if (function instanceof NullableAggregateFunction) {
+ return windowExpression.withFunction(((NullableAggregateFunction)
function).withAlwaysNullable(true));
+ }
+ return windowExpression;
+ }
+
/**
* gets the method for calculating the time.
* e.g. YEARS_ADD、YEARS_SUB、DAYS_ADD 、DAYS_SUB
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
index c55ed5957ba..24eeaef6dfa 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
@@ -219,10 +219,14 @@ public class FillUpMissingSlots implements
AnalysisRuleFactory {
// avoid throw exception even if having have slot
from its child.
// because we will add a project between having
and project.
Resolver resolver = new Resolver(agg, false,
outerScope);
- having.getConjuncts().forEach(resolver::resolve);
+ Set<Expression> adjustAggNullableConjuncts =
having.getConjuncts().stream()
+ .map(conjunct ->
AdjustAggregateNullableForEmptySet.replaceExpression(
+ conjunct, true))
+ .collect(Collectors.toSet());
+
adjustAggNullableConjuncts.forEach(resolver::resolve);
agg =
agg.withAggOutput(resolver.getNewOutputSlots());
Set<Expression> newConjuncts =
ExpressionUtils.replace(
- having.getConjuncts(),
resolver.getSubstitution());
+ adjustAggNullableConjuncts,
resolver.getSubstitution());
ImmutableList.Builder<NamedExpression> projects =
ImmutableList.builder();
projects.addAll(project.getOutputs()).addAll(agg.getOutput());
return new LogicalHaving<>(newConjuncts, new
LogicalProject<>(projects.build(), agg));
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java
index 5aa43fb05b6..376c22c7079 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java
@@ -273,7 +273,7 @@ public class ExpressionRewrite implements
RewriteRuleFactory {
if (newConjuncts.equals(having.getConjuncts())) {
return having;
}
- return having.withExpressions(newConjuncts);
+ return having.withConjuncts(newConjuncts);
}).toRule(RuleType.REWRITE_HAVING_EXPRESSION);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java
index 808288b8fe3..df711a921e9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java
@@ -17,19 +17,22 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.common.util.DebugUtil;
+import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.Function;
-import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
@@ -47,19 +50,23 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.qe.ConnectContext;
-import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
/**
* because some rule could change output's nullable.
@@ -67,6 +74,14 @@ import java.util.Set;
*/
public class AdjustNullable extends DefaultPlanRewriter<Map<ExprId, Slot>>
implements CustomRewriter {
+ private static final Logger LOG =
LogManager.getLogger(AdjustNullable.class);
+
+ private final boolean isAnalyzedPhase;
+
+ public AdjustNullable(boolean isAnalyzedPhase) {
+ this.isAnalyzedPhase = isAnalyzedPhase;
+ }
+
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
return plan.accept(this, Maps.newHashMap());
@@ -83,66 +98,145 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
@Override
public Plan visitLogicalSink(LogicalSink<? extends Plan> logicalSink,
Map<ExprId, Slot> replaceMap) {
logicalSink = (LogicalSink<? extends Plan>) super.visit(logicalSink,
replaceMap);
- List<NamedExpression> newOutputExprs =
updateExpressions(logicalSink.getOutputExprs(), replaceMap);
- return logicalSink.withOutputExprs(newOutputExprs);
+ Optional<List<NamedExpression>> newOutputExprs
+ = updateExpressions(logicalSink.getOutputExprs(), replaceMap,
true);
+ if (!newOutputExprs.isPresent()) {
+ return logicalSink;
+ } else {
+ return logicalSink.withOutputExprs(newOutputExprs.get());
+ }
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan>
aggregate, Map<ExprId, Slot> replaceMap) {
aggregate = (LogicalAggregate<? extends Plan>) super.visit(aggregate,
replaceMap);
- List<NamedExpression> newOutputs
- = updateExpressions(aggregate.getOutputExpressions(),
replaceMap);
- List<Expression> newGroupExpressions
- = updateExpressions(aggregate.getGroupByExpressions(),
replaceMap);
- newOutputs.forEach(o -> replaceMap.put(o.getExprId(), o.toSlot()));
- return aggregate.withGroupByAndOutput(newGroupExpressions, newOutputs);
+ Optional<List<NamedExpression>> newOutputs
+ = updateExpressions(aggregate.getOutputExpressions(),
replaceMap, true);
+ Optional<List<Expression>> newGroupBy =
updateExpressions(aggregate.getGroupByExpressions(), replaceMap, true);
+ for (NamedExpression newOutput :
newOutputs.orElse(aggregate.getOutputExpressions())) {
+ replaceMap.put(newOutput.getExprId(), newOutput.toSlot());
+ }
+ if (!newOutputs.isPresent() && !newGroupBy.isPresent()) {
+ return aggregate;
+ }
+ return aggregate.withGroupByAndOutput(
+
newGroupBy.orElse(newGroupBy.orElse(aggregate.getGroupByExpressions())),
+ newOutputs.orElse(newOutputs.orElse(aggregate.getOutputs()))
+ );
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter,
Map<ExprId, Slot> replaceMap) {
filter = (LogicalFilter<? extends Plan>) super.visit(filter,
replaceMap);
- Set<Expression> conjuncts = updateExpressions(filter.getConjuncts(),
replaceMap);
- return filter.withConjuncts(conjuncts).recomputeLogicalProperties();
+ Optional<Set<Expression>> conjuncts =
updateExpressions(filter.getConjuncts(), replaceMap, true);
+ if (!conjuncts.isPresent()) {
+ return filter;
+ }
+ return filter.withConjunctsAndChild(conjuncts.get(), filter.child());
}
@Override
public Plan visitLogicalGenerate(LogicalGenerate<? extends Plan> generate,
Map<ExprId, Slot> replaceMap) {
generate = (LogicalGenerate<? extends Plan>) super.visit(generate,
replaceMap);
- List<Function> newGenerators =
updateExpressions(generate.getGenerators(), replaceMap);
- Plan newGenerate =
generate.withGenerators(newGenerators).recomputeLogicalProperties();
- newGenerate.getOutputSet().forEach(o -> replaceMap.put(o.getExprId(),
o));
+ Optional<List<Function>> newGenerators =
updateExpressions(generate.getGenerators(), replaceMap, true);
+ Plan newGenerate = generate;
+ if (newGenerators.isPresent()) {
+ newGenerate =
generate.withGenerators(newGenerators.get()).recomputeLogicalProperties();
+ }
+ for (Slot slot : newGenerate.getOutput()) {
+ replaceMap.put(slot.getExprId(), slot);
+ }
return newGenerate;
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan>
join, Map<ExprId, Slot> replaceMap) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join,
replaceMap);
- List<Expression> hashConjuncts =
updateExpressions(join.getHashJoinConjuncts(), replaceMap);
- List<Expression> markConjuncts;
- if (hashConjuncts.isEmpty()) {
- // if hashConjuncts is empty, mark join conjuncts may used to
build hash table
+ Optional<List<Expression>> hashConjuncts =
updateExpressions(join.getHashJoinConjuncts(), replaceMap, true);
+ Optional<List<Expression>> markConjuncts = Optional.empty();
+ boolean hadUpdatedMarkConjuncts = false;
+ if (isAnalyzedPhase || join.getHashJoinConjuncts().isEmpty()) {
+ // if hashConjuncts is empty, mark join conjuncts may use to build
hash table
// so need call updateExpressions for mark join conjuncts before
adjust nullable by output slot
- markConjuncts = updateExpressions(join.getMarkJoinConjuncts(),
replaceMap);
- } else {
- markConjuncts = null;
+ markConjuncts = updateExpressions(join.getMarkJoinConjuncts(),
replaceMap, true);
+ hadUpdatedMarkConjuncts = true;
+ }
+ // in fact, otherConjuncts shouldn't use join output nullable
attribute,
+ // it should use left and right tables' origin nullable attribute.
+ // but for history reason, BE use join output nullable attribute for
evaluating the other conditions.
+ // so here, we make a difference:
+ // 1) when at analyzed phase, still update other conjuncts without
using join output nullables.
+ // then later at rewrite phase, the join conditions may push down,
and the push down condition with proper
+ // nullable attribute.
+ // 2) when at the end of rewrite phase, update other conjuncts with
join output nullables.
+ // Just change it to be consistent with BE.
+ Optional<List<Expression>> otherConjuncts = Optional.empty();
+ if (isAnalyzedPhase) {
+ otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(),
replaceMap, true);
}
- join.getOutputSet().forEach(o -> replaceMap.put(o.getExprId(), o));
- if (markConjuncts == null) {
- // hashConjuncts is not empty, mark join conjuncts are processed
like other join conjuncts
- Preconditions.checkState(!hashConjuncts.isEmpty(), "hash conjuncts
should not be empty");
- markConjuncts = updateExpressions(join.getMarkJoinConjuncts(),
replaceMap);
+ for (Slot slot : join.getOutput()) {
+ replaceMap.put(slot.getExprId(), slot);
}
- List<Expression> otherConjuncts =
updateExpressions(join.getOtherJoinConjuncts(), replaceMap);
- return join.withJoinConjuncts(hashConjuncts, otherConjuncts,
markConjuncts,
- join.getJoinReorderContext()).recomputeLogicalProperties();
+ if (!hadUpdatedMarkConjuncts) {
+ markConjuncts = updateExpressions(join.getMarkJoinConjuncts(),
replaceMap, false);
+ }
+ if (!isAnalyzedPhase) {
+ otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(),
replaceMap, false);
+ }
+ if (!hashConjuncts.isPresent() && !markConjuncts.isPresent() &&
!otherConjuncts.isPresent()) {
+ return join;
+ }
+ return join.withJoinConjuncts(
+ hashConjuncts.orElse(join.getHashJoinConjuncts()),
+ otherConjuncts.orElse(join.getOtherJoinConjuncts()),
+ markConjuncts.orElse(join.getMarkJoinConjuncts()),
+ join.getJoinReorderContext()
+ ).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project,
Map<ExprId, Slot> replaceMap) {
project = (LogicalProject<? extends Plan>) super.visit(project,
replaceMap);
- List<NamedExpression> newProjects =
updateExpressions(project.getProjects(), replaceMap);
- newProjects.forEach(p -> replaceMap.put(p.getExprId(), p.toSlot()));
- return project.withProjects(newProjects);
+ Optional<List<NamedExpression>> newProjects =
updateExpressions(project.getProjects(), replaceMap, true);
+ for (NamedExpression newProject :
newProjects.orElse(project.getProjects())) {
+ replaceMap.put(newProject.getExprId(), newProject.toSlot());
+ }
+ if (!newProjects.isPresent()) {
+ return project;
+ }
+ return project.withProjects(newProjects.get());
+ }
+
+ @Override
+ public Plan visitLogicalApply(LogicalApply<? extends Plan, ? extends Plan>
apply, Map<ExprId, Slot> replaceMap) {
+ apply = (LogicalApply<? extends Plan, ? extends Plan>)
super.visit(apply, replaceMap);
+ Optional<Expression> newCompareExpr =
updateExpression(apply.getCompareExpr(), replaceMap, true);
+ Optional<Expression> newTypeCoercionExpr =
updateExpression(apply.getTypeCoercionExpr(), replaceMap, true);
+ Optional<List<Slot>> newCorrelationSlot =
updateExpressions(apply.getCorrelationSlot(), replaceMap, true);
+ Optional<Expression> newCorrelationFilter =
updateExpression(apply.getCorrelationFilter(), replaceMap, true);
+ Optional<MarkJoinSlotReference> newMarkJoinSlotReference =
+ updateExpression(apply.getMarkJoinSlotReference(), replaceMap,
true);
+
+ for (Slot slot : apply.getOutput()) {
+ replaceMap.put(slot.getExprId(), slot);
+ }
+ if (!newCompareExpr.isPresent() && !newTypeCoercionExpr.isPresent() &&
!newCorrelationSlot.isPresent()
+ && !newCorrelationFilter.isPresent() &&
!newMarkJoinSlotReference.isPresent()) {
+ return apply;
+ }
+
+ return new LogicalApply<>(
+ newCorrelationSlot.orElse(apply.getCorrelationSlot()),
+ apply.getSubqueryType(),
+ apply.isNot(),
+ newCompareExpr.isPresent() ? newCompareExpr :
apply.getCompareExpr(),
+ newTypeCoercionExpr.isPresent() ? newTypeCoercionExpr :
apply.getTypeCoercionExpr(),
+ newCorrelationFilter.isPresent() ? newCorrelationFilter :
apply.getCorrelationFilter(),
+ newMarkJoinSlotReference.isPresent() ?
newMarkJoinSlotReference : apply.getMarkJoinSlotReference(),
+ apply.isNeedAddSubOutputToProjects(),
+ apply.isMarkJoinSlotNotNull(),
+ apply.left(),
+ apply.right());
}
@Override
@@ -152,13 +246,15 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
List<NamedExpression> newOutputs = Lists.newArrayList();
for (NamedExpression output : repeat.getOutputExpressions()) {
+ NamedExpression newOutput;
if (flattenGroupingSetExpr.contains(output)) {
- newOutputs.add(output);
+ newOutput = output;
} else {
- newOutputs.add(updateExpression(output, replaceMap));
+ newOutput = updateExpression(output, replaceMap,
true).orElse(output);
}
+ newOutputs.add(newOutput);
+ replaceMap.put(newOutput.getExprId(), newOutput.toSlot());
}
- newOutputs.forEach(o -> replaceMap.put(o.getExprId(), o.toSlot()));
return repeat.withGroupSetsAndOutput(repeat.getGroupingSets(),
newOutputs).recomputeLogicalProperties();
}
@@ -227,37 +323,70 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
@Override
public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, Map<ExprId,
Slot> replaceMap) {
sort = (LogicalSort<? extends Plan>) super.visit(sort, replaceMap);
- List<OrderKey> newKeys = sort.getOrderKeys().stream()
- .map(old -> old.withExpression(updateExpression(old.getExpr(),
replaceMap)))
- .collect(ImmutableList.toImmutableList());
- return sort.withOrderKeys(newKeys).recomputeLogicalProperties();
+ boolean changed = false;
+ ImmutableList.Builder<OrderKey> newOrderKeys = ImmutableList.builder();
+ for (OrderKey orderKey : sort.getOrderKeys()) {
+ Optional<Expression> newOrderKey =
updateExpression(orderKey.getExpr(), replaceMap, true);
+ if (!newOrderKey.isPresent()) {
+ newOrderKeys.add(orderKey);
+ } else {
+ changed = true;
+ newOrderKeys.add(orderKey.withExpression(newOrderKey.get()));
+ }
+ }
+ if (!changed) {
+ return sort;
+ }
+ return sort.withOrderKeysAndChild(newOrderKeys.build(), sort.child());
}
@Override
public Plan visitLogicalTopN(LogicalTopN<? extends Plan> topN, Map<ExprId,
Slot> replaceMap) {
topN = (LogicalTopN<? extends Plan>) super.visit(topN, replaceMap);
- List<OrderKey> newKeys = topN.getOrderKeys().stream()
- .map(old -> old.withExpression(updateExpression(old.getExpr(),
replaceMap)))
- .collect(ImmutableList.toImmutableList());
- return topN.withOrderKeys(newKeys).recomputeLogicalProperties();
+ boolean changed = false;
+ ImmutableList.Builder<OrderKey> newOrderKeys = ImmutableList.builder();
+ for (OrderKey orderKey : topN.getOrderKeys()) {
+ Optional<Expression> newOrderKey =
updateExpression(orderKey.getExpr(), replaceMap, true);
+ if (!newOrderKey.isPresent()) {
+ newOrderKeys.add(orderKey);
+ } else {
+ changed = true;
+ newOrderKeys.add(orderKey.withExpression(newOrderKey.get()));
+ }
+ }
+ if (!changed) {
+ return topN;
+ }
+ return
topN.withOrderKeys(newOrderKeys.build()).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalWindow(LogicalWindow<? extends Plan> window,
Map<ExprId, Slot> replaceMap) {
window = (LogicalWindow<? extends Plan>) super.visit(window,
replaceMap);
- List<NamedExpression> windowExpressions =
- updateExpressions(window.getWindowExpressions(), replaceMap);
- windowExpressions.forEach(w -> replaceMap.put(w.getExprId(),
w.toSlot()));
- return window.withExpressionsAndChild(windowExpressions,
window.child());
+ Optional<List<NamedExpression>> windowExpressions =
+ updateExpressions(window.getWindowExpressions(), replaceMap,
true);
+ for (NamedExpression w :
windowExpressions.orElse(window.getWindowExpressions())) {
+ replaceMap.put(w.getExprId(), w.toSlot());
+ }
+ if (!windowExpressions.isPresent()) {
+ return window;
+ }
+ return window.withExpressionsAndChild(windowExpressions.get(),
window.child());
}
@Override
public Plan visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan>
partitionTopN,
Map<ExprId, Slot> replaceMap) {
partitionTopN = (LogicalPartitionTopN<? extends Plan>)
super.visit(partitionTopN, replaceMap);
- List<Expression> partitionKeys =
updateExpressions(partitionTopN.getPartitionKeys(), replaceMap);
- List<OrderExpression> orderKeys =
updateExpressions(partitionTopN.getOrderKeys(), replaceMap);
- return partitionTopN.withPartitionKeysAndOrderKeys(partitionKeys,
orderKeys);
+ Optional<List<Expression>> partitionKeys
+ = updateExpressions(partitionTopN.getPartitionKeys(),
replaceMap, true);
+ Optional<List<OrderExpression>> orderKeys =
updateExpressions(partitionTopN.getOrderKeys(), replaceMap, true);
+ if (!partitionKeys.isPresent() && !orderKeys.isPresent()) {
+ return partitionTopN;
+ }
+ return partitionTopN.withPartitionKeysAndOrderKeys(
+ partitionKeys.orElse(partitionTopN.getPartitionKeys()),
orderKeys.orElse(partitionTopN.getOrderKeys())
+ );
}
@Override
@@ -265,54 +394,94 @@ public class AdjustNullable extends
DefaultPlanRewriter<Map<ExprId, Slot>> imple
Map<Slot, Slot> consumerToProducerOutputMap = new LinkedHashMap<>();
Multimap<Slot, Slot> producerToConsumerOutputMap =
LinkedHashMultimap.create();
for (Slot producerOutputSlot :
cteConsumer.getConsumerToProducerOutputMap().values()) {
- Slot newProducerOutputSlot = updateExpression(producerOutputSlot,
replaceMap);
+ Optional<Slot> newProducerOutputSlot =
updateExpression(producerOutputSlot, replaceMap, true);
for (Slot consumerOutputSlot :
cteConsumer.getProducerToConsumerOutputMap().get(producerOutputSlot)) {
- Slot newConsumerOutputSlot =
consumerOutputSlot.withNullable(newProducerOutputSlot.nullable());
- producerToConsumerOutputMap.put(newProducerOutputSlot,
newConsumerOutputSlot);
- consumerToProducerOutputMap.put(newConsumerOutputSlot,
newProducerOutputSlot);
+ Slot slot = newProducerOutputSlot.orElse(producerOutputSlot);
+ Slot newConsumerOutputSlot =
consumerOutputSlot.withNullable(slot.nullable());
+ producerToConsumerOutputMap.put(slot, newConsumerOutputSlot);
+ consumerToProducerOutputMap.put(newConsumerOutputSlot, slot);
replaceMap.put(newConsumerOutputSlot.getExprId(),
newConsumerOutputSlot);
}
}
return cteConsumer.withTwoMaps(consumerToProducerOutputMap,
producerToConsumerOutputMap);
}
- private <T extends Expression> T updateExpression(T input, Map<ExprId,
Slot> replaceMap) {
- return (T) input.rewriteDownShortCircuit(e ->
e.accept(SlotReferenceReplacer.INSTANCE, replaceMap));
+ private <T extends Expression> Optional<T> updateExpression(Optional<T>
input,
+ Map<ExprId, Slot> replaceMap, boolean debugCheck) {
+ return input.isPresent() ? updateExpression(input.get(), replaceMap,
debugCheck) : Optional.empty();
}
- private <T extends Expression> List<T> updateExpressions(List<T> inputs,
Map<ExprId, Slot> replaceMap) {
+ private <T extends Expression> Optional<T> updateExpression(T input,
+ Map<ExprId, Slot> replaceMap, boolean debugCheck) {
+ AtomicBoolean changed = new AtomicBoolean(false);
+ Expression replaced = input.rewriteDownShortCircuit(e -> {
+ if (e instanceof SlotReference) {
+ SlotReference slotReference = (SlotReference) e;
+ Slot newSlotReference = slotReference;
+ Slot replacedSlot = replaceMap.get(slotReference.getExprId());
+ if (replacedSlot != null) {
+ if (replacedSlot.getDataType().isAggStateType()) {
+ if (slotReference.nullable() != replacedSlot.nullable()
+ ||
!slotReference.getDataType().equals(replacedSlot.getDataType())) {
+ // we must replace data type, because nested type
and agg state contains nullable
+ // of their children.
+ // TODO: remove if statement after we ensure be
constant folding do not change
+ // expr type at all.
+ changed.set(true);
+ newSlotReference =
slotReference.withNullableAndDataType(
+ replacedSlot.nullable(),
replacedSlot.getDataType());
+ }
+ } else if (slotReference.nullable() !=
replacedSlot.nullable()) {
+ changed.set(true);
+ newSlotReference =
slotReference.withNullable(replacedSlot.nullable());
+ }
+ }
+ // for join other conditions, debugCheck = false, for other
case, debugCheck is always true.
+ // Because join other condition use join output's nullable
attribute, outer join may check fail.
+ // At analyzed phase, the slot reference nullable may change,
for example, NormalRepeat may adjust some
+ // slot reference to nullable, after this rule, node above
repeat need adjust.
+ // so analyzed phase don't assert not-nullable -> nullable,
otherwise adjust plan above
+ // repeat may check fail.
+ if (!slotReference.nullable() && newSlotReference.nullable()
+ && !isAnalyzedPhase && debugCheck &&
ConnectContext.get() != null) {
+ if (ConnectContext.get().getSessionVariable().feDebug) {
+ throw new AnalysisException("AdjustNullable convert
slot " + slotReference
+ + " from not-nullable to nullable. You can
disable check by set fe_debug = false.");
+ } else {
+ LOG.warn("adjust nullable convert slot '" +
slotReference
+ + "' from not-nullable to nullable for query "
+ +
DebugUtil.printId(ConnectContext.get().queryId()));
+ }
+ }
+ return newSlotReference;
+ } else {
+ return e;
+ }
+ });
+ return changed.get() ? Optional.of((T) replaced) : Optional.empty();
+ }
+
+ private <T extends Expression> Optional<List<T>> updateExpressions(List<T>
inputs,
+ Map<ExprId, Slot> replaceMap, boolean debugCheck) {
ImmutableList.Builder<T> result =
ImmutableList.builderWithExpectedSize(inputs.size());
+ boolean changed = false;
for (T input : inputs) {
- result.add(updateExpression(input, replaceMap));
+ Optional<T> newInput = updateExpression(input, replaceMap,
debugCheck);
+ changed |= newInput.isPresent();
+ result.add(newInput.orElse(input));
}
- return result.build();
+ return changed ? Optional.of(result.build()) : Optional.empty();
}
- private <T extends Expression> Set<T> updateExpressions(Set<T> inputs,
Map<ExprId, Slot> replaceMap) {
+ private <T extends Expression> Optional<Set<T>> updateExpressions(Set<T>
inputs,
+ Map<ExprId, Slot> replaceMap, boolean debugCheck) {
+ boolean changed = false;
ImmutableSet.Builder<T> result =
ImmutableSet.builderWithExpectedSize(inputs.size());
for (T input : inputs) {
- result.add(updateExpression(input, replaceMap));
- }
- return result.build();
- }
-
- private static class SlotReferenceReplacer extends
DefaultExpressionRewriter<Map<ExprId, Slot>> {
- public static SlotReferenceReplacer INSTANCE = new
SlotReferenceReplacer();
-
- @Override
- public Expression visitSlotReference(SlotReference slotReference,
Map<ExprId, Slot> context) {
- if (context.containsKey(slotReference.getExprId())) {
- Slot slot = context.get(slotReference.getExprId());
- if (slot.getDataType().isAggStateType()) {
- // we must replace data type, because nested type and agg
state contains nullable of their children.
- // TODO: remove if statement after we ensure be constant
folding do not change expr type at all.
- return
slotReference.withNullableAndDataType(slot.nullable(), slot.getDataType());
- } else {
- return slotReference.withNullable(slot.nullable());
- }
- } else {
- return slotReference;
- }
+ Optional<T> newInput = updateExpression(input, replaceMap,
debugCheck);
+ changed |= newInput.isPresent();
+ result.add(newInput.orElse(input));
}
+ return changed ? Optional.of(result.build()) : Optional.empty();
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
index 9325607dd70..4a749dd0b1d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
@@ -72,10 +72,8 @@ public class EliminateGroupBy extends OneRewriteRuleFactory {
}
}
List<NamedExpression> outputExpressions =
agg.getOutputExpressions();
-
ImmutableList.Builder<NamedExpression> newOutput
=
ImmutableList.builderWithExpectedSize(outputExpressions.size());
-
for (NamedExpression ne : outputExpressions) {
if (ne instanceof Alias && ne.child(0) instanceof
AggregateFunction) {
AggregateFunction f = (AggregateFunction)
ne.child(0);
@@ -84,8 +82,7 @@ public class EliminateGroupBy extends OneRewriteRuleFactory {
.castIfNotSameType(f.child(0),
f.getDataType()), ne.getName()));
} else if (f instanceof Count) {
newOutput.add((NamedExpression)
ne.withChildren(
- new If(new IsNull(f.child(0)), new
BigIntLiteral(0),
- new BigIntLiteral(1))));
+ ifNullElse(f.child(0), new
BigIntLiteral(0), new BigIntLiteral(1))));
} else {
throw new IllegalStateException("Unexpected
aggregate function: " + f);
}
@@ -96,4 +93,8 @@ public class EliminateGroupBy extends OneRewriteRuleFactory {
return PlanUtils.projectOrSelf(newOutput.build(), child);
}).toRule(RuleType.ELIMINATE_GROUP_BY);
}
+
+ private Expression ifNullElse(Expression conditionExpr, Expression ifExpr,
Expression elseExpr) {
+ return conditionExpr.nullable() ? new If(new IsNull(conditionExpr),
ifExpr, elseExpr) : elseExpr;
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
index 44c24a3a004..b8e2ab6a003 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
@@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
@@ -25,7 +26,7 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@@ -33,9 +34,11 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import org.apache.doris.nereids.util.JoinUtils;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.List;
@@ -79,24 +82,46 @@ public class EliminateJoinByFK extends
OneRewriteRuleFactory {
}
Set<Slot> output = project.getInputSlots();
Set<Slot> foreignKeys = Sets.intersection(foreign.getOutputSet(),
equalSet.getAllItemSet());
- Map<Expression, Expression> outputToForeign =
- tryMapOutputToForeignPlan(foreign, output, equalSet);
+ Map<Slot, Slot> outputToForeign = tryMapOutputToForeignPlan(foreign,
output, equalSet);
if (outputToForeign != null) {
+ Pair<Plan, Set<Slot>> newChildPair =
applyNullCompensationFilter(foreign, foreignKeys);
+ Map<Slot, Expression> replacedSlots =
getReplaceSlotMap(outputToForeign, newChildPair.second);
List<NamedExpression> newProjects = project.getProjects().stream()
- .map(e -> outputToForeign.containsKey(e)
- ? new Alias(e.getExprId(), outputToForeign.get(e),
e.toSql())
- : (NamedExpression) e.rewriteUp(s ->
outputToForeign.getOrDefault(s, s)))
+ .map(e -> replacedSlots.containsKey(e)
+ ? new Alias(e.getExprId(), replacedSlots.get(e),
e.toSql())
+ : (NamedExpression) e.rewriteUp(s ->
replacedSlots.getOrDefault(s, s)))
.collect(ImmutableList.toImmutableList());
- return project.withProjects(newProjects)
- .withChildren(applyNullCompensationFilter(foreign,
foreignKeys));
+ return
project.withProjects(newProjects).withChildren(newChildPair.first);
}
return project;
}
- private @Nullable Map<Expression, Expression>
tryMapOutputToForeignPlan(Plan foreignPlan,
+ /**
+ * get replace slots, include replace the primary slots and replace the
nullable foreign slots.
+ * @param outputToForeign primary slot to foreign slot map
+ * @param compensationForeignSlots foreign slots which are nullable but
add a filter 'slot is not null'
+ * @return the replaced map, include primary slot to foreign slot, and
foreign nullable slot to non-nullable(slot)
+ */
+ @VisibleForTesting
+ public Map<Slot, Expression> getReplaceSlotMap(Map<Slot, Slot>
outputToForeign,
+ Set<Slot> compensationForeignSlots) {
+ Map<Slot, Expression> replacedSlots = Maps.newHashMap();
+ for (Map.Entry<Slot, Slot> entry : outputToForeign.entrySet()) {
+ Slot forgeinSlot = entry.getValue();
+ Expression replacedExpr =
compensationForeignSlots.contains(forgeinSlot)
+ ? new NonNullable(forgeinSlot) : forgeinSlot;
+ replacedSlots.put(entry.getKey(), replacedExpr);
+ }
+ for (Slot forgeinSlot : compensationForeignSlots) {
+ replacedSlots.put(forgeinSlot, new NonNullable(forgeinSlot));
+ }
+ return replacedSlots;
+ }
+
+ private @Nullable Map<Slot, Slot> tryMapOutputToForeignPlan(Plan
foreignPlan,
Set<Slot> output, ImmutableEqualSet<Slot> equalSet) {
Set<Slot> residualPrimary = Sets.difference(output,
foreignPlan.getOutputSet());
- ImmutableMap.Builder<Expression, Expression> builder = new
ImmutableMap.Builder<>();
+ ImmutableMap.Builder<Slot, Slot> builder = new
ImmutableMap.Builder<>();
for (Slot primarySlot : residualPrimary) {
Optional<Slot> replacedForeign =
equalSet.calEqualSet(primarySlot).stream()
.filter(foreignPlan.getOutputSet()::contains)
@@ -109,14 +134,20 @@ public class EliminateJoinByFK extends
OneRewriteRuleFactory {
return builder.build();
}
- private Plan applyNullCompensationFilter(Plan child, Set<Slot> childSlots)
{
- Set<Expression> predicates = childSlots.stream()
- .filter(ExpressionTrait::nullable)
- .map(s -> new Not(new IsNull(s)))
- .collect(ImmutableSet.toImmutableSet());
- if (predicates.isEmpty()) {
- return child;
+ /**
+ * add a filter for foreign slots which is nullable, the filter is 'slot
is not null'
+ */
+ @VisibleForTesting
+ public Pair<Plan, Set<Slot>> applyNullCompensationFilter(Plan child,
Set<Slot> childSlots) {
+ ImmutableSet.Builder<Expression> predicatesBuilder =
ImmutableSet.builder();
+ Set<Slot> filterNotNullSlots = Sets.newHashSet();
+ for (Slot slot : childSlots) {
+ if (slot.nullable()) {
+ filterNotNullSlots.add(slot);
+ predicatesBuilder.add(new Not(new IsNull(slot)));
+ }
}
- return new LogicalFilter<>(predicates, child);
+ Plan newChild = filterNotNullSlots.isEmpty() ? child : new
LogicalFilter<>(predicatesBuilder.build(), child);
+ return Pair.of(newChild, filterNotNullSlots);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index d96dd8a15c2..9bc7fbfd5e1 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.properties.LogicalProperties;
+import
org.apache.doris.nereids.rules.analysis.AdjustAggregateNullableForEmptySet;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@@ -143,7 +144,11 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
CHILD_TYPE child) {
super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties,
child);
this.groupByExpressions = ImmutableList.copyOf(groupByExpressions);
- this.outputExpressions = ImmutableList.copyOf(outputExpressions);
+ boolean noGroupby = groupByExpressions.isEmpty();
+ ImmutableList.Builder<NamedExpression> builder =
ImmutableList.builder();
+ outputExpressions.forEach(output -> builder.add(
+ (NamedExpression)
AdjustAggregateNullableForEmptySet.replaceExpression(output, noGroupby)));
+ this.outputExpressions = builder.build();
this.normalized = normalized;
this.ordinalIsResolved = ordinalIsResolved;
this.generated = generated;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
index 5d21b1156f2..e2c7194fba9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
@@ -184,11 +184,44 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan,
RIGHT_CHILD_TYPE extends
if (markJoinSlotReference.isPresent()) {
builder.add(markJoinSlotReference.get());
}
+ // only scalar apply can be needAddSubOutputToProjects = true
if (needAddSubOutputToProjects) {
- if (isScalar()) {
-
builder.add(ScalarSubquery.getScalarQueryOutputAdjustNullable(right(),
correlationSlot));
+ // correlated apply right child may contain multiple output slots
+ // in rule ScalarApplyToJoin, only '(isCorrelated() &&
correlationFilter.isPresent())'
+ // but at analyzed phase, the correlationFilter is empty, only
after rule UnCorrelatedApplyAggregateFilter
+ // correlationFilter will be set, so we skip check
correlationFilter here.
+ // correlated apply will change to a left outer join, then all the
right child output will be nullable.
+ if (isCorrelated()) {
+ // for sql:
+ // `select t1.a,
+ // (select if(sum(t2.a) > 10, count(t2.b), max(t2.c))
as k from t2 where t1.a = t2.a)
+ // from t1`,
+ // its plan is:
+ // LogicalProject(t1.a, if(sum(t2.a) > 10, count(t2.b),
max(t2.c)) as k)
+ // |-- LogicalProject(..., if(sum(t2.a > 10),
ifnull(count(t2.b), 0), max(t2.c)) as k)
+ // |-- LogicalApply(correlationSlot = [t1.a])
+ // |-- LogicalOlapScan(t1)
+ // |-- LogicalAggregate(output = [sum(t2.a),
count(t2.b), max(t2.c)])
+ for (Slot slot : right().getOutput()) {
+ // in fact some slots may non-nullable, like count.
+ // but after convert correlated apply to left outer join,
all the join right child's slots
+ // will become nullable, so we let all slots be nullable,
then they wouldn't change nullable
+ // even after convert to join.
+ builder.add(slot.toSlot().withNullable(true));
+ }
} else {
- builder.add(right().getOutput().get(0));
+ // uncorrelated apply right child always contains one output
slot.
+ // for sql:
+ // `select t1.a,
+ // (select if(sum(t2.a) > 10, count(t2.b), max(t2.c))
as k from t2)
+ // from t1`,
+ // its plan is:
+ // LogicalProject(t1.a, k)
+ // |--LogicalApply(correlationSlot = [])
+ // |- LogicalOlapScan(t1)
+ // |- LogicalProject(if(sum(t2.a) > 10, count(t2.b),
max(t2.c)) as k)
+ // |-- LogicalAggregate(output = [sum(t2.a),
count(t2.b), max(t2.c)])
+
builder.add(ScalarSubquery.getScalarQueryOutputAdjustNullable(right(),
correlationSlot));
}
}
return builder.build();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java
index 680988b39f6..7dd227c6d6c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java
@@ -91,7 +91,7 @@ public class LogicalHaving<CHILD_TYPE extends Plan> extends
LogicalUnary<CHILD_T
return new LogicalHaving<>(conjuncts, groupExpression,
logicalProperties, children.get(0));
}
- public Plan withExpressions(Set<Expression> expressions) {
+ public Plan withConjuncts(Set<Expression> expressions) {
return new LogicalHaving<Plan>(expressions, Optional.empty(),
Optional.of(getLogicalProperties()), child());
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
index d0ce4bedc91..f73ee3425af 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
@@ -26,10 +26,12 @@ import
org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.types.BigIntType;
@@ -247,15 +249,15 @@ public class AnalyzeSubQueryTest extends
TestWithFeService implements MemoPatter
private void checkScalarSubquerySlotNullable(String sql, boolean
outputNullable) {
Plan root = PlanChecker.from(connectContext)
.analyze(sql)
- .applyTopDown(new LogicalSubQueryAliasToLogicalProject())
.getPlan();
List<LogicalProject<?>> projectList = Lists.newArrayList();
+ List<LogicalPlan> plansAboveApply = Lists.newArrayList();
root.foreach(plan -> {
if (plan instanceof LogicalProject && plan.child(0) instanceof
LogicalApply) {
projectList.add((LogicalProject<?>) plan);
- return true;
- } else {
- return false;
+ }
+ if (!(plan instanceof LogicalApply) && plan.anyMatch(p -> p
instanceof LogicalApply)) {
+ plansAboveApply.add((LogicalPlan) plan);
}
});
@@ -272,10 +274,26 @@ public class AnalyzeSubQueryTest extends
TestWithFeService implements MemoPatter
.findFirst().orElse(null);
Assertions.assertNotNull(output);
Assertions.assertEquals(outputNullable, output.nullable());
- output = apply.getOutput().stream()
+
+ Slot applySubqueySlot = apply.getOutput().stream()
.filter(e -> slotKName.contains(e.getName()))
.findFirst().orElse(null);
- Assertions.assertNotNull(output);
- Assertions.assertEquals(outputNullable, output.nullable());
+ Assertions.assertNotNull(applySubqueySlot);
+ if (apply.isCorrelated()) {
+ // apply will change to outer join
+ Assertions.assertTrue(applySubqueySlot.nullable());
+ } else {
+ Assertions.assertEquals(outputNullable,
applySubqueySlot.nullable());
+ }
+
+ for (LogicalPlan plan : plansAboveApply) {
+ Assertions.assertTrue(plan.getInputSlots().stream()
+ .filter(slot ->
slot.getExprId().equals(applySubqueySlot.getExprId()))
+ .allMatch(slot -> slot.nullable() ==
applySubqueySlot.nullable()));
+
+ Assertions.assertTrue(plan.getOutput().stream()
+ .filter(slot ->
slot.getExprId().equals(applySubqueySlot.getExprId()))
+ .allMatch(slot -> slot.nullable() ==
applySubqueySlot.nullable()));
+ }
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java
index 28993a66c5c..800590ec7d6 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java
@@ -194,7 +194,7 @@ public class AnalyzeWhereSubqueryTest extends
TestWithFeService implements MemoP
logicalAggregate().when(FieldChecker.check("outputExpressions",
ImmutableList.of(
new Alias(new ExprId(7), (new
Sum(
new SlotReference(new
ExprId(4), "k3", BigIntType.INSTANCE, true,
-
ImmutableList.of("test", "t7")))).withAlwaysNullable(true),
+
ImmutableList.of("test", "t7")))),
"sum(t7.k3)"),
new SlotReference(new
ExprId(6), "v2", BigIntType.INSTANCE, true,
ImmutableList.of("test", "t7"))
@@ -473,7 +473,7 @@ public class AnalyzeWhereSubqueryTest extends
TestWithFeService implements MemoP
logicalProject()
).when(FieldChecker.check("outputExpressions",
ImmutableList.of(
new Alias(new ExprId(8), (new Max(new
SlotReference(new ExprId(7), "aa", BigIntType.INSTANCE, true,
-
ImmutableList.of("t2")))).withAlwaysNullable(true), "max(aa)"),
+ ImmutableList.of("t2")))),
"max(aa)"),
new SlotReference(new ExprId(6), "v2",
BigIntType.INSTANCE, true,
ImmutableList.of("test",
"t7")))))
.when(FieldChecker.check("groupByExpressions", ImmutableList.of(
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
index 73844d7db67..2451bd3c46f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregateTest.java
@@ -19,9 +19,11 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@@ -29,8 +31,10 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.FieldChecker;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
@@ -41,10 +45,14 @@ import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
+import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
+import java.util.Arrays;
+import java.util.Collection;
import java.util.List;
+import java.util.stream.Collectors;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class NormalizeAggregateTest extends TestWithFeService implements
MemoPatternMatchSupported {
@@ -59,6 +67,15 @@ public class NormalizeAggregateTest extends
TestWithFeService implements MemoPat
createTables(
"CREATE TABLE IF NOT EXISTS t1 (\n"
+ " id int not null,\n"
+ + " no int not null,\n"
+ + " name char\n"
+ + ")\n"
+ + "DUPLICATE KEY(id)\n"
+ + "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ + "PROPERTIES (\"replication_num\" = \"1\")\n",
+ "CREATE TABLE IF NOT EXISTS t2 (\n"
+ + " id int not null,\n"
+ + " no int not null,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
@@ -302,4 +319,379 @@ public class NormalizeAggregateTest extends
TestWithFeService implements MemoPat
agg.getGroupByExpressions().size() == 1
&&
agg.getOutputExpressions().stream().anyMatch(e ->
e.toString().contains("COUNT"))));
}
+
+ @Test
+ void testAggFunctionNullabe() {
+ List<String> aggNullableSqls = ImmutableList.of(
+ // one row relation
+ "select sum(1) as k",
+
+ "select sum(id) as k from t1",
+ "select sum(id) as k from t1 where id > 10",
+
+ // sub query alias
+ "select * from (select sum(id) as k from t1) t",
+ "select * from (select sum(id) as k from t1 where id > 10) t",
+
+ // project sub query
+ "select id, (select sum(t2.id) as k from t2) from t1",
+ "select id, (select sum(t2.id) as k from t2 where t2.id > 10)
from t1",
+ "select id, (select sum(t2.id) as k from t2 where t1.id =
t2.id) from t1",
+
+ // filter sub query
+ "select * from t1 where t1.id > (select sum(t2.id) as k from
t2)",
+ "select * from t1 where t1.id > (select sum(t2.id) as k from
t2 where t2.id > 10)",
+ "select * from t1 where t1.id > (select sum(t2.id) as k from
t2 where t1.name = t2.name)"
+ );
+ for (String sql : aggNullableSqls) {
+ checkAggFunctionNullable(sql, true);
+ }
+
+ List<String> aggNotNullableSqls = ImmutableList.of(
+ "select sum(id) as k from t1 group by name",
+ "select sum(id) as k from t1 group by 'abcde' ",
+ "select sum(id) as k from t1 where id > 10 group by name",
+ "select sum(id) as k from t1 where id > 10 group by 'abcde' ",
+
+ // sub query alias
+ "select * from (select sum(id) as k from t1 group by name) t",
+ "select * from (select sum(id) as k from t1 group by 'abcde')
t",
+ "select * from (select sum(id) as k from t1 where id > 10
group by name) t",
+ "select * from (select sum(id) as k from t1 where id > 10
group by 'abcde') t"
+ );
+ for (String sql : aggNotNullableSqls) {
+ checkAggFunctionNullable(sql, false);
+ }
+ }
+
+ private void checkAggFunctionNullable(String sql, boolean nullable) {
+ List<LogicalAggregate<?>> aggList = Lists.newArrayList();
+ List<LogicalProject<?>> projectList = Lists.newArrayList();
+ List<LogicalApply<?, ?>> applyList = Lists.newArrayList();
+ List<LogicalPlan> planAboveApply = Lists.newArrayList();
+ List<LogicalPlan> planAboveAgg = Lists.newArrayList();
+ Plan root = PlanChecker.from(connectContext)
+ .analyze(sql).getPlan();
+ root.foreach(plan -> {
+ if (plan instanceof LogicalAggregate) {
+ aggList.add((LogicalAggregate<?>) plan);
+ } else if (plan instanceof LogicalProject) {
+ projectList.add((LogicalProject<?>) plan);
+ } else if (plan instanceof LogicalApply) {
+ applyList.add((LogicalApply<?, ?>) plan);
+ }
+
+ if (!(plan instanceof LogicalApply) && plan.anyMatch(p -> p
instanceof LogicalApply)) {
+ planAboveApply.add((LogicalPlan) plan);
+ }
+ if (!(plan instanceof LogicalAggregate)
+ && plan.anyMatch(p -> p instanceof LogicalAggregate)
+ && !(plan.anyMatch(p -> p instanceof LogicalApply))) {
+ planAboveAgg.add((LogicalPlan) plan);
+ }
+ });
+ List<String> slotKName = ImmutableList.of("k");
+
+ Assertions.assertEquals(1, aggList.size());
+ LogicalAggregate<?> agg = aggList.get(0);
+ NamedExpression slotK = agg.getOutputExpressions().stream()
+ .filter(output -> slotKName.contains(output.getName()))
+ .findFirst().orElse(null);
+ Assertions.assertNotNull(slotK);
+ Assertions.assertEquals(nullable, slotK.nullable());
+
+ Assertions.assertTrue(applyList.size() <= 1);
+ Slot applySlot = null;
+ if (applyList.size() == 1) {
+ LogicalApply<?, ?> apply = applyList.get(0);
+ applySlot = apply.getOutput().stream()
+ .filter(output ->
output.getExprId().equals(slotK.getExprId()))
+ .findFirst().orElse(null);
+ Assertions.assertNotNull(applySlot);
+ Assertions.assertTrue(applySlot.nullable());
+ }
+ for (LogicalProject<?> project : projectList) {
+ if (!project.anyMatch(plan -> plan instanceof LogicalAggregate)) {
+ continue;
+ }
+
+ NamedExpression expr = project.getProjects().stream()
+ .filter(output ->
output.getExprId().equals(slotK.getExprId()))
+ .findFirst().orElse(null);
+ if (expr == null) {
+ expr = project.getProjects().stream()
+ .map(output -> output instanceof Alias &&
output.child(0) instanceof SlotReference
+ ? (SlotReference) output.child(0) : output)
+ .filter(output ->
output.getExprId().equals(slotK.getExprId()))
+ .findFirst().orElse(null);
+ }
+ if (expr == null) {
+ continue;
+ }
+
+ boolean aboveApply = project.anyMatch(plan -> plan instanceof
LogicalApply);
+ if (aboveApply) {
+ Assertions.assertTrue(expr.nullable());
+ } else {
+ Assertions.assertEquals(nullable, expr.nullable());
+ }
+ }
+
+ if (applySlot != null) {
+ ExprId applySlotExprId = applySlot.getExprId();
+ boolean applySlotNullable = applySlot.nullable();
+ for (LogicalPlan plan : planAboveApply) {
+ Assertions.assertTrue(plan.getInputSlots().stream()
+ .filter(slot ->
slot.getExprId().equals(applySlotExprId))
+ .allMatch(slot -> slot.nullable() ==
applySlotNullable));
+ Assertions.assertTrue(plan.getOutput().stream()
+ .filter(slot ->
slot.getExprId().equals(applySlotExprId))
+ .allMatch(slot -> slot.nullable() ==
applySlotNullable));
+ }
+ }
+ for (LogicalPlan plan : planAboveAgg) {
+ ExprId kSlotExprId = slotK.getExprId();
+ boolean kSlotNullable = slotK.nullable();
+ Assertions.assertTrue(plan.getInputSlots().stream()
+ .filter(slot -> slot.getExprId().equals(kSlotExprId))
+ .allMatch(slot -> slot.nullable() == kSlotNullable));
+ Assertions.assertTrue(plan.getOutput().stream()
+ .filter(slot -> slot.getExprId().equals(kSlotExprId))
+ .allMatch(slot -> slot.nullable() == kSlotNullable));
+ }
+ }
+
+ @Test
+ void testAggFunctionNullabe2() {
+ PlanChecker.from(connectContext)
+ .analyze("select sum(id) from t1")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalAggregate().when(agg -> {
+ List<Slot> output =
agg.getOutput();
+ checkExprsToSql(output, "sum(id)");
+
Assertions.assertTrue(output.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select 1 from t1 having sum(id) > 10")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalFilter(
+ logicalProject(
+ logicalProject(
+
logicalAggregate().when(agg -> {
+ List<Slot> output
= agg.getOutput();
+
checkExprsToSql(output, "sum(id)");
+
Assertions.assertTrue(output.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression>
projects = project.getProjects();
+ checkExprsToSql(projects,
"sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "1 AS
`1`", "sum(id)");
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(filter -> {
+ List<Expression> conjuncts =
filter.getExpressions();
+ checkExprsToSql(conjuncts, "(sum(id) >
10)");
+
Assertions.assertTrue(conjuncts.get(0).child(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select sum(id), sum(no) from t1 having sum(id) > 10")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalProject(
+ logicalFilter(
+
logicalAggregate().when(agg -> {
+ List<Slot> output
= agg.getOutput();
+
checkExprsToSql(output, "sum(id)", "sum(no)");
+
Assertions.assertTrue(output.get(0).nullable());
+
Assertions.assertTrue(output.get(1).nullable());
+ return true;
+ })
+ ).when(filter -> {
+ List<Expression> conjuncts
= filter.getExpressions();
+ checkExprsToSql(conjuncts,
"(sum(id) > 10)");
+
Assertions.assertTrue(conjuncts.get(0).child(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects,
"sum(id)", "sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(id)",
"sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select sum(id), sum(no) from t1 order by sum(id)")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalSort(
+ logicalProject(
+
logicalAggregate().when(agg -> {
+ List<Slot> output =
agg.getOutput();
+
checkExprsToSql(output, "sum(id)", "sum(no)");
+
Assertions.assertTrue(output.get(0).nullable());
+
Assertions.assertTrue(output.get(1).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects
= project.getProjects();
+ checkExprsToSql(projects,
"sum(id)", "sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(sort -> {
+ List<? extends Expression> keys =
sort.getExpressions();
+ checkExprsToSql(keys, "sum(id)");
+
Assertions.assertTrue(keys.get(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select sum(no) from t1 order by sum(id)")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalSort(
+ logicalProject(
+
logicalAggregate().when(agg -> {
+ List<Slot> output =
agg.getOutput();
+
checkExprsToSql(output, "sum(no)", "sum(id)");
+
Assertions.assertTrue(output.get(0).nullable());
+
Assertions.assertTrue(output.get(1).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects
= project.getProjects();
+ checkExprsToSql(projects,
"sum(no)", "sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(sort -> {
+ List<? extends Expression> keys =
sort.getExpressions();
+ checkExprsToSql(keys, "sum(id)");
+
Assertions.assertTrue(keys.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ PlanChecker.from(connectContext)
+ .analyze("select sum(no) from t1 having sum(no) > 10 order by
sum(id)")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalSort(
+ logicalProject(
+ logicalProject(
+ logicalFilter(
+
logicalAggregate().when(agg -> {
+
List<Slot> output = agg.getOutput();
+
checkExprsToSql(output, "sum(no)", "sum(id)");
+
Assertions.assertTrue(output.get(0).nullable());
+
Assertions.assertTrue(output.get(1).nullable());
+
return true;
+ })
+ ).when(filter
-> {
+
List<Expression> conjuncts = filter.getExpressions();
+
checkExprsToSql(conjuncts, "(sum(no) > 10)");
+
Assertions.assertTrue(conjuncts.get(0).child(0).nullable());
+ return
true;
+ })
+ ).when(project -> {
+
List<NamedExpression> projects = project.getProjects();
+
checkExprsToSql(projects, "sum(no)", "sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression>
projects = project.getProjects();
+ checkExprsToSql(projects,
"sum(no)", "sum(id)");
+
Assertions.assertTrue(projects.get(0).nullable());
+
Assertions.assertTrue(projects.get(1).nullable());
+ return true;
+ })
+ ).when(sort -> {
+ List<? extends Expression> keys =
sort.getExpressions();
+ checkExprsToSql(keys, "sum(id)");
+
Assertions.assertTrue(keys.get(0).nullable());
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(no)");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ )
+ );
+
+ // a window function, not agg
+ PlanChecker.from(connectContext)
+ .analyze("select sum(1) over()")
+ .matchesFromRoot(
+ logicalResultSink(
+ logicalProject(
+ logicalOneRowRelation()
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ checkExprsToSql(projects, "sum(1) OVER()
AS `sum(1) over()`");
+
Assertions.assertTrue(projects.get(0).nullable());
+ return true;
+ })
+ ).when(sink -> {
+
Assertions.assertTrue(sink.getOutput().get(0).nullable());
+ return true;
+ })
+ );
+ }
+
+ private void checkExprsToSql(Collection<? extends Expression> expressions,
String... exprsToSql) {
+ Assertions.assertEquals(Arrays.asList(exprsToSql),
+
expressions.stream().map(Expression::toSql).collect(Collectors.toList()));
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java
index d7dad886ac3..443dbaebd8f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java
@@ -74,6 +74,7 @@ public class AggScalarSubQueryToWindowFunctionTest extends
TPCHTestBase implemen
@Test
public void testRuleOnTPCHTest() {
+ connectContext.getSessionVariable().feDebug = false;
check(TPCHUtils.Q2);
check(TPCHUtils.Q17);
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
index 71c1ccbfb7a..a0d9e5aba56 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
@@ -17,13 +17,34 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.NonNullable;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
+import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
class EliminateJoinByFkTest extends TestWithFeService implements
MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
@@ -113,15 +134,37 @@ class EliminateJoinByFkTest extends TestWithFeService
implements MemoPatternMatc
@Test
void testNull() throws Exception {
- String sql = "select pri.id1 from pri inner join foreign_null on
pri.id1 = foreign_null.id3";
+ String sql = "select pri.id1, 1 + foreign_null.id3 as k from pri inner
join foreign_null on pri.id1 = foreign_null.id3";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin())
- .matches(logicalFilter().when(f -> {
- Assertions.assertTrue(f.getPredicate().toSql().contains("(
not id3 IS NULL)"));
- return true;
- }))
+ .matches(
+ logicalResultSink(
+ logicalProject(
+ logicalFilter().when(f -> {
+
Assertions.assertTrue(f.getPredicate().toSql().contains("( not id3 IS NULL)"));
+ return true;
+ })
+ ).when(project -> {
+ List<NamedExpression> projects =
project.getProjects();
+ Assertions.assertEquals(2, projects.size());
+ Assertions.assertEquals("non_nullable(id3) AS
`id1`", projects.get(0).toSql());
+ Assertions.assertEquals("(non_nullable(id3) + 1)
AS `k`", projects.get(1).toSql());
+ Assertions.assertFalse(projects.get(0).nullable());
+ Assertions.assertFalse(projects.get(1).nullable());
+ return true;
+ })
+ ).when(sink -> {
+ List<NamedExpression> projects = sink.getOutputExprs();
+ Assertions.assertEquals(2, projects.size());
+ Assertions.assertEquals("id1",
projects.get(0).toSql());
+ Assertions.assertFalse(projects.get(0).nullable());
+ Assertions.assertEquals("k", projects.get(1).toSql());
+ Assertions.assertFalse(projects.get(1).nullable());
+ return true;
+ })
+ )
.printlnTree();
sql = "select foreign_null.id3 from pri inner join foreign_null on
pri.id1 = foreign_null.id3";
PlanChecker.from(connectContext)
@@ -205,4 +248,47 @@ class EliminateJoinByFkTest extends TestWithFeService
implements MemoPatternMatc
.rewrite()
.matches(logicalOlapScan().when(scan ->
scan.getTable().getName().equals("pri")));
}
+
+ @Test
+ void testReplaceMap() {
+ Slot a = new SlotReference("a", IntegerType.INSTANCE);
+ Slot b = new SlotReference("b", IntegerType.INSTANCE);
+ Slot x = new SlotReference("x", IntegerType.INSTANCE);
+ Slot y = new SlotReference("y", IntegerType.INSTANCE);
+ Slot z = new SlotReference("z", IntegerType.INSTANCE);
+ Map<Slot, Slot> outputToForeign = Maps.newHashMap();
+ outputToForeign.put(a, x);
+ outputToForeign.put(b, y);
+
+ Set<Slot> compensationForeignSlots = Sets.newHashSet();
+ compensationForeignSlots.add(x);
+ compensationForeignSlots.add(z);
+
+ Map<Slot, Expression> replacedSlots = new
EliminateJoinByFK().getReplaceSlotMap(outputToForeign,
compensationForeignSlots);
+ Map<Slot, Expression> expectedReplacedSlots = Maps.newHashMap();
+ expectedReplacedSlots.put(a, new NonNullable(x));
+ expectedReplacedSlots.put(b, y);
+ expectedReplacedSlots.put(x, new NonNullable(x));
+ expectedReplacedSlots.put(z, new NonNullable(z));
+ Assertions.assertEquals(expectedReplacedSlots, replacedSlots);
+ }
+
+ @Test
+ void testyNullCompensationFilter() {
+ EliminateJoinByFK instance = new EliminateJoinByFK();
+ SlotReference notNull1 = new SlotReference("notNull1",
IntegerType.INSTANCE, false);
+ SlotReference notNull2 = new SlotReference("notNull2",
IntegerType.INSTANCE, false);
+ SlotReference null1 = new SlotReference("null1", IntegerType.INSTANCE,
true);
+ SlotReference null2 = new SlotReference("null2", IntegerType.INSTANCE,
true);
+ LogicalOneRowRelation oneRowRelation = new LogicalOneRowRelation(new
RelationId(100), ImmutableList.of());
+ Pair<Plan, Set<Slot>> result1 =
instance.applyNullCompensationFilter(oneRowRelation, ImmutableSet.of(notNull1,
notNull2));
+ Assertions.assertEquals(ImmutableSet.of(), result1.second);
+ Assertions.assertEquals(oneRowRelation, result1.first);
+ Pair<Plan, Set<Slot>> result2 =
instance.applyNullCompensationFilter(oneRowRelation, ImmutableSet.of(notNull1,
notNull2, null1, null2));
+ Assertions.assertEquals(ImmutableSet.of(null1, null2), result2.second);
+ LogicalFilter<?> expectFilter = new LogicalFilter<>(
+ ImmutableSet.of(new Not(new IsNull(null1)), new Not(new
IsNull(null2))),
+ oneRowRelation);
+ Assertions.assertEquals(expectFilter, result2.first);
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApplyTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApplyTest.java
index 3e1c4d58c40..6cb8f9afb8a 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApplyTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApplyTest.java
@@ -56,6 +56,8 @@ class PullUpProjectUnderApplyTest extends TestWithFeService
implements MemoPatte
@Test
void testPullUpProjectUnderApply() {
+ connectContext.getSessionVariable().feDebug = false;
+
List<String> testSql = ImmutableList.of(
"select * from T as T1 where id = (select max(id) from T as T2
where T1.score = T2.score)",
"select * from T as T1 where id = (select max(id) + 1 from T
as T2 where T1.score = T2.score)"
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java
index c769e1c210c..ac18dddaa62 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java
@@ -78,6 +78,7 @@ class SelectMvIndexTest extends
BaseMaterializedIndexSelectTest implements MemoP
useDatabase(HR_DB_NAME);
connectContext.getSessionVariable().enableNereidsTimeout = false;
connectContext.getSessionVariable().setEnableSyncMvCostBasedRewrite(false);
+ connectContext.getSessionVariable().feDebug = false;
}
@BeforeEach
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregateTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregateTest.java
new file mode 100644
index 00000000000..eb6e3ea5fc8
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregateTest.java
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.plans.logical;
+
+import org.apache.doris.nereids.properties.OrderKey;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.OrderExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.expressions.WindowExpression;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
+import org.apache.doris.nereids.types.IntegerType;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class LogicalAggregateTest {
+
+ @Test
+ void testAdjustAggNullableWithEmptyGroupBy() {
+ SlotReference a = new SlotReference("a", IntegerType.INSTANCE, false);
+ SlotReference b = new SlotReference("b", IntegerType.INSTANCE, false);
+
+ LogicalOneRowRelation oneRowRelation = new
LogicalOneRowRelation(StatementScopeIdGenerator.newRelationId(),
+ ImmutableList.of(a, b));
+
+ // agg with empty group by
+ NamedExpression originOutput1 = new Alias(new Add(new Sum(a), new
IntegerLiteral(1)));
+ NamedExpression originOutput2 = new Alias(new WindowExpression(
+ new Sum(false, true, new Add(new Sum(b), new
IntegerLiteral(1))),
+ ImmutableList.of(a),
+ ImmutableList.of(new OrderExpression(new OrderKey(b, true,
true)))));
+ Assertions.assertFalse(originOutput1.nullable());
+ LogicalAggregate<LogicalOneRowRelation> agg = new LogicalAggregate<>(
+ ImmutableList.of(), ImmutableList.of(originOutput1,
originOutput2), oneRowRelation);
+ NamedExpression output1 = agg.getOutputs().get(0);
+ NamedExpression output2 = agg.getOutputs().get(1);
+ Assertions.assertNotEquals(originOutput1, output1);
+ Assertions.assertNotEquals(originOutput2, output2);
+ Assertions.assertTrue(output1.nullable());
+ Expression expectOutput1Child = new Add(new Sum(false, true, a), new
IntegerLiteral(1));
+ Expression expectOutput2Child = new WindowExpression(
+ new Sum(false, true, new Add(new Sum(false, true, b), new
IntegerLiteral(1))),
+ ImmutableList.of(a),
+ ImmutableList.of(new OrderExpression(new OrderKey(b, true,
true))));
+ Assertions.assertEquals(expectOutput1Child, output1.child(0));
+ Assertions.assertEquals(expectOutput2Child, output2.child(0));
+ }
+
+ @Test
+ void testAdjustAggNullableWithNotEmptyGroupBy() {
+ SlotReference a = new SlotReference("a", IntegerType.INSTANCE, false);
+ SlotReference b = new SlotReference("b", IntegerType.INSTANCE, false);
+
+ LogicalOneRowRelation oneRowRelation = new
LogicalOneRowRelation(StatementScopeIdGenerator.newRelationId(),
+ ImmutableList.of(a, b));
+
+ // agg with not empty group by
+ NamedExpression originOutput1 = new Alias(new Add(new Sum(false, true,
a), new IntegerLiteral(1)));
+ NamedExpression originOutput2 = new Alias(new WindowExpression(
+ new Sum(false, true, new Add(new Sum(false, true, b), new
IntegerLiteral(1))),
+ ImmutableList.of(a),
+ ImmutableList.of(new OrderExpression(new OrderKey(b, true,
true)))));
+ Assertions.assertTrue(originOutput1.nullable());
+ LogicalAggregate<LogicalOneRowRelation> agg = new LogicalAggregate<>(
+ ImmutableList.of(new TinyIntLiteral((byte) 1)),
ImmutableList.of(originOutput1, originOutput2), oneRowRelation);
+ NamedExpression output1 = agg.getOutputs().get(0);
+ NamedExpression output2 = agg.getOutputs().get(1);
+ Assertions.assertNotEquals(originOutput1, output1);
+ Assertions.assertNotEquals(originOutput2, output2);
+ Assertions.assertFalse(output1.nullable());
+ Expression expectOutput1Child = new Add(new Sum(false, false, a), new
IntegerLiteral(1));
+ Expression expectOutput2Child = new WindowExpression(
+ new Sum(false, true, new Add(new Sum(false, false, b), new
IntegerLiteral(1))),
+ ImmutableList.of(a),
+ ImmutableList.of(new OrderExpression(new OrderKey(b, true,
true))));
+ Assertions.assertEquals(expectOutput1Child, output1.child(0));
+ Assertions.assertEquals(expectOutput2Child, output2.child(0));
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
index 5e2572a2ab8..95d9a4e0e98 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
@@ -160,6 +160,7 @@ public abstract class TestWithFeService {
FeConstants.enableInternalSchemaDb = false;
beforeCreatingConnectContext();
connectContext = createDefaultCtx();
+ connectContext.getSessionVariable().feDebug = true;
beforeCluster();
createDorisCluster();
runBeforeAll();
diff --git
a/regression-test/data/nereids_rules_p0/adjust_nullable/test_agg_nullable.out
b/regression-test/data/nereids_rules_p0/adjust_nullable/test_agg_nullable.out
new file mode 100644
index 00000000000..7ec09edaf5d
Binary files /dev/null and
b/regression-test/data/nereids_rules_p0/adjust_nullable/test_agg_nullable.out
differ
diff --git
a/regression-test/suites/nereids_rules_p0/adjust_nullable/test_agg_nullable.groovy
b/regression-test/suites/nereids_rules_p0/adjust_nullable/test_agg_nullable.groovy
new file mode 100644
index 00000000000..60e3342ea0f
--- /dev/null
+++
b/regression-test/suites/nereids_rules_p0/adjust_nullable/test_agg_nullable.groovy
@@ -0,0 +1,30 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite('test_agg_nullable') {
+ sql 'DROP TABLE IF EXISTS test_agg_nullable_t1 FORCE'
+ sql "CREATE TABLE test_agg_nullable_t1(a int not null, b int not null, c
int not null) distributed by hash(a) properties('replication_num' = '1')"
+ sql "SET detail_shape_nodes='PhysicalProject'"
+ order_qt_agg_nullable '''
+ select k > 10 and k < 5 from (select sum(a) as k from
test_agg_nullable_t1) s
+ '''
+ qt_agg_nullable_shape '''explain shape plan
+ select k > 10 and k < 5 from (select sum(a) as k from
test_agg_nullable_t1) s
+ '''
+ sql 'DROP TABLE IF EXISTS test_agg_nullable_t1 FORCE'
+}
+
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]