This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 58f2593ba1 [Fix](Nereids) Add cast comparison with slot reference when
inferring predicate (#21171)
58f2593ba1 is described below
commit 58f2593ba1b65713e7b3c1ed39fc84be8cc3ff2c
Author: LiBinfeng <[email protected]>
AuthorDate: Wed Jul 19 23:14:26 2023 +0800
[Fix](Nereids) Add cast comparison with slot reference when inferring
predicate (#21171)
Problem:
When inferring predicate, we assume that slot reference need to be
inferred. But in this case:
carete table tb1(l1 smallint) ...;
create table tb2(l2 int) ...;
select * from tb1 inner join tb2 where tb1.l1 = tb2.l2 and tb2.l2 = 1;
We can not get tb1.l1 = 1 filter because we will add a cast to l1 (Cast
smallint to int l1) = l2.
Solved:
Add cast consideration when inferring predicate, also add change judgement
when judging equals to slotreference and cast expression. But when we want to
infer predicate from bigger type cast to smaller type, it is logical error.
For example:
select * from tb1 inner join tb2 where tb1.l1 = cast(tb2.l2 as smallint)
and tb2.l2 = (number between smallint max and intmax);
tb2.l2 value can not infer to left side because tb1.l1 would be false
value, and when we add one more condition like tb1.l1 = tb3.l3(smallint). It
would cause this predicate be false.
---
.../rules/rewrite/PredicatePropagation.java | 40 ++++++++++++++++----
.../doris/nereids/types/coercion/IntegralType.java | 4 ++
.../apache/doris/nereids/util/ExpressionUtils.java | 29 +++++++++++++++
.../infer_predicate/infer_predicate.groovy | 43 ++++++++++++++++++++++
4 files changed, 108 insertions(+), 8 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
index 71ceee713a..9602bb4a56 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
@@ -17,11 +17,15 @@
package org.apache.doris.nereids.rules.rewrite;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.coercion.IntegralType;
+import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Sets;
@@ -70,19 +74,38 @@ public class PredicatePropagation {
@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp,
Void context) {
- if (cp.left().isSlot() && cp.right().isConstant()) {
- return replaceSlot(cp);
- } else if (cp.left().isConstant() && cp.right().isSlot()) {
- return replaceSlot(cp);
+ // we need to get expression covered by cast, because we want
to infer different datatype
+ if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left())
&& (cp.right().isConstant())) {
+ return replaceSlot(cp,
ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
+ } else if
(ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) &&
cp.left().isConstant()) {
+ return replaceSlot(cp,
ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
}
return super.visit(cp, context);
}
- private Expression replaceSlot(Expression expr) {
+ private boolean isOriginDataTypeBigger(DataType originDataType,
Expression expr) {
+ if ((leftSlotEqualToRightSlot.child(0).getDataType()
instanceof IntegralType)
+ && (leftSlotEqualToRightSlot.child(1).getDataType()
instanceof IntegralType)
+ && (originDataType instanceof IntegralType)) {
+ // infer filter can not be lower than original datatype,
or dataset would be wrong
+ if (((IntegralType) originDataType).widerThan(
+ (IntegralType)
leftSlotEqualToRightSlot.child(0).getDataType())
+ || ((IntegralType)
originDataType).widerThan(
+ (IntegralType)
leftSlotEqualToRightSlot.child(1).getDataType())) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private Expression replaceSlot(Expression expr, DataType
originDataType) {
return expr.rewriteUp(e -> {
- if (e.equals(leftSlotEqualToRightSlot.child(0))) {
+ if (isOriginDataTypeBigger(originDataType,
leftSlotEqualToRightSlot)) {
+ return e;
+ }
+ if (ExpressionUtils.isTwoExpressionEqualWithCast(e,
leftSlotEqualToRightSlot.child(0))) {
return leftSlotEqualToRightSlot.child(1);
- } else if (e.equals(leftSlotEqualToRightSlot.child(1))) {
+ } else if (ExpressionUtils.isTwoExpressionEqualWithCast(e,
leftSlotEqualToRightSlot.child(1))) {
return leftSlotEqualToRightSlot.child(0);
} else {
return e;
@@ -98,7 +121,8 @@ public class PredicatePropagation {
*/
private boolean canEquivalentInfer(Expression predicate) {
return predicate instanceof EqualTo
- && predicate.children().stream().allMatch(e -> e instanceof
SlotReference)
+ && predicate.children().stream().allMatch(e ->
+ (e instanceof SlotReference) || (e instanceof Cast &&
e.child(0).isSlot()))
&&
predicate.child(0).getDataType().equals(predicate.child(1).getDataType());
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
index 7c147ff017..542f9df993 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
@@ -41,4 +41,8 @@ public class IntegralType extends NumericType {
public String simpleString() {
return "integral";
}
+
+ public boolean widerThan(IntegralType other) {
+ return this.width() > other.width();
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index a3a3ca1b80..71f9808ad2 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -38,6 +38,7 @@ import
org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
@@ -251,6 +252,34 @@ public class ExpressionUtils {
}
}
+ /**
+ * get slot covered by cast
+ * example: input: cast(cast(table.columnA)) output: columnA.datatype
+ *
+ */
+ public static DataType getDatatypeCoveredByCast(Expression expr) {
+ if (expr instanceof Cast) {
+ return getDatatypeCoveredByCast(((Cast) expr).child());
+ }
+ return expr.getDataType();
+ }
+
+ /**
+ * judge if expression is slot covered by cast
+ * example: cast(cast(table.columnA))
+ */
+ public static boolean isExpressionSlotCoveredByCast(Expression expr) {
+ if (expr instanceof Cast) {
+ return isExpressionSlotCoveredByCast(((Cast) expr).child());
+ }
+ return expr instanceof SlotReference;
+ }
+
+ public static boolean isTwoExpressionEqualWithCast(Expression left,
Expression right) {
+ return ExpressionUtils.extractSlotOrCastOnSlot(left)
+ .equals(ExpressionUtils.extractSlotOrCastOnSlot(right));
+ }
+
/**
* Replace expression node in the expression tree by `replaceMap` in
top-down manner.
* For example.
diff --git
a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
new file mode 100644
index 0000000000..ac46201185
--- /dev/null
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -0,0 +1,43 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_infer_predicate") {
+ sql 'set enable_nereids_planner=true'
+ sql 'set enable_fallback_to_original_planner=false'
+
+ sql 'drop table if exists infer_tb1;'
+ sql 'drop table if exists infer_tb2;'
+
+ sql '''create table infer_tb1 (k1 int, k2 int) distributed by hash(k1)
buckets 3 properties('replication_num' = '1');'''
+
+ sql '''create table infer_tb2 (k1 tinyint, k2 smallint, k3 int, k4 bigint,
k5 largeint, k6 date, k7 datetime, k8 float, k9 double) distributed by hash(k1)
buckets 3 properties('replication_num' = '1');'''
+
+ explain {
+ sql "select * from infer_tb1 inner join infer_tb2 where infer_tb2.k1 =
infer_tb1.k2 and infer_tb2.k1 = 1;"
+ contains "PREDICATES: k2[#20] = 1"
+ }
+
+ explain {
+ sql "select * from infer_tb1 inner join infer_tb2 where infer_tb1.k2 =
infer_tb2.k1 and infer_tb2.k1 = 1;"
+ contains "PREDICATES: k2[#20] = 1"
+ }
+
+ explain {
+ sql "select * from infer_tb1 inner join infer_tb2 where
cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;"
+ notContains "PREDICATES: k2[#20] = 1"
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]