This is an automated email from the ASF dual-hosted git repository. kxiao pushed a commit to branch branch-2.0 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 59fe081a8296d26e96943a06ebec062aefcda7b9 Author: starocean999 <[email protected]> AuthorDate: Mon Aug 21 15:38:22 2023 +0800 [fix](nereids)scalar subquery shouldn't be used in mark join (#22907) * [fix](nereids)scalar subquery shouldn't be used in mark join --- .../nereids/rules/analysis/SubqueryToApply.java | 132 +++------------------ .../join/LogicalJoinSemiJoinTranspose.java | 6 +- .../join/LogicalJoinSemiJoinTransposeProject.java | 7 +- .../join/SemiJoinSemiJoinTranspose.java | 2 +- .../rewrite/AggScalarSubQueryToWindowFunction.java | 25 ++-- .../nereids/rules/rewrite/ExistsApplyToJoin.java | 19 +-- .../doris/nereids/rules/rewrite/InApplyToJoin.java | 3 - ...CorrelatedFilterUnderApplyAggregateProject.java | 4 +- .../rules/rewrite/PullUpProjectUnderApply.java | 4 +- .../rewrite/PushdownFilterThroughProject.java | 36 +++--- .../nereids/rules/rewrite/ScalarApplyToJoin.java | 17 +-- .../rewrite/UnCorrelatedApplyAggregateFilter.java | 4 +- .../rules/rewrite/UnCorrelatedApplyFilter.java | 4 +- .../rewrite/UnCorrelatedApplyProjectFilter.java | 4 +- .../trees/copier/LogicalPlanDeepCopier.java | 4 +- .../nereids/trees/plans/logical/LogicalApply.java | 48 ++++---- .../rules/analysis/AnalyzeWhereSubqueryTest.java | 62 +++++----- .../AggScalarSubQueryToWindowFunctionTest.java | 3 + .../nereids_syntax_p0/sub_query_correlated.out | 3 - .../nereids_tpcds_shape_sf100_p0/shape/query1.out | 18 +-- .../nereids_tpcds_shape_sf100_p0/shape/query30.out | 15 +-- .../nereids_tpcds_shape_sf100_p0/shape/query41.out | 2 +- .../nereids_tpcds_shape_sf100_p0/shape/query6.out | 88 +++++++------- .../nereids_tpcds_shape_sf100_p0/shape/query81.out | 15 +-- .../nereids_tpch_shape_sf1000_p0/shape/q20.out | 2 +- .../data/nereids_tpch_shape_sf500_p0/shape/q20.out | 2 +- .../nereids_syntax_p0/sub_query_correlated.groovy | 6 +- .../shape/query6.groovy | 3 + 28 files changed, 223 insertions(+), 315 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index 6b89d02782..6dfe95c116 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -42,7 +42,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -100,24 +99,21 @@ public class SubqueryToApply implements AnalysisRuleFactory { applyPlan = subqueryToApply(subqueryExprs.stream() .collect(ImmutableList.toImmutableList()), tmpPlan, context.getSubqueryToMarkJoinSlot(), - context.getSubqueryCorrespondingConjunct(), ctx.cascadesContext, + ctx.cascadesContext, Optional.of(conjunct), false); tmpPlan = applyPlan; - if (!(subqueryExprs.size() == 1 - && subqueryExprs.stream().anyMatch(ScalarSubquery.class::isInstance))) { - newConjuncts.add(conjunct); - } + newConjuncts.add(conjunct); } - Set<Expression> conjects = new LinkedHashSet<>(); - conjects.addAll(newConjuncts.build()); - Plan newFilter = new LogicalFilter<>(conjects, applyPlan); - if (conjects.stream().flatMap(c -> c.children().stream()) + Set<Expression> conjuncts = new LinkedHashSet<>(); + conjuncts.addAll(newConjuncts.build()); + Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan); + if (conjuncts.stream().flatMap(c -> c.children().stream()) .anyMatch(MarkJoinSlotReference.class::isInstance)) { return new LogicalProject<>(applyPlan.getOutput().stream() .filter(s -> !(s instanceof MarkJoinSlotReference)) .collect(ImmutableList.toImmutableList()), newFilter); } - return new LogicalFilter<>(conjects, applyPlan); + return new LogicalFilter<>(conjuncts, applyPlan); }) ), RuleType.PROJECT_SUBQUERY_TO_APPLY.build( @@ -144,7 +140,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { subqueryToApply( subqueryExprs.stream().collect(ImmutableList.toImmutableList()), (LogicalPlan) project.child(), - context.getSubqueryToMarkJoinSlot(), context.getSubqueryCorrespondingConjunct(), + context.getSubqueryToMarkJoinSlot(), ctx.cascadesContext, Optional.empty(), true )); @@ -155,7 +151,6 @@ public class SubqueryToApply implements AnalysisRuleFactory { private LogicalPlan subqueryToApply(List<SubqueryExpr> subqueryExprs, LogicalPlan childPlan, Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot, - Map<SubqueryExpr, Expression> subqueryCorrespondingConject, CascadesContext ctx, Optional<Expression> conjunct, boolean isProject) { LogicalPlan tmpPlan = childPlan; @@ -167,7 +162,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { if (!ctx.subqueryIsAnalyzed(subqueryExpr)) { tmpPlan = addApply(subqueryExpr, tmpPlan, - subqueryToMarkJoinSlot, subqueryCorrespondingConject, ctx, conjunct, + subqueryToMarkJoinSlot, ctx, conjunct, isProject, subqueryExprs.size() == 1); } } @@ -183,19 +178,16 @@ public class SubqueryToApply implements AnalysisRuleFactory { private LogicalPlan addApply(SubqueryExpr subquery, LogicalPlan childPlan, Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot, - Map<SubqueryExpr, Expression> subqueryCorrespondingConject, CascadesContext ctx, Optional<Expression> conjunct, boolean isProject, boolean singleSubquery) { ctx.setSubqueryExprIsAnalyzed(subquery, true); - boolean needAddSubOutputToProjects = isScalarAndFilterContainsSubqueryOutput( + boolean needAddScalarSubqueryOutputToProjects = isConjunctContainsScalarSubqueryOutput( subquery, conjunct, isProject, singleSubquery); LogicalApply newApply = new LogicalApply( subquery.getCorrelateSlots(), subquery, Optional.empty(), subqueryToMarkJoinSlot.get(subquery), - mergeScalarSubConjunctAndFilterConjunct( - subquery, subqueryCorrespondingConject, - conjunct, needAddSubOutputToProjects, singleSubquery), isProject, + needAddScalarSubqueryOutputToProjects, isProject, childPlan, subquery.getQueryPlan()); List<NamedExpression> projects = ImmutableList.<NamedExpression>builder() @@ -205,57 +197,21 @@ public class SubqueryToApply implements AnalysisRuleFactory { .addAll(subqueryToMarkJoinSlot.get(subquery).isPresent() ? ImmutableList.of(subqueryToMarkJoinSlot.get(subquery).get()) : ImmutableList.of()) // scalarSubquery output - .addAll(needAddSubOutputToProjects + .addAll(needAddScalarSubqueryOutputToProjects ? ImmutableList.of(subquery.getQueryPlan().getOutput().get(0)) : ImmutableList.of()) .build(); return new LogicalProject(projects, newApply); } - private boolean checkSingleScalarWithOr(SubqueryExpr subquery, - Optional<Expression> conjunct) { - return subquery instanceof ScalarSubquery - && conjunct.isPresent() && conjunct.get() instanceof Or - && subquery.getCorrelateSlots().isEmpty(); - } - - private boolean isScalarAndFilterContainsSubqueryOutput( + private boolean isConjunctContainsScalarSubqueryOutput( SubqueryExpr subqueryExpr, Optional<Expression> conjunct, boolean isProject, boolean singleSubquery) { return subqueryExpr instanceof ScalarSubquery - && ((!singleSubquery && conjunct.isPresent() - && ((ImmutableSet) conjunct.get().collect(SlotReference.class::isInstance)) + && ((conjunct.isPresent() && ((ImmutableSet) conjunct.get().collect(SlotReference.class::isInstance)) .contains(subqueryExpr.getQueryPlan().getOutput().get(0))) || isProject); } - /** - * For a single scalarSubQuery, when there is a disjunction, - * directly use all connection conditions as the join conjunct of scalarSubQuery. - * e.g. - * select * from t1 where k1 > scalarSub(sum(c1)) or k2 > 10; - * LogicalJoin(otherConjunct[k1 > sum(c1) or k2 > 10]) - * - * For other scalarSubQuery, you only need to use the connection as the join conjunct. - * e.g. - * select * from t1 where k1 > scalarSub(sum(c1)) or k2 in inSub(c2) or k2 > 10; - * LogicalFilter($c$1 or $c$2 or k2 > 10) - * LogicalJoin(otherConjunct[k2 = c2]) ---> inSub - * LogicalJoin(otherConjunct[k1 > sum(c1)]) ---> scalarSub - */ - private Optional<Expression> mergeScalarSubConjunctAndFilterConjunct( - SubqueryExpr subquery, - Map<SubqueryExpr, Expression> subqueryCorrespondingConject, - Optional<Expression> conjunct, - boolean isProject, - boolean singleSubquery) { - if (singleSubquery && checkSingleScalarWithOr(subquery, conjunct)) { - return conjunct; - } else if (subqueryCorrespondingConject.containsKey(subquery) && !isProject) { - return Optional.of(subqueryCorrespondingConject.get(subquery)); - } - return Optional.empty(); - } - /** * The Subquery in the LogicalFilter will change to LogicalApply, so we must replace the origin Subquery. * LogicalFilter(predicate(contain subquery)) -> LogicalFilter(predicate(not contain subquery) @@ -284,24 +240,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { this.isProject = isProject; } - public Set<Expression> replace(Set<Expression> expressions, SubqueryContext subqueryContext) { - return expressions.stream().map(expr -> expr.accept(this, subqueryContext)) - .collect(ImmutableSet.toImmutableSet()); - } - public Expression replace(Expression expression, SubqueryContext subqueryContext) { - Expression replacedExpr = doReplace(expression, subqueryContext); - if (subqueryContext.onlySingleSubquery() && !isMarkJoin) { - // if there is only one subquery and it's not a mark join, - // we can merge the filter with the join conjunct to eliminate the filter node - // to do that, we need update the subquery's corresponding conjunct use replacedExpr - // see mergeScalarSubConjunctAndFilterConjunct() for more info - subqueryContext.updateSubqueryCorrespondingConjunct(replacedExpr); - } - return replacedExpr; - } - - public Expression doReplace(Expression expression, SubqueryContext subqueryContext) { return expression.accept(this, subqueryContext); } @@ -336,19 +275,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { @Override public Expression visitScalarSubquery(ScalarSubquery scalar, SubqueryContext context) { - context.setSubqueryCorrespondingConject(scalar, scalar.getSubqueryOutput()); - // When there is only one scalarSubQuery and CorrelateSlots is empty - // it will not be processed by MarkJoin, so it can be returned directly - if (context.onlySingleSubquery() && scalar.getCorrelateSlots().isEmpty()) { - return scalar.getSubqueryOutput(); - } - - MarkJoinSlotReference markJoinSlotReference = - new MarkJoinSlotReference(statementContext.generateColumnName()); - if (isMarkJoin) { - context.setSubqueryToMarkJoinSlot(scalar, Optional.of(markJoinSlotReference)); - } - return isMarkJoin ? markJoinSlotReference : scalar.getSubqueryOutput(); + return scalar.getSubqueryOutput(); } @Override @@ -359,8 +286,8 @@ public class SubqueryToApply implements AnalysisRuleFactory { || binaryOperator.right().anyMatch(SubqueryExpr.class::isInstance)) && (binaryOperator instanceof Or)); - Expression left = doReplace(binaryOperator.left(), context); - Expression right = doReplace(binaryOperator.right(), context); + Expression left = replace(binaryOperator.left(), context); + Expression right = replace(binaryOperator.right(), context); return binaryOperator.withChildren(left, right); } @@ -372,7 +299,7 @@ public class SubqueryToApply implements AnalysisRuleFactory { * For inSubquery and exists: it will be directly replaced by markSlotReference * e.g. * logicalFilter(predicate=exists) ---> logicalFilter(predicate=$c$1) - * For scalarSubquery: it will be replaced by markSlotReference too + * For scalarSubquery: it will be replaced by scalarSubquery's output slot * e.g. * logicalFilter(predicate=k1 > scalarSubquery) ---> logicalFilter(predicate=k1 > $c$1) * @@ -384,11 +311,8 @@ public class SubqueryToApply implements AnalysisRuleFactory { private static class SubqueryContext { private final Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot; - private final Map<SubqueryExpr, Expression> subqueryCorrespondingConjunct; - public SubqueryContext(Set<SubqueryExpr> subqueryExprs) { this.subqueryToMarkJoinSlot = new LinkedHashMap<>(subqueryExprs.size()); - this.subqueryCorrespondingConjunct = new LinkedHashMap<>(subqueryExprs.size()); subqueryExprs.forEach(subqueryExpr -> subqueryToMarkJoinSlot.put(subqueryExpr, Optional.empty())); } @@ -396,31 +320,11 @@ public class SubqueryToApply implements AnalysisRuleFactory { return subqueryToMarkJoinSlot; } - private Map<SubqueryExpr, Expression> getSubqueryCorrespondingConjunct() { - return subqueryCorrespondingConjunct; - } - private void setSubqueryToMarkJoinSlot(SubqueryExpr subquery, Optional<MarkJoinSlotReference> markJoinSlotReference) { subqueryToMarkJoinSlot.put(subquery, markJoinSlotReference); } - private void setSubqueryCorrespondingConject(SubqueryExpr subquery, - Expression expression) { - subqueryCorrespondingConjunct.put(subquery, expression); - } - - private boolean onlySingleSubquery() { - return subqueryToMarkJoinSlot.size() == 1; - } - - private void updateSubqueryCorrespondingConjunct(Expression expression) { - Preconditions.checkState(onlySingleSubquery(), - "onlySingleSubquery must be true"); - subqueryCorrespondingConjunct - .forEach((k, v) -> subqueryCorrespondingConjunct.put(k, expression)); - } - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTranspose.java index 5cfcb3c0b3..0f0df1567b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTranspose.java @@ -43,7 +43,8 @@ public class LogicalJoinSemiJoinTranspose implements ExplorationRuleFactory { .when(topJoin -> (topJoin.left().getJoinType().isLeftSemiOrAntiJoin() && (topJoin.getJoinType().isInnerJoin() || topJoin.getJoinType().isLeftOuterJoin()))) - .whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().hasJoinHint()) + .whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().hasJoinHint() + || topJoin.left().isMarkJoin()) .whenNot(LogicalJoin::isMarkJoin) .then(topJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left(); @@ -59,7 +60,8 @@ public class LogicalJoinSemiJoinTranspose implements ExplorationRuleFactory { .when(topJoin -> (topJoin.right().getJoinType().isLeftSemiOrAntiJoin() && (topJoin.getJoinType().isInnerJoin() || topJoin.getJoinType().isRightOuterJoin()))) - .whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.right().hasJoinHint()) + .whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.right().hasJoinHint() + || topJoin.right().isMarkJoin()) .whenNot(LogicalJoin::isMarkJoin) .then(topJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java index 7f72dcd5b3..8d0ef36ef3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java @@ -43,7 +43,9 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto .when(topJoin -> (topJoin.left().child().getJoinType().isLeftSemiOrAntiJoin() && (topJoin.getJoinType().isInnerJoin() || topJoin.getJoinType().isLeftOuterJoin()))) - .whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().child().hasJoinHint()) + .whenNot(topJoin -> topJoin.hasJoinHint() + || topJoin.left().child().hasJoinHint() + || topJoin.left().child().isMarkJoin()) .whenNot(LogicalJoin::isMarkJoin) .when(join -> join.left().isAllSlots()) .then(topJoin -> { @@ -63,6 +65,9 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto .when(topJoin -> (topJoin.right().child().getJoinType().isLeftSemiOrAntiJoin() && (topJoin.getJoinType().isInnerJoin() || topJoin.getJoinType().isRightOuterJoin()))) + .whenNot(topJoin -> topJoin.hasJoinHint() + || topJoin.right().child().hasJoinHint() + || topJoin.right().child().isMarkJoin()) .when(join -> join.right().isAllSlots()) .then(topJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.right().child(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java index 7a58bbd6ae..12966e9ac8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java @@ -63,7 +63,7 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory { return logicalJoin(logicalJoin(), group()) .when(this::typeChecker) .whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint()) - .whenNot(LogicalJoin::isMarkJoin) + .whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin()) .then(topJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left(); GroupPlan a = bottomJoin.left(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunction.java index 9f588d4c5d..5b70d2d2f8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunction.java @@ -156,9 +156,7 @@ public class AggScalarSubQueryToWindowFunction extends DefaultPlanRewriter<JobCo return apply.isScalar() && !apply.isMarkJoin() && apply.right() instanceof LogicalAggregate - && apply.isCorrelated() - && apply.getSubCorrespondingConjunct().isPresent() - && apply.getSubCorrespondingConjunct().get() instanceof ComparisonPredicate; + && apply.isCorrelated(); } /** @@ -326,7 +324,17 @@ public class AggScalarSubQueryToWindowFunction extends DefaultPlanRewriter<JobCo // it's a simple case, but we may meet some complex cases in ut. // TODO: support compound predicate and multi apply node. - Expression windowFilterConjunct = apply.getSubCorrespondingConjunct().get(); + Map<Boolean, Set<Expression>> conjuncts = filter.getConjuncts().stream() + .collect(Collectors.groupingBy(conjunct -> Sets + .intersection(conjunct.getInputSlotExprIds(), agg.getOutputExprIdSet()) + .isEmpty(), Collectors.toSet())); + Set<Expression> correlatedConjuncts = conjuncts.get(false); + if (correlatedConjuncts.isEmpty() || correlatedConjuncts.size() > 1 + || !(correlatedConjuncts.iterator().next() instanceof ComparisonPredicate)) { + //TODO: only support simple comparison predicate now + return filter; + } + Expression windowFilterConjunct = correlatedConjuncts.iterator().next(); windowFilterConjunct = PlanUtils.maybeCommuteComparisonPredicate( (ComparisonPredicate) windowFilterConjunct, apply.left()); @@ -349,13 +357,10 @@ public class AggScalarSubQueryToWindowFunction extends DefaultPlanRewriter<JobCo aggOutExpr = ExpressionUtils.replace(aggOutExpr, ImmutableMap .of(functions.get(0), windowFunctionAlias.toSlot())); - // we change the child contains the original agg output to agg output expr. - // for comparison predicate, it is always the child(1), since we ensure the window agg slot is in child(0) - // for in predicate, we should extract the options and find the corresponding child. - windowFilterConjunct = windowFilterConjunct - .withChildren(windowFilterConjunct.child(0), aggOutExpr); + windowFilterConjunct = ExpressionUtils.replace(windowFilterConjunct, + ImmutableMap.of(aggOut.toSlot(), aggOutExpr)); - LogicalFilter<Plan> newFilter = (LogicalFilter<Plan>) filter.withChildren(apply.left()); + LogicalFilter<Plan> newFilter = filter.withConjunctsAndChild(conjuncts.get(true), apply.left()); LogicalWindow<Plan> newWindow = new LogicalWindow<>(ImmutableList.of(windowFunctionAlias), newFilter); LogicalFilter<Plan> windowFilter = new LogicalFilter<>(ImmutableSet.of(windowFilterConjunct), newWindow); return windowFilter; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExistsApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExistsApplyToJoin.java index 04fbc55340..f4b2cbc9a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExistsApplyToJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExistsApplyToJoin.java @@ -90,16 +90,7 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory { private Plan correlatedToJoin(LogicalApply apply) { Optional<Expression> correlationFilter = apply.getCorrelationFilter(); - Expression predicate = null; - if (correlationFilter.isPresent() && apply.getSubCorrespondingConjunct().isPresent()) { - predicate = ExpressionUtils.and(correlationFilter.get(), - (Expression) apply.getSubCorrespondingConjunct().get()); - } else if (apply.getSubCorrespondingConjunct().isPresent()) { - predicate = (Expression) apply.getSubCorrespondingConjunct().get(); - } else if (correlationFilter.isPresent()) { - predicate = correlationFilter.get(); - } - + Expression predicate = correlationFilter.get(); if (((Exists) apply.getSubqueryExpr()).isNot()) { return new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, ExpressionUtils.EMPTY_CONDITION, predicate != null @@ -133,9 +124,7 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory { LogicalAggregate newAgg = new LogicalAggregate<>(new ArrayList<>(), ImmutableList.of(alias), newLimit); LogicalJoin newJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION, - unapply.getSubCorrespondingConjunct().isPresent() - ? ExpressionUtils.extractConjunction((Expression) unapply.getSubCorrespondingConjunct().get()) - : ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, unapply.getMarkJoinSlotReference(), + ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, unapply.getMarkJoinSlotReference(), (LogicalPlan) unapply.left(), newAgg); return new LogicalFilter<>(ImmutableSet.of(new EqualTo(newAgg.getOutput().get(0), new IntegerLiteral(0))), newJoin); @@ -144,9 +133,7 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory { private Plan unCorrelatedExist(LogicalApply unapply) { LogicalLimit newLimit = new LogicalLimit<>(1, 0, LimitPhase.ORIGIN, (LogicalPlan) unapply.right()); return new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION, - unapply.getSubCorrespondingConjunct().isPresent() - ? ExpressionUtils.extractConjunction((Expression) unapply.getSubCorrespondingConjunct().get()) - : ExpressionUtils.EMPTY_CONDITION, + ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, unapply.getMarkJoinSlotReference(), (LogicalPlan) unapply.left(), newLimit); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InApplyToJoin.java index 64d8defcfe..ab8cbc40ed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InApplyToJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InApplyToJoin.java @@ -104,9 +104,6 @@ public class InApplyToJoin extends OneRewriteRuleFactory { predicate = new EqualTo(left, right); } - if (apply.getSubCorrespondingConjunct().isPresent()) { - predicate = ExpressionUtils.and(predicate, apply.getSubCorrespondingConjunct().get()); - } List<Expression> conjuncts = ExpressionUtils.extractConjunction(predicate); if (((InSubquery) apply.getSubqueryExpr()).isNot()) { return new LogicalJoin<>( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java index 4f20f58c59..21cf3ea7f8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpCorrelatedFilterUnderApplyAggregateProject.java @@ -79,8 +79,8 @@ public class PullUpCorrelatedFilterUnderApplyAggregateProject extends OneRewrite LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newFilter)); return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(), apply.getCorrelationFilter(), apply.getMarkJoinSlotReference(), - apply.getSubCorrespondingConjunct(), apply.isNeedAddSubOutputToProjects(), - apply.left(), newAgg); + apply.isNeedAddSubOutputToProjects(), + apply.isInProject(), apply.left(), newAgg); }).toRule(RuleType.PULL_UP_CORRELATED_FILTER_UNDER_APPLY_AGGREGATE_PROJECT); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApply.java index 43beb3b042..fbe4844f74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpProjectUnderApply.java @@ -57,8 +57,8 @@ public class PullUpProjectUnderApply extends OneRewriteRuleFactory { LogicalProject<Plan> project = apply.right(); LogicalApply newCorrelate = new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(), apply.getCorrelationFilter(), apply.getMarkJoinSlotReference(), - apply.getSubCorrespondingConjunct(), apply.isNeedAddSubOutputToProjects(), - apply.left(), project.child()); + apply.isNeedAddSubOutputToProjects(), + apply.isInProject(), apply.left(), project.child()); List<NamedExpression> newProjects = new ArrayList<>(); newProjects.addAll(apply.left().getOutput()); if (apply.getSubqueryExpr() instanceof ScalarSubquery) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java index 08827c7ca2..386f4a0119 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownFilterThroughProject.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; @@ -40,26 +41,27 @@ public class PushdownFilterThroughProject implements RewriteRuleFactory { @Override public List<Rule> buildRules() { return ImmutableList.of( - RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT.build( - logicalFilter(logicalProject()) - .then(PushdownFilterThroughProject::pushdownFilterThroughProject) - ), + RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT.build(logicalFilter(logicalProject()) + .whenNot(filter -> filter.child().getProjects().stream().anyMatch( + expr -> expr.anyMatch(WindowExpression.class::isInstance))) + .then(PushdownFilterThroughProject::pushdownFilterThroughProject)), // filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit, // then we should change filter(limit(project)) to project(filter(limit)) - RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT.build( - logicalFilter(logicalLimit(logicalProject())).then(filter -> { - LogicalLimit<LogicalProject<Plan>> limit = filter.child(); - LogicalProject<Plan> project = limit.child(); + RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT + .build(logicalFilter(logicalLimit(logicalProject())) + .whenNot(filter -> filter.child().child().getProjects().stream() + .anyMatch(expr -> expr + .anyMatch(WindowExpression.class::isInstance))) + .then(filter -> { + LogicalLimit<LogicalProject<Plan>> limit = filter.child(); + LogicalProject<Plan> project = limit.child(); - return project.withProjectsAndChild( - project.getProjects(), - new LogicalFilter<>( - ExpressionUtils.replace(filter.getConjuncts(), project.getAliasToProducer()), - limit.withChildren(project.child()) - ) - ); - }) - ) + return project.withProjectsAndChild(project.getProjects(), + new LogicalFilter<>( + ExpressionUtils.replace(filter.getConjuncts(), + project.getAliasToProducer()), + limit.withChildren(project.child()))); + })) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ScalarApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ScalarApplyToJoin.java index 82a1398d59..6c10427215 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ScalarApplyToJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ScalarApplyToJoin.java @@ -56,14 +56,12 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory { LogicalAssertNumRows assertNumRows = new LogicalAssertNumRows<>( new AssertNumRowsElement( 1, apply.getSubqueryExpr().toString(), - apply.isNeedAddSubOutputToProjects() + apply.isInProject() ? AssertNumRowsElement.Assertion.EQ : AssertNumRowsElement.Assertion.LE), (LogicalPlan) apply.right()); return new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION, - apply.getSubCorrespondingConjunct().isPresent() - ? ExpressionUtils.extractConjunction((Expression) apply.getSubCorrespondingConjunct().get()) - : ExpressionUtils.EMPTY_CONDITION, + ExpressionUtils.EMPTY_CONDITION, JoinHint.NONE, apply.getMarkJoinSlotReference(), (LogicalPlan) apply.left(), assertNumRows); @@ -83,14 +81,11 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory { throw new AnalysisException("correlationFilter can't be null in correlatedToJoin"); } - return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, + return new LogicalJoin<>( + apply.isNeedAddSubOutputToProjects() ? JoinType.LEFT_OUTER_JOIN + : JoinType.LEFT_SEMI_JOIN, ExpressionUtils.EMPTY_CONDITION, - ExpressionUtils.extractConjunction( - apply.getSubCorrespondingConjunct().isPresent() - ? ExpressionUtils.and( - (Expression) apply.getSubCorrespondingConjunct().get(), - correlationFilter.get()) - : correlationFilter.get()), + ExpressionUtils.extractConjunction(correlationFilter.get()), JoinHint.NONE, apply.getMarkJoinSlotReference(), apply.children()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java index e7e62c3068..15f1b4a554 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyAggregateFilter.java @@ -88,8 +88,8 @@ public class UnCorrelatedApplyAggregateFilter extends OneRewriteRuleFactory { apply.getSubqueryExpr(), ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(), - apply.getSubCorrespondingConjunct(), apply.isNeedAddSubOutputToProjects(), - apply.left(), newAgg); + apply.isNeedAddSubOutputToProjects(), + apply.isInProject(), apply.left(), newAgg); }).toRule(RuleType.UN_CORRELATED_APPLY_AGGREGATE_FILTER); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyFilter.java index 7a19c1ee2f..af2c83b2b8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyFilter.java @@ -68,8 +68,8 @@ public class UnCorrelatedApplyFilter extends OneRewriteRuleFactory { Plan child = PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child()); return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(), ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(), - apply.getSubCorrespondingConjunct(), apply.isNeedAddSubOutputToProjects(), - apply.left(), child); + apply.isNeedAddSubOutputToProjects(), + apply.isInProject(), apply.left(), child); }).toRule(RuleType.UN_CORRELATED_APPLY_FILTER); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyProjectFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyProjectFilter.java index 9ddc701128..911bf0eef0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyProjectFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnCorrelatedApplyProjectFilter.java @@ -89,8 +89,8 @@ public class UnCorrelatedApplyProjectFilter extends OneRewriteRuleFactory { LogicalProject newProject = project.withProjectsAndChild(projects, child); return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(), ExpressionUtils.optionalAnd(correlatedPredicate), apply.getMarkJoinSlotReference(), - apply.getSubCorrespondingConjunct(), apply.isNeedAddSubOutputToProjects(), - apply.left(), newProject); + apply.isNeedAddSubOutputToProjects(), + apply.isInProject(), apply.left(), newProject); }).toRule(RuleType.UN_CORRELATED_APPLY_PROJECT_FILTER); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java index 5727279ccf..24c839ad12 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java @@ -113,10 +113,8 @@ public class LogicalPlanDeepCopier extends DefaultPlanRewriter<DeepCopierContext .map(f -> ExpressionDeepCopier.INSTANCE.deepCopy(f, context)); Optional<MarkJoinSlotReference> markJoinSlotReference = apply.getMarkJoinSlotReference() .map(m -> (MarkJoinSlotReference) ExpressionDeepCopier.INSTANCE.deepCopy(m, context)); - Optional<Expression> subCorrespondingConjunct = apply.getSubCorrespondingConjunct() - .map(c -> ExpressionDeepCopier.INSTANCE.deepCopy(c, context)); return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter, - markJoinSlotReference, subCorrespondingConjunct, apply.isNeedAddSubOutputToProjects(), left, right); + markJoinSlotReference, apply.isNeedAddSubOutputToProjects(), apply.isInProject(), left, right); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java index 4d255c5c04..ed408c0c3a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java @@ -55,9 +55,10 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends // The slot replaced by the subquery in MarkJoin private final Optional<MarkJoinSlotReference> markJoinSlotReference; - private final Optional<Expression> subCorrespondingConjunct; - // Whether the subquery is in logicalProject + private final boolean inProject; + + // Whether adding the subquery's output to projects private final boolean needAddSubOutputToProjects; /** @@ -68,7 +69,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends List<Expression> correlationSlot, SubqueryExpr subqueryExpr, Optional<Expression> correlationFilter, Optional<MarkJoinSlotReference> markJoinSlotReference, - Optional<Expression> subCorrespondingConjunct, + boolean needAddSubOutputToProjects, boolean inProject, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { super(PlanType.LOGICAL_APPLY, groupExpression, logicalProperties, leftChild, rightChild); @@ -76,16 +77,17 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends this.subqueryExpr = Objects.requireNonNull(subqueryExpr, "subquery can not be null"); this.correlationFilter = correlationFilter; this.markJoinSlotReference = markJoinSlotReference; - this.subCorrespondingConjunct = subCorrespondingConjunct; - this.needAddSubOutputToProjects = inProject; + this.needAddSubOutputToProjects = needAddSubOutputToProjects; + this.inProject = inProject; } public LogicalApply(List<Expression> correlationSlot, SubqueryExpr subqueryExpr, Optional<Expression> correlationFilter, Optional<MarkJoinSlotReference> markJoinSlotReference, - Optional<Expression> subCorrespondingConjunct, boolean inProject, + boolean needAddSubOutputToProjects, boolean inProject, LEFT_CHILD_TYPE input, RIGHT_CHILD_TYPE subquery) { - this(Optional.empty(), Optional.empty(), correlationSlot, subqueryExpr, - correlationFilter, markJoinSlotReference, subCorrespondingConjunct, inProject, input, subquery); + this(Optional.empty(), Optional.empty(), correlationSlot, subqueryExpr, correlationFilter, + markJoinSlotReference, needAddSubOutputToProjects, inProject, input, + subquery); } public List<Expression> getCorrelationSlot() { @@ -128,14 +130,14 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends return markJoinSlotReference; } - public Optional<Expression> getSubCorrespondingConjunct() { - return subCorrespondingConjunct; - } - public boolean isNeedAddSubOutputToProjects() { return needAddSubOutputToProjects; } + public boolean isInProject() { + return inProject; + } + @Override public List<Slot> computeOutput() { return ImmutableList.<Slot>builder() @@ -153,9 +155,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends "correlationSlot", correlationSlot, "correlationFilter", correlationFilter, "isMarkJoin", markJoinSlotReference.isPresent(), - "MarkJoinSlotReference", markJoinSlotReference.isPresent() ? markJoinSlotReference.get() : "empty", - "scalarSubCorrespondingSlot", - subCorrespondingConjunct.isPresent() ? subCorrespondingConjunct.get() : "empty"); + "MarkJoinSlotReference", markJoinSlotReference.isPresent() ? markJoinSlotReference.get() : "empty"); } @Override @@ -171,15 +171,15 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends && Objects.equals(subqueryExpr, that.getSubqueryExpr()) && Objects.equals(correlationFilter, that.getCorrelationFilter()) && Objects.equals(markJoinSlotReference, that.getMarkJoinSlotReference()) - && Objects.equals(subCorrespondingConjunct, that.getSubCorrespondingConjunct()) - && needAddSubOutputToProjects == that.needAddSubOutputToProjects; + && needAddSubOutputToProjects == that.needAddSubOutputToProjects + && inProject == that.inProject; } @Override public int hashCode() { return Objects.hash( correlationSlot, subqueryExpr, correlationFilter, - markJoinSlotReference, subCorrespondingConjunct, needAddSubOutputToProjects); + markJoinSlotReference, needAddSubOutputToProjects, inProject); } @Override @@ -204,23 +204,23 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends public LogicalBinary<Plan, Plan> withChildren(List<Plan> children) { Preconditions.checkArgument(children.size() == 2); return new LogicalApply<>(correlationSlot, subqueryExpr, correlationFilter, - markJoinSlotReference, subCorrespondingConjunct, needAddSubOutputToProjects, + markJoinSlotReference, needAddSubOutputToProjects, inProject, children.get(0), children.get(1)); } @Override public Plan withGroupExpression(Optional<GroupExpression> groupExpression) { return new LogicalApply<>(groupExpression, Optional.of(getLogicalProperties()), - correlationSlot, subqueryExpr, correlationFilter, - markJoinSlotReference, subCorrespondingConjunct, needAddSubOutputToProjects, left(), right()); + correlationSlot, subqueryExpr, correlationFilter, markJoinSlotReference, + needAddSubOutputToProjects, inProject, left(), right()); } @Override public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, List<Plan> children) { Preconditions.checkArgument(children.size() == 2); - return new LogicalApply<>(groupExpression, logicalProperties, correlationSlot, subqueryExpr, correlationFilter, - markJoinSlotReference, subCorrespondingConjunct, needAddSubOutputToProjects, children.get(0), - children.get(1)); + return new LogicalApply<>(groupExpression, logicalProperties, correlationSlot, subqueryExpr, + correlationFilter, markJoinSlotReference, + needAddSubOutputToProjects, inProject, children.get(0), children.get(1)); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java index b3387012f2..bf060d7e5d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeWhereSubqueryTest.java @@ -180,6 +180,10 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP // after aggFilter rule PlanChecker.from(connectContext) .analyze(sql2) + .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) + .applyTopDown(new MergeProjects()) + .applyBottomUp(new PullUpProjectUnderApply()) + .applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject()) .applyBottomUp(new UnCorrelatedApplyAggregateFilter()) .matchesNotCheck( logicalApply( @@ -214,20 +218,19 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP // after Scalar CorrelatedJoin to join PlanChecker.from(connectContext) .analyze(sql2) + .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) + .applyTopDown(new MergeProjects()) + .applyBottomUp(new PullUpProjectUnderApply()) + .applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject()) .applyBottomUp(new UnCorrelatedApplyAggregateFilter()) .applyBottomUp(new ScalarApplyToJoin()) .matchesNotCheck( logicalJoin( any(), logicalAggregate() - ).when(FieldChecker.check("joinType", JoinType.LEFT_SEMI_JOIN)) + ).when(FieldChecker.check("joinType", JoinType.LEFT_OUTER_JOIN)) .when(FieldChecker.check("otherJoinConjuncts", ImmutableList.of(new EqualTo( - new SlotReference(new ExprId(0), "k1", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t6")), - new SlotReference(new ExprId(7), "sum(k3)", BigIntType.INSTANCE, true, - ImmutableList.of()) - ), new EqualTo( new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t7")), new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, @@ -419,6 +422,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP .analyze(sql10) .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) .applyTopDown(new MergeProjects()) + .applyBottomUp(new PullUpProjectUnderApply()) .applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject()) .matchesNotCheck( logicalApply( @@ -451,6 +455,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP .analyze(sql10) .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) .applyTopDown(new MergeProjects()) + .applyBottomUp(new PullUpProjectUnderApply()) .applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject()) .applyBottomUp(new UnCorrelatedApplyAggregateFilter()) .matchesNotCheck( @@ -473,30 +478,27 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements MemoP @Test public void testSql10AfterScalarToJoin() { - PlanChecker.from(connectContext) - .analyze(sql10) - .applyBottomUp(new LogicalSubQueryAliasToLogicalProject()) - .applyTopDown(new MergeProjects()) - .applyBottomUp(new PullUpCorrelatedFilterUnderApplyAggregateProject()) - .applyBottomUp(new UnCorrelatedApplyAggregateFilter()) - .applyBottomUp(new ScalarApplyToJoin()) + PlanChecker.from(connectContext).analyze(sql10).rewrite() .matchesNotCheck( - leftSemiLogicalJoin( - any(), - logicalAggregate( - logicalProject() - ) - ) - .when(j -> j.getOtherJoinConjuncts().equals(ImmutableList.of( - new LessThan(new SlotReference(new ExprId(0), "k1", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t6")), - new SlotReference(new ExprId(8), "max(aa)", BigIntType.INSTANCE, true, - ImmutableList.of())), - new EqualTo(new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t6")), - new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t7"))) - ))) - ); + innerLogicalJoin(any(), logicalAggregate(logicalProject())).when(j -> j + .getOtherJoinConjuncts().equals(ImmutableList + .of(new LessThan( + new SlotReference(new ExprId(0), "k1", + BigIntType.INSTANCE, true, + ImmutableList.of("default_cluster:test", + "t6")), + new SlotReference( + new ExprId(8), "max(aa)", + BigIntType.INSTANCE, true, + ImmutableList.of())))) + && j.getHashJoinConjuncts() + .equals(ImmutableList.of(new EqualTo( + new SlotReference(new ExprId(1), "k2", + BigIntType.INSTANCE, true, + ImmutableList.of("default_cluster:test", + "t6")), + new SlotReference(new ExprId(6), "v2", + BigIntType.INSTANCE, true, ImmutableList.of( + "default_cluster:test", "t7"))))))); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java index f70161f465..bc33c4c70e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunctionTest.java @@ -341,6 +341,9 @@ public class AggScalarSubQueryToWindowFunctionTest extends TPCHTestBase implemen System.out.printf("Test:\n%s\n\n", sql); Plan plan = PlanChecker.from(createCascadesContext(sql)) .analyze(sql) + .applyBottomUp(new PullUpProjectUnderApply()) + .applyTopDown(new PushdownFilterThroughProject()) + .customRewrite(new EliminateUnnecessaryProject()) .customRewrite(new AggScalarSubQueryToWindowFunction()) .rewrite() .getPlan(); diff --git a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out index 492b4c9545..116d4ec918 100644 --- a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out +++ b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out @@ -387,9 +387,6 @@ 2 5 3 3 3 4 -20 2 -22 3 -24 4 -- !imitate_tpcds_10 -- diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out index 55ff9ac886..8c934fb187 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query1.out @@ -23,8 +23,8 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ----------------PhysicalProject ------------------PhysicalOlapScan[customer] --------------PhysicalDistribute -----------------PhysicalProject -------------------hashJoin[LEFT_SEMI_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +----------------hashJoin[INNER_JOIN](ctr1.ctr_store_sk = ctr2.ctr_store_sk)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------------PhysicalProject --------------------hashJoin[INNER_JOIN](store.s_store_sk = ctr1.ctr_store_sk) ----------------------PhysicalDistribute ------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) @@ -32,11 +32,11 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------------------------PhysicalProject --------------------------filter((cast(s_state as VARCHAR(*)) = 'SD')) ----------------------------PhysicalOlapScan[store] ---------------------PhysicalProject -----------------------hashAgg[GLOBAL] -------------------------PhysicalDistribute ---------------------------hashAgg[LOCAL] -----------------------------PhysicalDistribute -------------------------------PhysicalProject ---------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) +------------------PhysicalProject +--------------------hashAgg[GLOBAL] +----------------------PhysicalDistribute +------------------------hashAgg[LOCAL] +--------------------------PhysicalDistribute +----------------------------PhysicalProject +------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out index 48239bce9e..df28c5bee4 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query30.out @@ -24,18 +24,19 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------PhysicalDistribute --------PhysicalTopN ----------PhysicalProject -------------hashJoin[LEFT_SEMI_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) --------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk) ----------------PhysicalDistribute ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) ----------------PhysicalDistribute -------------------hashJoin[INNER_JOIN](customer_address.ca_address_sk = customer.c_current_addr_sk) ---------------------PhysicalProject -----------------------PhysicalOlapScan[customer] ---------------------PhysicalDistribute +------------------PhysicalProject +--------------------hashJoin[INNER_JOIN](customer_address.ca_address_sk = customer.c_current_addr_sk) ----------------------PhysicalProject -------------------------filter((cast(ca_state as VARCHAR(*)) = 'IN')) ---------------------------PhysicalOlapScan[customer_address] +------------------------PhysicalOlapScan[customer] +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------filter((cast(ca_state as VARCHAR(*)) = 'IN')) +----------------------------PhysicalOlapScan[customer_address] --------------PhysicalDistribute ----------------PhysicalProject ------------------hashAgg[GLOBAL] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query41.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query41.out index 9a5ec3b381..964e2d8aec 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query41.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query41.out @@ -8,7 +8,7 @@ PhysicalResultSink ----------PhysicalDistribute ------------hashAgg[LOCAL] --------------PhysicalProject -----------------hashJoin[LEFT_SEMI_JOIN](item.i_manufact = i1.i_manufact) +----------------hashJoin[INNER_JOIN](item.i_manufact = i1.i_manufact) ------------------PhysicalProject --------------------filter((i1.i_manufact_id >= 748)(i1.i_manufact_id <= 788)) ----------------------PhysicalOlapScan[item] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query6.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query6.out index 853909e2c9..2cf7b8a32d 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query6.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query6.out @@ -1,52 +1,54 @@ -- This file is automatically generated. You should know what you did if you want to edit this -- !ds_shape_6 -- -PhysicalTopN ---PhysicalDistribute -----PhysicalTopN -------filter((cnt >= 10)) ---------hashAgg[GLOBAL] -----------PhysicalDistribute -------------hashAgg[LOCAL] ---------------PhysicalProject -----------------hashJoin[LEFT_SEMI_JOIN](j.i_category = i.i_category)(cast(i_current_price as DECIMALV3(38, 5)) > (1.2 * avg(i_current_price))) -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN](c.c_customer_sk = s.ss_customer_sk) -----------------------PhysicalDistribute -------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN](s.ss_item_sk = i.i_item_sk) -----------------------------PhysicalDistribute +PhysicalResultSink +--PhysicalTopN +----PhysicalDistribute +------PhysicalTopN +--------filter((cnt >= 10)) +----------hashAgg[GLOBAL] +------------PhysicalDistribute +--------------hashAgg[LOCAL] +----------------PhysicalProject +------------------hashJoin[INNER_JOIN](j.i_category = i.i_category)(cast(i_current_price as DECIMALV3(38, 5)) > (1.2 * avg(i_current_price))) +--------------------PhysicalProject +----------------------hashJoin[INNER_JOIN](s.ss_item_sk = i.i_item_sk) +------------------------PhysicalDistribute +--------------------------hashJoin[INNER_JOIN](d.d_month_seq = date_dim.d_month_seq) +----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN](s.ss_sold_date_sk = d.d_date_sk) --------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[store_sales] ---------------------------------PhysicalDistribute -----------------------------------hashJoin[INNER_JOIN](d.d_month_seq = date_dim.d_month_seq) -------------------------------------PhysicalProject ---------------------------------------PhysicalOlapScan[date_dim] +----------------------------------hashJoin[INNER_JOIN](c.c_customer_sk = s.ss_customer_sk) ------------------------------------PhysicalDistribute ---------------------------------------PhysicalAssertNumRows -----------------------------------------PhysicalDistribute -------------------------------------------hashAgg[GLOBAL] ---------------------------------------------PhysicalDistribute -----------------------------------------------hashAgg[LOCAL] -------------------------------------------------PhysicalProject ---------------------------------------------------filter((date_dim.d_year = 2002)(date_dim.d_moy = 3)) -----------------------------------------------------PhysicalOlapScan[date_dim] -----------------------------PhysicalDistribute -------------------------------PhysicalProject ---------------------------------PhysicalOlapScan[item] -----------------------PhysicalDistribute -------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN](a.ca_address_sk = c.c_current_addr_sk) -----------------------------PhysicalDistribute -------------------------------PhysicalProject ---------------------------------PhysicalOlapScan[customer] +--------------------------------------PhysicalProject +----------------------------------------PhysicalOlapScan[store_sales] +------------------------------------PhysicalDistribute +--------------------------------------PhysicalProject +----------------------------------------hashJoin[INNER_JOIN](a.ca_address_sk = c.c_current_addr_sk) +------------------------------------------PhysicalDistribute +--------------------------------------------PhysicalProject +----------------------------------------------PhysicalOlapScan[customer] +------------------------------------------PhysicalDistribute +--------------------------------------------PhysicalProject +----------------------------------------------PhysicalOlapScan[customer_address] +--------------------------------PhysicalDistribute +----------------------------------PhysicalProject +------------------------------------PhysicalOlapScan[date_dim] ----------------------------PhysicalDistribute -------------------------------PhysicalProject ---------------------------------PhysicalOlapScan[customer_address] -------------------PhysicalDistribute ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] +------------------------------PhysicalAssertNumRows +--------------------------------PhysicalDistribute +----------------------------------hashAgg[GLOBAL] +------------------------------------PhysicalDistribute +--------------------------------------hashAgg[LOCAL] +----------------------------------------PhysicalProject +------------------------------------------filter((date_dim.d_year = 2002)(date_dim.d_moy = 3)) +--------------------------------------------PhysicalOlapScan[date_dim] +------------------------PhysicalDistribute --------------------------PhysicalProject ----------------------------PhysicalOlapScan[item] +--------------------PhysicalDistribute +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute +--------------------------hashAgg[LOCAL] +----------------------------PhysicalProject +------------------------------PhysicalOlapScan[item] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out index 15caa8024d..77c7b273ba 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query81.out @@ -24,18 +24,19 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) ------PhysicalDistribute --------PhysicalTopN ----------PhysicalProject -------------hashJoin[LEFT_SEMI_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) +------------hashJoin[INNER_JOIN](ctr1.ctr_state = ctr2.ctr_state)(cast(ctr_total_return as DOUBLE) > cast((avg(ctr_total_return) * 1.2) as DOUBLE)) --------------hashJoin[INNER_JOIN](ctr1.ctr_customer_sk = customer.c_customer_sk) ----------------PhysicalDistribute ------------------PhysicalCteConsumer ( cteId=CTEId#0 ) ----------------PhysicalDistribute -------------------hashJoin[INNER_JOIN](customer_address.ca_address_sk = customer.c_current_addr_sk) ---------------------PhysicalProject -----------------------PhysicalOlapScan[customer] ---------------------PhysicalDistribute +------------------PhysicalProject +--------------------hashJoin[INNER_JOIN](customer_address.ca_address_sk = customer.c_current_addr_sk) ----------------------PhysicalProject -------------------------filter((cast(ca_state as VARCHAR(*)) = 'CA')) ---------------------------PhysicalOlapScan[customer_address] +------------------------PhysicalOlapScan[customer] +----------------------PhysicalDistribute +------------------------PhysicalProject +--------------------------filter((cast(ca_state as VARCHAR(*)) = 'CA')) +----------------------------PhysicalOlapScan[customer_address] --------------PhysicalDistribute ----------------PhysicalProject ------------------hashAgg[GLOBAL] diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out index d3901ee9c2..6114877bc9 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q20.out @@ -8,7 +8,7 @@ PhysicalResultSink ----------hashJoin[RIGHT_SEMI_JOIN](supplier.s_suppkey = partsupp.ps_suppkey) ------------PhysicalDistribute --------------PhysicalProject -----------------hashJoin[RIGHT_SEMI_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) +----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) ------------------PhysicalProject --------------------hashAgg[GLOBAL] ----------------------PhysicalDistribute diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out index d3901ee9c2..6114877bc9 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q20.out @@ -8,7 +8,7 @@ PhysicalResultSink ----------hashJoin[RIGHT_SEMI_JOIN](supplier.s_suppkey = partsupp.ps_suppkey) ------------PhysicalDistribute --------------PhysicalProject -----------------hashJoin[RIGHT_SEMI_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) +----------------hashJoin[INNER_JOIN](lineitem.l_partkey = partsupp.ps_partkey)(lineitem.l_suppkey = partsupp.ps_suppkey)(cast(ps_availqty as DECIMALV3(38, 3)) > (0.5 * sum(l_quantity))) ------------------PhysicalProject --------------------hashAgg[GLOBAL] ----------------------PhysicalDistribute diff --git a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy index d1fe2d7a88..349802e21a 100644 --- a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy +++ b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy @@ -38,6 +38,10 @@ suite ("sub_query_correlated") { DROP TABLE IF EXISTS `sub_query_correlated_subquery4` """ + sql """ + DROP TABLE IF EXISTS `sub_query_correlated_subquery5` + """ + sql """ create table if not exists sub_query_correlated_subquery1 (k1 bigint, k2 bigint) @@ -403,7 +407,7 @@ suite ("sub_query_correlated") { """ qt_cast_subquery_in_with_disconjunct """ - SELECT * FROM sub_query_correlated_subquery1 WHERE k1 < (cast('1.2' as decimal(2,1)) * (SELECT sum(k1) FROM sub_query_correlated_subquery3 WHERE sub_query_correlated_subquery1.k1 = sub_query_correlated_subquery3.k1)) or k1 > 10 order by k1, k2; + SELECT * FROM sub_query_correlated_subquery1 WHERE k1 < (cast('1.2' as decimal(2,1)) * (SELECT sum(k1) FROM sub_query_correlated_subquery3 WHERE sub_query_correlated_subquery1.k1 = sub_query_correlated_subquery3.k1)) or k1 > 100 order by k1, k2; """ qt_imitate_tpcds_10 """ diff --git a/regression-test/suites/nereids_tpcds_shape_sf100_p0/shape/query6.groovy b/regression-test/suites/nereids_tpcds_shape_sf100_p0/shape/query6.groovy index 2a43a48427..26e886751e 100644 --- a/regression-test/suites/nereids_tpcds_shape_sf100_p0/shape/query6.groovy +++ b/regression-test/suites/nereids_tpcds_shape_sf100_p0/shape/query6.groovy @@ -30,6 +30,9 @@ suite("query6") { sql 'set enable_nereids_timeout = false' sql 'SET enable_pipeline_engine = true' + // TODO: uncomment following line to get better shape + // sql 'set max_join_number_bushy_tree=6' + qt_ds_shape_6 ''' explain shape plan --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
