This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 82716ec99d [fix](Nereids) type coercion for subquery (#17661)
82716ec99d is described below

commit 82716ec99d28d4174a9bef9158b7ed7abb5ead6b
Author: zhengshiJ <[email protected]>
AuthorDate: Tue Mar 21 20:38:06 2023 +0800

    [fix](Nereids) type coercion for subquery (#17661)
    
    Complete the type coercion of the subquery in the function Binder process.
    
    Expressions generated when subqueries are nested are uniformly converted to 
implicit types in the analyze stage.
    Method: Add a typeCoercionExpr field to the subquery expression to store 
the generated cast information.
    
    Fix scenario where scalarSubQuery handles arithmetic expressions when 
implicitly converting types
---
 .../nereids/rules/analysis/BindExpression.java     |  3 +-
 .../nereids/rules/analysis/FunctionBinder.java     | 19 ++++++
 .../nereids/rules/analysis/SubqueryToApply.java    | 64 ++++++++++----------
 .../rules/rewrite/logical/InApplyToJoin.java       | 15 +----
 .../doris/nereids/trees/expressions/Exists.java    | 16 ++++-
 .../nereids/trees/expressions/InSubquery.java      | 28 ++++++++-
 .../doris/nereids/trees/expressions/ListQuery.java | 14 +++++
 .../nereids/trees/expressions/ScalarSubquery.java  | 18 +++++-
 .../nereids/trees/expressions/SubqueryExpr.java    | 29 +++++++--
 .../doris/nereids/util/TypeCoercionUtils.java      | 13 ++++-
 .../doris/nereids/trees/plans/MarkJoinTest.java    | 68 +++++++++++++++++-----
 .../nereids_syntax_p0/sub_query_correlated.out     | 19 ++++++
 .../nereids_syntax_p0/sub_query_correlated.groovy  |  9 +++
 13 files changed, 244 insertions(+), 71 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
index 701e18c9d1..6e74cc1dbf 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
@@ -77,6 +77,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -142,7 +143,7 @@ public class BindExpression implements AnalysisRuleFactory {
                     Set<Expression> boundConjuncts = 
filter.getConjuncts().stream()
                             .map(expr -> bindSlot(expr, filter.children(), 
ctx.cascadesContext))
                             .map(expr -> bindFunction(expr, 
ctx.cascadesContext))
-                            .collect(Collectors.toSet());
+                            
.collect(Collectors.toCollection(LinkedHashSet::new));
                     return new LogicalFilter<>(boundConjuncts, filter.child());
                 })
             ),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java
index 8fe09fb510..031d8c74d7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FunctionBinder.java
@@ -31,9 +31,12 @@ import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
 import org.apache.doris.nereids.trees.expressions.Divide;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.InSubquery;
 import org.apache.doris.nereids.trees.expressions.IntegralDivide;
+import org.apache.doris.nereids.trees.expressions.ListQuery;
 import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
 import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
@@ -212,4 +215,20 @@ public class FunctionBinder extends 
DefaultExpressionRewriter<CascadesContext> {
         Between newBetween = between.withChildren(rewrittenChildren);
         return TypeCoercionUtils.processBetween(newBetween);
     }
