This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 32f8f6671197b7c0a7c61d9385290f0e28c964c3
Author: starocean999 <[email protected]>
AuthorDate: Fri Jul 28 15:08:56 2023 +0800

    [fix](nereids) SubqueryToApply may lost conjunct (#22262)
    
    consider sql:
    ```
    SELECT *
            FROM sub_query_correlated_subquery1 t1
            WHERE coalesce(bitand(
            cast(
                (SELECT sum(k1)
                FROM sub_query_correlated_subquery3 ) AS int),
                cast(t1.k1 AS int)),
                coalesce(t1.k1, t1.k2)) is NULL
            ORDER BY  t1.k1, t1.k2;
    ```
    is Null conjunct is lost in SubqueryToApply rule. This pr fix it
---
 .../nereids/rules/analysis/SubqueryToApply.java    | 140 +++++----------------
 .../nereids_syntax_p0/sub_query_correlated.out     |  17 +++
 .../nereids_syntax_p0/sub_query_correlated.groovy  |  23 ++--
 3 files changed, 66 insertions(+), 114 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
index fb39794c23..6b89d02782 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
@@ -24,15 +24,11 @@ import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.BinaryOperator;
 import org.apache.doris.nereids.trees.expressions.CaseWhen;
-import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
-import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
 import org.apache.doris.nereids.trees.expressions.Exists;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InSubquery;
-import org.apache.doris.nereids.trees.expressions.IsNull;
 import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
@@ -46,6 +42,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 
@@ -196,7 +193,7 @@ public class SubqueryToApply implements AnalysisRuleFactory 
{
                 subquery.getCorrelateSlots(),
                 subquery, Optional.empty(),
                 subqueryToMarkJoinSlot.get(subquery),
-                mergeScalarSubConjectAndFilterConject(
+                mergeScalarSubConjunctAndFilterConjunct(
                     subquery, subqueryCorrespondingConject,
                     conjunct, needAddSubOutputToProjects, singleSubquery), 
isProject,
                 childPlan, subquery.getQueryPlan());
@@ -245,7 +242,7 @@ public class SubqueryToApply implements AnalysisRuleFactory 
{
      *      LogicalJoin(otherConjunct[k2 = c2])  ---> inSub
      *          LogicalJoin(otherConjunct[k1 > sum(c1)])  ---> scalarSub
      */
-    private Optional<Expression> mergeScalarSubConjectAndFilterConject(
+    private Optional<Expression> mergeScalarSubConjunctAndFilterConjunct(
                     SubqueryExpr subquery,
                     Map<SubqueryExpr, Expression> subqueryCorrespondingConject,
                     Optional<Expression> conjunct,
@@ -292,8 +289,20 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
                     .collect(ImmutableSet.toImmutableSet());
         }
 
-        public Expression replace(Expression expressions, SubqueryContext 
subqueryContext) {
-            return expressions.accept(this, subqueryContext);
+        public Expression replace(Expression expression, SubqueryContext 
subqueryContext) {
+            Expression replacedExpr = doReplace(expression, subqueryContext);
+            if (subqueryContext.onlySingleSubquery() && !isMarkJoin) {
+                // if there is only one subquery and it's not a mark join,
+                // we can merge the filter with the join conjunct to eliminate 
the filter node
+                // to do that, we need update the subquery's corresponding 
conjunct use replacedExpr
+                // see mergeScalarSubConjunctAndFilterConjunct() for more info
+                
subqueryContext.updateSubqueryCorrespondingConjunct(replacedExpr);
+            }
+            return replacedExpr;
+        }
+
+        public Expression doReplace(Expression expression, SubqueryContext 
subqueryContext) {
+            return expression.accept(this, subqueryContext);
         }
 
         @Override
@@ -342,52 +351,16 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
             return isMarkJoin ? markJoinSlotReference : 
scalar.getSubqueryOutput();
         }
 
-        @Override
-        public Expression visitNot(Not not, SubqueryContext context) {
-            // Need to re-update scalarSubQuery unequal conditions into 
subqueryCorrespondingConject
-            if (not.child() instanceof BinaryOperator
-                    && (((BinaryOperator) 
not.child()).left().containsType(ScalarSubquery.class)
-                    || ((BinaryOperator) 
not.child()).right().containsType(ScalarSubquery.class))) {
-                Expression newChild = replace(not.child(), context);
-                ScalarSubquery subquery = 
collectScalarSubqueryForBinaryOperator((BinaryOperator) not.child());
-                context.updateSubqueryCorrespondingConjunctInNot(subquery);
-                return 
context.getSubqueryToMarkJoinSlotValue(subquery).isPresent() ? newChild : new 
Not(newChild);
-            }
-
-            return visit(not, context);
-        }
-
-        @Override
-        public Expression visitIsNull(IsNull isNull, SubqueryContext context) {
-            // Need to re-update scalarSubQuery unequal conditions into 
subqueryCorrespondingConject
-            if (isNull.child() instanceof BinaryOperator
-                    && (((BinaryOperator) 
isNull.child()).left().containsType(ScalarSubquery.class)
-                    || ((BinaryOperator) 
isNull.child()).right().containsType(ScalarSubquery.class))) {
-                Expression newChild = replace(isNull.child(), context);
-                ScalarSubquery subquery = 
collectScalarSubqueryForBinaryOperator((BinaryOperator) isNull.child());
-                context.updateSubqueryCorrespondingConjunctIsNull(subquery);
-                return 
context.getSubqueryToMarkJoinSlotValue(subquery).isPresent() ? newChild : new 
IsNull(newChild);
-            }
-
-            return visit(isNull, context);
-        }
-
         @Override
         public Expression visitBinaryOperator(BinaryOperator binaryOperator, 
SubqueryContext context) {
-            boolean atLeastOneChildContainsScalarSubquery =
-                    binaryOperator.left().containsType(ScalarSubquery.class)
-                        || 
binaryOperator.right().containsType(ScalarSubquery.class);
-            boolean currentMarkJoin = 
((binaryOperator.left().anyMatch(SubqueryExpr.class::isInstance)
-                                        || 
binaryOperator.right().anyMatch(SubqueryExpr.class::isInstance))
-                                      && (binaryOperator instanceof Or)) || 
isMarkJoin;
-            isMarkJoin = currentMarkJoin;
-            Expression left = replace(binaryOperator.left(), context);
-            isMarkJoin = currentMarkJoin;
-            Expression right = replace(binaryOperator.right(), context);
-
-            if (atLeastOneChildContainsScalarSubquery && !(binaryOperator 
instanceof CompoundPredicate)) {
-                return context.replaceBinaryOperator(binaryOperator, left, 
right, isProject);
-            }
+            // update isMarkJoin flag
+            isMarkJoin =
+                    isMarkJoin || 
((binaryOperator.left().anyMatch(SubqueryExpr.class::isInstance)
+                            || 
binaryOperator.right().anyMatch(SubqueryExpr.class::isInstance))
+                            && (binaryOperator instanceof Or));
+
+            Expression left = doReplace(binaryOperator.left(), context);
+            Expression right = doReplace(binaryOperator.right(), context);
 
             return binaryOperator.withChildren(left, right);
         }
@@ -399,11 +372,11 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
      * For inSubquery and exists: it will be directly replaced by 
markSlotReference
      *  e.g.
      *  logicalFilter(predicate=exists) ---> logicalFilter(predicate=$c$1)
-     * For scalarSubquery: will replace the connected ComparisonPredicate with 
markSlotReference
+     * For scalarSubquery: it will be replaced by markSlotReference too
      *  e.g.
-     *  logicalFilter(predicate=k1 > scalarSubquery) ---> 
logicalFilter(predicate=$c$1)
+     *  logicalFilter(predicate=k1 > scalarSubquery) ---> 
logicalFilter(predicate=k1 > $c$1)
      *
-     * subqueryCorrespondingConject: Record the conject corresponding to the 
subquery.
+     * subqueryCorrespondingConjunct: Record the conject corresponding to the 
subquery.
      * rule:
      *
      *
@@ -427,10 +400,6 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
             return subqueryCorrespondingConjunct;
         }
 
-        private Optional<MarkJoinSlotReference> 
getSubqueryToMarkJoinSlotValue(SubqueryExpr subqueryExpr) {
-            return subqueryToMarkJoinSlot.get(subqueryExpr);
-        }
-
         private void setSubqueryToMarkJoinSlot(SubqueryExpr subquery,
                                               Optional<MarkJoinSlotReference> 
markJoinSlotReference) {
             subqueryToMarkJoinSlot.put(subquery, markJoinSlotReference);
@@ -445,56 +414,13 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
             return subqueryToMarkJoinSlot.size() == 1;
         }
 
-        private void updateSubqueryCorrespondingConjunctInNot(SubqueryExpr 
subquery) {
-            if (subqueryCorrespondingConjunct.containsKey(subquery)) {
-                subqueryCorrespondingConjunct.replace(subquery,
-                    new Not(subqueryCorrespondingConjunct.get(subquery)));
-            }
-        }
-
-        private void updateSubqueryCorrespondingConjunctIsNull(SubqueryExpr 
subquery) {
-            if (subqueryCorrespondingConjunct.containsKey(subquery)) {
-                subqueryCorrespondingConjunct.replace(subquery,
-                        new 
IsNull(subqueryCorrespondingConjunct.get(subquery)));
-            }
+        private void updateSubqueryCorrespondingConjunct(Expression 
expression) {
+            Preconditions.checkState(onlySingleSubquery(),
+                    "onlySingleSubquery must be true");
+            subqueryCorrespondingConjunct
+                    .forEach((k, v) -> subqueryCorrespondingConjunct.put(k, 
expression));
         }
 
-        /**
-         * For scalarSubQuery and MarkJoin, it will be replaced by 
markSlotReference
-         *  e.g.
-         *  logicalFilter(predicate=k1 > scalarSub or exists)
-         *  -->
-         *  logicalFilter(predicate=$c$1 or $c$2)
-         */
-        private Expression replaceBinaryOperator(BinaryOperator binaryOperator,
-                                                Expression left,
-                                                Expression right,
-                                                boolean isProject) {
-            boolean leftContaionsScalar = 
binaryOperator.left().containsType(ScalarSubquery.class);
-            ScalarSubquery subquery = 
collectScalarSubqueryForBinaryOperator(binaryOperator);
-
-            // record the result in subqueryCorrespondingConjunct
-            Expression newLeft = leftContaionsScalar && 
subqueryToMarkJoinSlot.get(subquery).isPresent()
-                    ? subqueryCorrespondingConjunct.get(subquery) : left;
-            Expression newRight = !leftContaionsScalar && 
subqueryToMarkJoinSlot.get(subquery).isPresent()
-                    ? subqueryCorrespondingConjunct.get(subquery) : right;
-            Expression newBinary = binaryOperator.withChildren(newLeft, 
newRight);
-            subqueryCorrespondingConjunct.put(subquery,
-                    (isProject ? (leftContaionsScalar ? newLeft : newRight) : 
newBinary));
-
-            if (subqueryToMarkJoinSlot.get(subquery).isPresent() && 
binaryOperator instanceof ComparisonPredicate) {
-                return subqueryToMarkJoinSlot.get(subquery).get();
-            }
-            return newBinary;
-        }
     }
 
-    private static ScalarSubquery 
collectScalarSubqueryForBinaryOperator(BinaryOperator binaryOperator) {
-        boolean leftContaionsScalar = 
binaryOperator.left().containsType(ScalarSubquery.class);
-        return leftContaionsScalar
-                ? (ScalarSubquery) ((ImmutableSet) binaryOperator.left()
-                .collect(ScalarSubquery.class::isInstance)).asList().get(0)
-                : (ScalarSubquery) ((ImmutableSet) binaryOperator.right()
-                .collect(ScalarSubquery.class::isInstance)).asList().get(0);
-    }
 }
diff --git a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out 
b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
index d6eae9c01a..492b4c9545 100644
--- a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
+++ b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
@@ -161,6 +161,23 @@
 22     3
 24     4
 
+-- !scalar_subquery1 --
+
+-- !scalar_subquery2 --
+
+-- !in_subquery --
+1      abc     2       3       4
+1      abcd    3       3       4
+
+-- !exist_subquery --
+2      uvw     3       4       2
+2      uvw     3       4       2
+2      xyz     2       4       2
+
+-- !in_subquery --
+
+-- !exist_subquery --
+
 -- !scalar_subquery_with_order --
 20     2
 22     3
diff --git 
a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy 
b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
index 665ad67e48..d1fe2d7a88 100644
--- a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
+++ b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
@@ -205,13 +205,24 @@ suite ("sub_query_correlated") {
     """
 
     //----------complex subqueries----------
-    //----------remove temporarily---------
-    /*qt_scalar_subquery """
+    qt_scalar_subquery1 """
         select * from sub_query_correlated_subquery1
             where k1 = (select sum(k1) from sub_query_correlated_subquery3 
where sub_query_correlated_subquery1.k1 = sub_query_correlated_subquery3.v1 and 
sub_query_correlated_subquery3.v2 = 2)
             order by k1, k2
     """
 
+    qt_scalar_subquery2 """
+        SELECT *
+        FROM sub_query_correlated_subquery1 t1
+        WHERE coalesce(bitand( 
+        cast(
+            (SELECT sum(k1)
+            FROM sub_query_correlated_subquery3 ) AS int), 
+            cast(t1.k1 AS int)), 
+            coalesce(t1.k1, t1.k2)) is NULL
+        ORDER BY  t1.k1, t1.k2;
+    """
+
     qt_in_subquery """
         select * from sub_query_correlated_subquery3
             where (k1 = 1 or k1 = 2 or k1 = 3) and v1 in (select k1 from 
sub_query_correlated_subquery1 where sub_query_correlated_subquery1.k2 = 
sub_query_correlated_subquery3.v2 and sub_query_correlated_subquery1.k1 = 3)
@@ -222,12 +233,10 @@ suite ("sub_query_correlated") {
         select * from sub_query_correlated_subquery3
             where k1 = 2 and exists (select * from 
sub_query_correlated_subquery1 where sub_query_correlated_subquery1.k1 = 
sub_query_correlated_subquery3.v2 and sub_query_correlated_subquery1.k2 = 4)
             order by k1, k2
-    """*/
+    """
 
     //----------complex nonEqual subqueries----------
-
-    //----------remove temporarily---------
-    /*qt_in_subquery """
+    qt_in_subquery """
         select * from sub_query_correlated_subquery3
             where (k1 = 1 or k1 = 2 or k1 = 3) and v1 in (select k1 from 
sub_query_correlated_subquery1 where sub_query_correlated_subquery1.k2 > 
sub_query_correlated_subquery3.v2 and sub_query_correlated_subquery1.k1 = 3)
             order by k1, k2
@@ -237,7 +246,7 @@ suite ("sub_query_correlated") {
         select * from sub_query_correlated_subquery3
             where k1 = 2 and exists (select * from 
sub_query_correlated_subquery1 where sub_query_correlated_subquery1.k1 < 
sub_query_correlated_subquery3.v2 and sub_query_correlated_subquery1.k2 = 4)
             order by k1, k2
-    """*/
+    """
 
     //----------subquery with order----------
     order_qt_scalar_subquery_with_order """


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to