+
+    @Override
+    public Expression visitInSubquery(InSubquery inSubquery, CascadesContext 
context) {
+        Expression newCompareExpr = inSubquery.getCompareExpr().accept(this, 
context);
+        Expression newListQuery = inSubquery.getListQuery().accept(this, 
context);
+        ComparisonPredicate newCpAfterUnNestingSubquery =
+                new EqualTo(newCompareExpr, ((ListQuery) 
newListQuery).getQueryPlan().getOutput().get(0));
+        ComparisonPredicate afterTypeCoercion = (ComparisonPredicate) 
TypeCoercionUtils.processComparisonPredicate(
+                newCpAfterUnNestingSubquery, newCompareExpr, newListQuery);
+        if (!newCompareExpr.getDataType().isBigIntType() && 
newListQuery.getDataType().isBitmapType()) {
+            newCompareExpr = new Cast(newCompareExpr, BigIntType.INSTANCE);
+        }
+        return new InSubquery(newCompareExpr, (ListQuery) 
afterTypeCoercion.right(),
+            inSubquery.getCorrelateSlots(), ((ListQuery) 
afterTypeCoercion.right()).getTypeCoercionExpr(),
+            inSubquery.isNot());
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
index 7ca11befe9..a0b4bfc652 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
@@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.BinaryOperator;
 import org.apache.doris.nereids.trees.expressions.CaseWhen;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
 import org.apache.doris.nereids.trees.expressions.Exists;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InSubquery;
@@ -42,7 +43,6 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalApply;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.util.TypeCoercionUtils;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
@@ -290,11 +290,11 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
 
         @Override
         public Expression visitScalarSubquery(ScalarSubquery scalar, 
SubqueryContext context) {
-            context.setSubqueryCorrespondingConject(scalar, 
scalar.getQueryPlan().getOutput().get(0));
+            context.setSubqueryCorrespondingConject(scalar, 
scalar.getSubqueryOutput());
             // When there is only one scalarSubQuery and CorrelateSlots is 
empty
             // it will not be processed by MarkJoin, so it can be returned 
directly
             if (context.onlySingleSubquery() && 
scalar.getCorrelateSlots().isEmpty()) {
-                return scalar.getQueryPlan().getOutput().get(0);
+                return scalar.getSubqueryOutput();
             }
 
             MarkJoinSlotReference markJoinSlotReference =
@@ -302,19 +302,17 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
             if (isMarkJoin) {
                 context.setSubqueryToMarkJoinSlot(scalar, 
Optional.of(markJoinSlotReference));
             }
-            return isMarkJoin ? markJoinSlotReference : 
scalar.getQueryPlan().getOutput().get(0);
+            return isMarkJoin ? markJoinSlotReference : 
scalar.getSubqueryOutput();
         }
 
         @Override
         public Expression visitNot(Not not, SubqueryContext context) {
             // Need to re-update scalarSubQuery unequal conditions into 
subqueryCorrespondingConject
             if (not.child() instanceof BinaryOperator
-                    && (((BinaryOperator) not.child()).left() instanceof 
ScalarSubquery
-                    || ((BinaryOperator) not.child()).right() instanceof 
ScalarSubquery)) {
+                    && (((BinaryOperator) 
not.child()).left().containsType(ScalarSubquery.class)
+                    || ((BinaryOperator) 
not.child()).right().containsType(ScalarSubquery.class))) {
                 Expression newChild = replace(not.child(), context);
-                ScalarSubquery subquery = ((BinaryOperator) 
not.child()).left() instanceof ScalarSubquery
-                        ? (ScalarSubquery) ((BinaryOperator) 
not.child()).left()
-                        : (ScalarSubquery) ((BinaryOperator) 
not.child()).right();
+                ScalarSubquery subquery = 
collectScalarSubqueryForBinaryOperator((BinaryOperator) not.child());
                 context.updateSubqueryCorrespondingConjunctInNot(subquery);
                 return 
context.getSubqueryToMarkJoinSlotValue(subquery).isPresent() ? newChild : new 
Not(newChild);
             }
@@ -324,8 +322,9 @@ public class SubqueryToApply implements AnalysisRuleFactory 
{
 
         @Override
         public Expression visitBinaryOperator(BinaryOperator binaryOperator, 
SubqueryContext context) {
-            boolean atLeastOneChildIsScalarSubquery =
-                    binaryOperator.left() instanceof ScalarSubquery || 
binaryOperator.right() instanceof ScalarSubquery;
+            boolean atLeastOneChildContainsScalarSubquery =
+                    binaryOperator.left().containsType(ScalarSubquery.class)
+                        || 
binaryOperator.right().containsType(ScalarSubquery.class);
             boolean currentMarkJoin = 
((binaryOperator.left().anyMatch(SubqueryExpr.class::isInstance)
                                         || 
binaryOperator.right().anyMatch(SubqueryExpr.class::isInstance))
                                       && (binaryOperator instanceof Or)) || 
isMarkJoin;
@@ -334,9 +333,10 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
             isMarkJoin = currentMarkJoin;
             Expression right = replace(binaryOperator.right(), context);
 
-            if (atLeastOneChildIsScalarSubquery) {
+            if (atLeastOneChildContainsScalarSubquery && !(binaryOperator 
instanceof CompoundPredicate)) {
                 return context.replaceBinaryOperator(binaryOperator, left, 
right, isProject);
             }
+
             return binaryOperator.withChildren(left, right);
         }
     }
@@ -406,36 +406,36 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
          *  logicalFilter(predicate=k1 > scalarSub or exists)
          *  -->
          *  logicalFilter(predicate=$c$1 or $c$2)
-         *
-         * For non-MarkJoin scalarSubQuery, do implicit type conversion.
-         *  e.g.
-         *  logicalFilter(predicate=k1 > scalarSub(sum(k2)))
-         *  -->
-         *  logicalFilter(predicate=Cast(k1[#0] as BIGINT) = sum(k2)[#1])
          */
         public Expression replaceBinaryOperator(BinaryOperator binaryOperator,
                                                 Expression left,
                                                 Expression right,
                                                 boolean isProject) {
-            boolean leftIsScalar = binaryOperator.left() instanceof 
ScalarSubquery;
-            ScalarSubquery subquery = leftIsScalar
-                    ? (ScalarSubquery) binaryOperator.left() : 
(ScalarSubquery) binaryOperator.right();
-
-            // Perform implicit type conversion on the connection condition of 
scalarSubQuery,
-            // and record the result in subqueryCorrespondingConjunct
-            Expression newLeft = leftIsScalar && 
subqueryToMarkJoinSlot.get(subquery).isPresent()
-                    ? ((ScalarSubquery) 
binaryOperator.left()).getQueryPlan().getOutput().get(0) : left;
-            Expression newRight = !leftIsScalar && 
subqueryToMarkJoinSlot.get(subquery).isPresent()
-                    ? ((ScalarSubquery) 
binaryOperator.right()).getQueryPlan().getOutput().get(0) : right;
-            Expression newBinary = 
TypeCoercionUtils.processComparisonPredicate(
-                    (ComparisonPredicate) binaryOperator.withChildren(newLeft, 
newRight), newLeft, newRight);
+            boolean leftContaionsScalar = 
binaryOperator.left().containsType(ScalarSubquery.class);
+            ScalarSubquery subquery = 
collectScalarSubqueryForBinaryOperator(binaryOperator);
+
+            // record the result in subqueryCorrespondingConjunct
+            Expression newLeft = leftContaionsScalar && 
subqueryToMarkJoinSlot.get(subquery).isPresent()
+                    ? subqueryCorrespondingConjunct.get(subquery) : left;
+            Expression newRight = !leftContaionsScalar && 
subqueryToMarkJoinSlot.get(subquery).isPresent()
+                    ? subqueryCorrespondingConjunct.get(subquery) : right;
+            Expression newBinary = binaryOperator.withChildren(newLeft, 
newRight);
             subqueryCorrespondingConjunct.put(subquery,
-                    (isProject ? (leftIsScalar ? newLeft : newRight) : 
newBinary));
+                    (isProject ? (leftContaionsScalar ? newLeft : newRight) : 
newBinary));
 
-            if (subqueryToMarkJoinSlot.get(subquery).isPresent()) {
+            if (subqueryToMarkJoinSlot.get(subquery).isPresent() && 
binaryOperator instanceof ComparisonPredicate) {
                 return subqueryToMarkJoinSlot.get(subquery).get();
             }
             return newBinary;
         }
     }
+
+    private static ScalarSubquery 
collectScalarSubqueryForBinaryOperator(BinaryOperator binaryOperator) {
+        boolean leftContaionsScalar = 
binaryOperator.left().containsType(ScalarSubquery.class);
+        return leftContaionsScalar
+                ? (ScalarSubquery) ((ImmutableSet) binaryOperator.left()
+                .collect(ScalarSubquery.class::isInstance)).asList().get(0)
+                : (ScalarSubquery) ((ImmutableSet) binaryOperator.right()
+                .collect(ScalarSubquery.class::isInstance)).asList().get(0);
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
index ce95589a7a..8325ad18fb 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
@@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
 import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InSubquery;
@@ -36,9 +35,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.util.ExpressionUtils;
-import org.apache.doris.nereids.util.TypeCoercionUtils;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
@@ -84,10 +81,6 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
 
                 LogicalAggregate agg = new LogicalAggregate(groupExpressions, 
outputExpressions, apply.right());
                 Expression compareExpr = ((InSubquery) 
apply.getSubqueryExpr()).getCompareExpr();
-                if (!compareExpr.getDataType().isBigIntType()) {
-                    //this rule is after type coercion, we need to add cast by 
hand
-                    compareExpr = new Cast(compareExpr, BigIntType.INSTANCE);
-                }
                 Expression expr = new BitmapContains(agg.getOutput().get(0), 
compareExpr);
                 if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
                     expr = new Not(expr);
@@ -101,14 +94,12 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
             //in-predicate to equal
             Expression predicate;
             Expression left = ((InSubquery) 
apply.getSubqueryExpr()).getCompareExpr();
-            Expression right = apply.right().getOutput().get(0);
+            Expression right = apply.getSubqueryExpr().getSubqueryOutput();
             if (apply.isCorrelated()) {
-                predicate = ExpressionUtils.and(
-                        TypeCoercionUtils.processComparisonPredicate(
-                            new EqualTo(left, right), left, right),
+                predicate = ExpressionUtils.and(new EqualTo(left, right),
                         apply.getCorrelationFilter().get());
             } else {
-                predicate = TypeCoercionUtils.processComparisonPredicate(new 
EqualTo(left, right), left, right);
+                predicate = new EqualTo(left, right);
             }
 
             if (apply.getSubCorrespondingConject().isPresent()) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Exists.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Exists.java
index 7f2628e03f..28762addd7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Exists.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Exists.java
@@ -28,6 +28,7 @@ import com.google.common.base.Preconditions;
 
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
 
 /**
  * Exists subquery expression.
@@ -41,8 +42,16 @@ public class Exists extends SubqueryExpr implements 
LeafExpression {
     }
 
     public Exists(LogicalPlan subquery, List<Slot> correlateSlots, boolean 
isNot) {
+        this(Objects.requireNonNull(subquery, "subquery can not be null"),
+                Objects.requireNonNull(correlateSlots, "subquery can not be 
null"),
+                Optional.empty(), isNot);
+    }
+
+    public Exists(LogicalPlan subquery, List<Slot> correlateSlots,
+                  Optional<Expression> typeCoercionExpr, boolean isNot) {
         super(Objects.requireNonNull(subquery, "subquery can not be null"),
-                Objects.requireNonNull(correlateSlots, "subquery can not be 
null"));
+                Objects.requireNonNull(correlateSlots, "subquery can not be 
null"),
+                typeCoercionExpr);
         this.isNot = Objects.requireNonNull(isNot, "isNot can not be null");
     }
 
@@ -88,4 +97,9 @@ public class Exists extends SubqueryExpr implements 
LeafExpression {
     public int hashCode() {
         return Objects.hash(this.queryPlan, this.isNot);
     }
+
+    @Override
+    public Expression withTypeCoercion(DataType dataType) {
+        return this;
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InSubquery.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InSubquery.java
index fe1dc5428f..e6487e0c81 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InSubquery.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InSubquery.java
@@ -25,6 +25,7 @@ import com.google.common.base.Preconditions;
 
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
 
 /**
  * In predicate expression.
@@ -43,8 +44,20 @@ public class InSubquery extends SubqueryExpr {
     }
 
     public InSubquery(Expression compareExpr, ListQuery listQuery, List<Slot> 
correlateSlots, boolean isNot) {
+        this(compareExpr, listQuery, correlateSlots, Optional.empty(), isNot);
+    }
+
+    /**
+     * InSubquery Constructor.
+     */
+    public InSubquery(Expression compareExpr,
+                      ListQuery listQuery,
+                      List<Slot> correlateSlots,
+                      Optional<Expression> typeCoercionExpr,
+                      boolean isNot) {
         super(Objects.requireNonNull(listQuery.getQueryPlan(), "subquery can 
not be null"),
-                Objects.requireNonNull(correlateSlots, "correlateSlots can not 
be null"));
+                Objects.requireNonNull(correlateSlots, "correlateSlots can not 
be null"),
+                typeCoercionExpr);
         this.compareExpr = Objects.requireNonNull(compareExpr, "compareExpr 
can not be null");
         this.listQuery = Objects.requireNonNull(listQuery, "listQuery can not 
be null");
         this.isNot = Objects.requireNonNull(isNot, "isNot can not be null");
@@ -99,7 +112,9 @@ public class InSubquery extends SubqueryExpr {
             return false;
         }
         InSubquery inSubquery = (InSubquery) o;
-        return Objects.equals(this.compareExpr, inSubquery.getCompareExpr())
+        return super.equals(inSubquery)
+                && Objects.equals(this.compareExpr, 
inSubquery.getCompareExpr())
+                && Objects.equals(this.listQuery, inSubquery.listQuery)
                 && this.isNot == inSubquery.isNot;
     }
 
@@ -107,4 +122,13 @@ public class InSubquery extends SubqueryExpr {
     public int hashCode() {
         return Objects.hash(this.compareExpr, this.listQuery, this.isNot);
     }
+
+    @Override
+    public Expression withTypeCoercion(DataType dataType) {
+        return new InSubquery(compareExpr, listQuery, correlateSlots,
+            dataType == listQuery.queryPlan.getOutput().get(0).getDataType()
+                ? Optional.of(listQuery.queryPlan.getOutput().get(0))
+                : Optional.of(new Cast(listQuery.queryPlan.getOutput().get(0), 
dataType)),
+            isNot);
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ListQuery.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ListQuery.java
index 961a46ff68..bccc090016 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ListQuery.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ListQuery.java
@@ -24,7 +24,9 @@ import org.apache.doris.nereids.types.DataType;
 
 import com.google.common.base.Preconditions;
 
+import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
 
 /**
  * Encapsulate LogicalPlan as Expression.
@@ -35,6 +37,10 @@ public class ListQuery extends SubqueryExpr implements 
LeafExpression {
         super(Objects.requireNonNull(subquery, "subquery can not be null"));
     }
 
+    public ListQuery(LogicalPlan subquery, List<Slot> correlateSlots, 
Optional<Expression> typeCoercionExpr) {
+        super(subquery, correlateSlots, typeCoercionExpr);
+    }
+
     @Override
     public DataType getDataType() {
         Preconditions.checkArgument(queryPlan.getOutput().size() == 1);
@@ -54,4 +60,12 @@ public class ListQuery extends SubqueryExpr implements 
LeafExpression {
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
         return visitor.visitListQuery(this, context);
     }
+
+    @Override
+    public Expression withTypeCoercion(DataType dataType) {
+        return new ListQuery(queryPlan, correlateSlots,
+                dataType == queryPlan.getOutput().get(0).getDataType()
+                    ? Optional.of(queryPlan.getOutput().get(0))
+                    : Optional.of(new Cast(queryPlan.getOutput().get(0), 
dataType)));
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
index e49e514511..a17cb0701f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
@@ -27,6 +27,7 @@ import com.google.common.base.Preconditions;
 
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
 
 /**
  * A subquery that will return only one row and one column.
@@ -37,8 +38,15 @@ public class ScalarSubquery extends SubqueryExpr implements 
LeafExpression {
     }
 
     public ScalarSubquery(LogicalPlan subquery, List<Slot> correlateSlots) {
+        this(Objects.requireNonNull(subquery, "subquery can not be null"),
+                Objects.requireNonNull(correlateSlots, "correlateSlots can not 
be null"),
+                Optional.empty());
+    }
+
+    public ScalarSubquery(LogicalPlan subquery, List<Slot> correlateSlots, 
Optional<Expression> typeCoercionExpr) {
         super(Objects.requireNonNull(subquery, "subquery can not be null"),
-                Objects.requireNonNull(correlateSlots, "correlateSlots can not 
be null"));
+                Objects.requireNonNull(correlateSlots, "correlateSlots can not 
be null"),
+                typeCoercionExpr);
     }
 
     @Override
@@ -60,4 +68,12 @@ public class ScalarSubquery extends SubqueryExpr implements 
LeafExpression {
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
         return visitor.visitScalarSubquery(this, context);
     }
+
+    @Override
+    public Expression withTypeCoercion(DataType dataType) {
+        return new ScalarSubquery(queryPlan, correlateSlots,
+                dataType == queryPlan.getOutput().get(0).getDataType()
+                    ? Optional.of(queryPlan.getOutput().get(0))
+                    : Optional.of(new Cast(queryPlan.getOutput().get(0), 
dataType)));
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SubqueryExpr.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SubqueryExpr.java
index d98a72f06a..759634623b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SubqueryExpr.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SubqueryExpr.java
@@ -21,11 +21,13 @@ import org.apache.doris.nereids.exceptions.UnboundException;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.Utils;
 
 import com.google.common.collect.ImmutableList;
 
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
 
 /**
  * Subquery Expression.
@@ -34,20 +36,32 @@ public abstract class SubqueryExpr extends Expression {
     protected final LogicalPlan queryPlan;
     protected final List<Slot> correlateSlots;
 
+    protected final Optional<Expression> typeCoercionExpr;
+
     public SubqueryExpr(LogicalPlan subquery) {
         this.queryPlan = Objects.requireNonNull(subquery, "subquery can not be 
null");
         this.correlateSlots = ImmutableList.of();
+        this.typeCoercionExpr = Optional.empty();
     }
 
-    public SubqueryExpr(LogicalPlan subquery, List<Slot> correlateSlots) {
+    public SubqueryExpr(LogicalPlan subquery, List<Slot> correlateSlots, 
Optional<Expression> typeCoercionExpr) {
         this.queryPlan = Objects.requireNonNull(subquery, "subquery can not be 
null");
         this.correlateSlots = ImmutableList.copyOf(correlateSlots);
+        this.typeCoercionExpr = typeCoercionExpr;
     }
 
     public List<Slot> getCorrelateSlots() {
         return correlateSlots;
     }
 
+    public Optional<Expression> getTypeCoercionExpr() {
+        return typeCoercionExpr;
+    }
+
+    public Expression getSubqueryOutput() {
+        return typeCoercionExpr.orElseGet(() -> queryPlan.getOutput().get(0));
+    }
+
     @Override
     public DataType getDataType() throws UnboundException {
         throw new UnboundException("getDataType");
@@ -65,8 +79,10 @@ public abstract class SubqueryExpr extends Expression {
 
     @Override
     public String toString() {
-        return "(QueryPlan: " + queryPlan
-                + "), (CorrelatedSlots: " + correlateSlots + ")";
+        return Utils.toSqlString("SubqueryExpr",
+                "QueryPlan", queryPlan,
+                "CorrelatedSlots", correlateSlots,
+                "typeCoercionExpr", typeCoercionExpr.isPresent() ? 
typeCoercionExpr.get() : "null");
     }
 
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
@@ -92,15 +108,18 @@ public abstract class SubqueryExpr extends Expression {
         }
         SubqueryExpr other = (SubqueryExpr) o;
         return Objects.equals(correlateSlots, other.correlateSlots)
-                && queryPlan.deepEquals(other.queryPlan);
+                && queryPlan.deepEquals(other.queryPlan)
+                && Objects.equals(typeCoercionExpr, other.typeCoercionExpr);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(queryPlan, correlateSlots);
+        return Objects.hash(queryPlan, correlateSlots, typeCoercionExpr);
     }
 
     public List<Slot> getOutput() {
         return queryPlan.getOutput();
     }
+
+    public abstract Expression withTypeCoercion(DataType dataType);
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
index 720fc36904..ea443bc610 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
@@ -227,7 +227,7 @@ public class TypeCoercionUtils {
      * cast input type if input's datatype is not same with dateType.
      */
     public static Expression castIfNotSameType(Expression input, DataType 
targetType) {
-        if (input.getDataType().equals(targetType)) {
+        if (input.getDataType().equals(targetType) || 
isSubqueryAndDataTypeIsBitmap(input)) {
             return input;
         } else {
             checkCanCastTo(input.getDataType(), targetType);
@@ -235,6 +235,10 @@ public class TypeCoercionUtils {
         }
     }
 
+    private static boolean isSubqueryAndDataTypeIsBitmap(Expression input) {
+        return input instanceof SubqueryExpr && 
input.getDataType().isBitmapType();
+    }
+
     private static boolean canCastTo(DataType input, DataType target) {
         return Type.canCastTo(input.toCatalogDataType(), 
target.toCatalogDataType());
     }
@@ -263,6 +267,13 @@ public class TypeCoercionUtils {
                 }
             }
         }
+        return recordTypeCoercionForSubQuery(input, dataType);
+    }
+
+    private static Expression recordTypeCoercionForSubQuery(Expression input, 
DataType dataType) {
+        if (input instanceof SubqueryExpr) {
+            return ((SubqueryExpr) input).withTypeCoercion(dataType);
+        }
         return new Cast(input, dataType);
     }
 
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/MarkJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/MarkJoinTest.java
index b3d34a3f59..9251c28564 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/MarkJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/MarkJoinTest.java
@@ -29,13 +29,13 @@ public class MarkJoinTest extends TestWithFeService {
         useDatabase("test");
 
         createTable("CREATE TABLE `test_sq_dj1` (\n"
-                + " `c1` bigint(20) NULL,\n"
+                + " `c1` varchar(20) NULL,\n"
                 + " `c2` bigint(20) NULL,\n"
-                + " `c3` bigint(20) not NULL,\n"
-                + " `k4` bigint(20) not NULL,\n"
-                + " `k5` bigint(20) NULL\n"
+                + " `c3` int(20) not NULL,\n"
+                + " `k4` bitmap BITMAP_UNION NULL,\n"
+                + " `k5` bitmap BITMAP_UNION NULL\n"
                 + ") ENGINE=OLAP\n"
-                + "DUPLICATE KEY(`c1`)\n"
+                + "AGGREGATE KEY(`c1`, `c2`, `c3`)\n"
                 + "COMMENT 'OLAP'\n"
                 + "DISTRIBUTED BY HASH(`c2`) BUCKETS 1\n"
                 + "PROPERTIES (\n"
@@ -49,10 +49,10 @@ public class MarkJoinTest extends TestWithFeService {
                 + " `c1` bigint(20) NULL,\n"
                 + " `c2` bigint(20) NULL,\n"
                 + " `c3` bigint(20) not NULL,\n"
-                + " `k4` bigint(20) not NULL,\n"
-                + " `k5` bigint(20) NULL\n"
+                + " `k4` bitmap BITMAP_UNION NULL,\n"
+                + " `k5` bitmap BITMAP_UNION NULL\n"
                 + ") ENGINE=OLAP\n"
-                + "DUPLICATE KEY(`c1`)\n"
+                + "AGGREGATE KEY(`c1`, `c2`, `c3`)\n"
                 + "COMMENT 'OLAP'\n"
                 + "DISTRIBUTED BY HASH(`c2`) BUCKETS 1\n"
                 + "PROPERTIES (\n"
@@ -178,17 +178,17 @@ public class MarkJoinTest extends TestWithFeService {
                 .checkPlannerResult("SELECT CASE\n"
                     + "            WHEN (\n"
                     + "                SELECT COUNT(*) / 2\n"
-                    + "                FROM test_sq_dj1\n"
+                    + "                FROM test_sq_dj2\n"
                     + "            ) > c1 THEN (\n"
                     + "                SELECT AVG(c1)\n"
-                    + "                FROM test_sq_dj1\n"
+                    + "                FROM test_sq_dj2\n"
                     + "            )\n"
                     + "            ELSE (\n"
                     + "                SELECT SUM(c2)\n"
-                    + "                FROM test_sq_dj1\n"
+                    + "                FROM test_sq_dj2\n"
                     + "            )\n"
                     + "            END AS kk4\n"
-                    + "        FROM test_sq_dj1 ;");
+                    + "        FROM test_sq_dj2 ;");
     }
 
     @Test
@@ -197,17 +197,17 @@ public class MarkJoinTest extends TestWithFeService {
                 .checkPlannerResult("SELECT CASE\n"
                     + "            WHEN  exists (\n"
                     + "                SELECT COUNT(*) / 2\n"
-                    + "                FROM test_sq_dj1\n"
+                    + "                FROM test_sq_dj2\n"
                     + "            ) THEN (\n"
                     + "                SELECT AVG(c1)\n"
-                    + "                FROM test_sq_dj1\n"
+                    + "                FROM test_sq_dj2\n"
                     + "            )\n"
                     + "            ELSE (\n"
                     + "                SELECT SUM(c2)\n"
-                    + "                FROM test_sq_dj1\n"
+                    + "                FROM test_sq_dj2\n"
                     + "            )\n"
                     + "            END AS kk4\n"
-                    + "        FROM test_sq_dj1 ;");
+                    + "        FROM test_sq_dj2 ;");
     }
 
     @Test
@@ -246,4 +246,40 @@ public class MarkJoinTest extends TestWithFeService {
                     + " OR c1 < (SELECT sum(c1) FROM test_sq_dj2 WHERE 
test_sq_dj1.c1 = test_sq_dj2.c1)"
                     + " AND exists (SELECT c1 FROM test_sq_dj2 WHERE 
test_sq_dj1.c1 = test_sq_dj2.c1)");
     }
+
+    @Test
+    public void test20() {
+        PlanChecker.from(connectContext)
+                .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 < 
(cast('1.2' as decimal(2,1)) * (SELECT sum(c1) FROM test_sq_dj2 WHERE 
test_sq_dj1.c1 = test_sq_dj2.c1))");
+    }
+
+    @Test
+    public void test21() {
+        PlanChecker.from(connectContext)
+                .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 < 
(cast('1.2' as decimal(2,1)) * (SELECT sum(c1) FROM test_sq_dj2 WHERE 
test_sq_dj1.c1 = test_sq_dj2.c1)) or c1 > 10");
+    }
+
+    @Test
+    public void test22() {
+        PlanChecker.from(connectContext)
+                .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c1 != 
(cast('1.2' as decimal(2,1)) * (SELECT sum(c1) FROM test_sq_dj2 WHERE 
test_sq_dj1.c1 = test_sq_dj2.c1)) or c1 > 10");
+    }
+
+    @Test
+    public void test23() {
+        PlanChecker.from(connectContext)
+                .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c2 in 
(select k4 from test_sq_dj2)");
+    }
+
+    @Test
+    public void test24() {
+        PlanChecker.from(connectContext)
+                .checkPlannerResult("SELECT * FROM test_sq_dj1 WHERE c3 in 
(select k4 from test_sq_dj2)");
+    }
+
+    @Test
+    public void test25() {
+        PlanChecker.from(connectContext)
+                .checkPlannerResult("select * from test_sq_dj1 where c1 in 
(select c1 from test_sq_dj1 where c2 in (select c2 from test_sq_dj2) and c2 > 
(select sum(c1) from test_sq_dj2))");
+    }
 }
diff --git a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out 
b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
index b7e57d2613..1073b29d1f 100644
--- a/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
+++ b/regression-test/data/nereids_syntax_p0/sub_query_correlated.out
@@ -355,3 +355,22 @@
 
 -- !multi_subquery_scalar_and_in_or_scalar_and_exists --
 
+-- !cast_subquery_in --
+1      2
+1      3
+2      4
+2      5
+3      3
+3      4
+
+-- !cast_subquery_in_with_disconjunct --
+1      2
+1      3
+2      4
+2      5
+3      3
+3      4
+20     2
+22     3
+24     4
+
diff --git 
a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy 
b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
index 1d153405ea..b816f82366 100644
--- a/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
+++ b/regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
@@ -383,4 +383,13 @@ suite ("sub_query_correlated") {
                                              OR k1 < (SELECT sum(k1) FROM 
sub_query_correlated_subquery3 WHERE sub_query_correlated_subquery1.k1 = 
sub_query_correlated_subquery3.k1))
                                         and exists (SELECT k1 FROM 
sub_query_correlated_subquery3 WHERE sub_query_correlated_subquery1.k1 = 
sub_query_correlated_subquery3.k1);
     """
+    
+    //----------type coercion subquery-----------
+    qt_cast_subquery_in """
+        SELECT * FROM sub_query_correlated_subquery1 WHERE k1 < (cast('1.2' as 
decimal(2,1)) * (SELECT sum(k1) FROM sub_query_correlated_subquery3 WHERE 
sub_query_correlated_subquery1.k1 = sub_query_correlated_subquery3.k1)) order 
by k1, k2;
+    """
+
+    qt_cast_subquery_in_with_disconjunct """
+        SELECT * FROM sub_query_correlated_subquery1 WHERE k1 < (cast('1.2' as 
decimal(2,1)) * (SELECT sum(k1) FROM sub_query_correlated_subquery3 WHERE 
sub_query_correlated_subquery1.k1 = sub_query_correlated_subquery3.k1)) or k1 > 
10 order by k1, k2;
+    """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to