This is an automated email from the ASF dual-hosted git repository.
lingmiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new f758e1166a [fix] Fix RewriteBinaryPredicatesRule which causes wrong
query results in some cases. (#10551)
f758e1166a is described below
commit f758e1166a2844d507806c2ddc5376cea036d31b
Author: luozenglin <[email protected]>
AuthorDate: Wed Jul 6 15:39:27 2022 +0800
[fix] Fix RewriteBinaryPredicatesRule which causes wrong query results in
some cases. (#10551)
During the query planning phase, the binary predicate rewrite optimization
process converting DecimalLiteral to integers may overflow, resulting in false
values like "id = 12345678901.0" (see the issue for detailed examples).
This pr fixes a possible overflow and optimizes the case where
DecimalLiteral is not in the column type value range.
Issue Number: close #10544
---
.../org/apache/doris/analysis/DecimalLiteral.java | 6 ++
.../java/org/apache/doris/analysis/IntLiteral.java | 22 ++++
.../doris/rewrite/RewriteBinaryPredicatesRule.java | 72 +++++++++----
.../org/apache/doris/analysis/SelectStmtTest.java | 40 +++----
.../java/org/apache/doris/planner/PlannerTest.java | 4 +-
.../rewrite/RewriteBinaryPredicatesRuleTest.java | 118 +++++++++++++++++++++
6 files changed, 219 insertions(+), 43 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
index e938a46361..3e5bf9abc7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
@@ -246,6 +246,12 @@ public class DecimalLiteral extends LiteralExpr {
} else if (targetType.isFloatingPointType()) {
return new FloatLiteral(value.doubleValue(), targetType);
} else if (targetType.isIntegerType()) {
+ // If the integer part of BigDecimal is too big to fit into long,
+ // longValue() will only return the low-order 64-bit value.
+ if (value.compareTo(BigDecimal.valueOf(Long.MAX_VALUE)) > 0
+ || value.compareTo(BigDecimal.valueOf(Long.MIN_VALUE)) <
0) {
+ throw new AnalysisException("Integer part of " + value + "
exceeds storage range of Long Type.");
+ }
return new IntLiteral(value.longValue(), targetType);
} else if (targetType.isStringType()) {
return new StringLiteral(value.toString());
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java
index 00662c5e6a..4d4f673822 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java
@@ -172,6 +172,28 @@ public class IntLiteral extends LiteralExpr {
return new IntLiteral(value);
}
+ public static IntLiteral createMaxValue(Type type) {
+ long value = 0L;
+ switch (type.getPrimitiveType()) {
+ case TINYINT:
+ value = TINY_INT_MAX;
+ break;
+ case SMALLINT:
+ value = SMALL_INT_MAX;
+ break;
+ case INT:
+ value = INT_MAX;
+ break;
+ case BIGINT:
+ value = BIG_INT_MAX;
+ break;
+ default:
+ Preconditions.checkState(false);
+ }
+
+ return new IntLiteral(value);
+ }
+
@Override
public boolean isMinValue() {
switch (type.getPrimitiveType()) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java
index 4ed232b4b1..a18797b657 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java
@@ -19,10 +19,13 @@ package org.apache.doris.rewrite;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.BinaryPredicate;
+import org.apache.doris.analysis.BinaryPredicate.Operator;
import org.apache.doris.analysis.BoolLiteral;
import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.DecimalLiteral;
import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.IntLiteral;
+import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
@@ -53,36 +56,61 @@ public class RewriteBinaryPredicatesRule implements
ExprRewriteRule {
* 3) "select * from T where t1 != 2.0" is converted to "select * from T
where t1 != 2"
* 4) "select * from T where t1 != 2.1" is converted to "select * from T"
* 5) "select * from T where t1 <= 2.0" is converted to "select * from T
where t1 <= 2"
- * 6) "select * from T where t1 <= 2.1" is converted to "select * from T
where t1 <3"
+ * 6) "select * from T where t1 <= 2.1" is converted to "select * from T
where t1 <=2"
* 7) "select * from T where t1 >= 2.0" is converted to "select * from T
where t1 >= 2"
* 8) "select * from T where t1 >= 2.1" is converted to "select * from T
where t1> 2"
* 9) "select * from T where t1 <2.0" is converted to "select * from T
where t1 <2"
- * 10) "select * from T where t1 <2.1" is converted to "select * from T
where t1 <3"
+ * 10) "select * from T where t1 <2.1" is converted to "select * from T
where t1 <=2"
* 11) "select * from T where t1> 2.0" is converted to "select * from T
where t1> 2"
* 12) "select * from T where t1> 2.1" is converted to "select * from T
where t1> 2"
*/
- private Expr rewriteBigintSlotRefCompareDecimalLiteral(Expr expr0, Expr
expr1, BinaryPredicate.Operator op)
- throws AnalysisException {
- if (((DecimalLiteral) expr1).getDoubleValue() % (int)
(((DecimalLiteral) expr1).getDoubleValue()) != 0) {
- if (op == BinaryPredicate.Operator.EQ || op ==
BinaryPredicate.Operator.EQ_FOR_NULL) {
- return new BoolLiteral(false);
- } else if (op == BinaryPredicate.Operator.NE) {
+ private Expr rewriteBigintSlotRefCompareDecimalLiteral(Expr expr0,
DecimalLiteral expr1,
+ BinaryPredicate.Operator op) {
+ Type columnType = expr0.getSrcSlotRef().getColumn().getType();
+ try {
+ // Convert childExpr to column type and compare the converted
values. There are 3 possible situations:
+ // case 1. The value of childExpr exceeds the range of the column
type, then castTo() will throw an
+ // exception. For example, the value of childExpr is 128.0 and
the column type is tinyint.
+ // case 2. childExpr is converted to column type, but the value of
childExpr loses precision.
+ // For example, 2.1 is converted to 2;
+ // case 3. childExpr is precisely converted to column type. For
example, 2.0 is converted to 2.
+ LiteralExpr newExpr = (LiteralExpr) expr1.castTo(columnType);
+ int compResult = expr1.compareLiteral(newExpr);
+ // case 2
+ if (compResult != 0) {
+ if (op == Operator.EQ || op == Operator.EQ_FOR_NULL) {
+ return new BoolLiteral(false);
+ } else if (op == Operator.NE) {
+ return new BoolLiteral(true);
+ }
+
+ if (compResult > 0) {
+ if (op == Operator.LT) {
+ op = Operator.LE;
+ } else if (op == Operator.GE) {
+ op = Operator.GT;
+ }
+ } else {
+ if (op == Operator.LE) {
+ op = Operator.LT;
+ } else if (op == Operator.GT) {
+ op = Operator.GE;
+ }
+ }
+ }
+ // case 3
+ return new BinaryPredicate(op, expr0.castTo(columnType), newExpr);
+ } catch (AnalysisException e) {
+ // case 1
+ IntLiteral colTypeMinValue = IntLiteral.createMinValue(columnType);
+ IntLiteral colTypeMaxValue = IntLiteral.createMaxValue(columnType);
+ if (op == Operator.NE || ((expr1).compareLiteral(colTypeMinValue)
< 0 && (op == Operator.GE
+ || op == Operator.GT)) ||
((expr1).compareLiteral(colTypeMaxValue) > 0 && (op == Operator.LE
+ || op == Operator.LT))) {
return new BoolLiteral(true);
- } else if (op == BinaryPredicate.Operator.LE) {
- ((DecimalLiteral) expr1).roundCeiling();
- op = BinaryPredicate.Operator.LT;
- } else if (op == BinaryPredicate.Operator.GE) {
- ((DecimalLiteral) expr1).roundFloor();
- op = BinaryPredicate.Operator.GT;
- } else if (op == BinaryPredicate.Operator.LT) {
- ((DecimalLiteral) expr1).roundCeiling();
- } else if (op == BinaryPredicate.Operator.GT) {
- ((DecimalLiteral) expr1).roundFloor();
}
+ return new BoolLiteral(false);
}
- expr0 = expr0.getChild(0);
- expr1 = expr1.castTo(Type.BIGINT);
- return new BinaryPredicate(op, expr0, expr1);
}
@Override
@@ -95,7 +123,7 @@ public class RewriteBinaryPredicatesRule implements
ExprRewriteRule {
Expr expr1 = expr.getChild(1);
if (expr0 instanceof CastExpr && expr0.getType() == Type.DECIMALV2 &&
expr0.getChild(0) instanceof SlotRef
&& expr0.getChild(0).getType().getResultType() == Type.BIGINT
&& expr1 instanceof DecimalLiteral) {
- return rewriteBigintSlotRefCompareDecimalLiteral(expr0, expr1, op);
+ return rewriteBigintSlotRefCompareDecimalLiteral(expr0,
(DecimalLiteral) expr1, op);
}
return expr;
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
index 5577a566be..9083fea261 100755
--- a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
@@ -272,25 +272,27 @@ public class SelectStmtTest {
+ " );";
SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql,
ctx);
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(),
ctx).getExprRewriter());
- String rewritedFragment1 = "((`t1`.`k2` = `t4`.`k2` AND `t3`.`k3` =
`t1`.`k3` "
- + "AND ((`t1`.`k4` >= 50 AND `t1`.`k4` <= 200) AND "
- + "(`t3`.`k1` = 'D' OR `t3`.`k1` = 'S' OR `t3`.`k1` = 'W') "
- + "AND (`t4`.`k3` = '2 yr Degree' OR `t4`.`k3` = 'Advanced
Degree' OR `t4`.`k3` = 'Secondary') "
- + "AND (`t4`.`k4` = 1 OR `t4`.`k4` = 3))) "
- + "AND ((`t3`.`k1` = 'D' AND `t4`.`k3` = '2 yr Degree' "
- + "AND `t1`.`k4` >= 100 AND `t1`.`k4` <= 150 AND `t4`.`k4` =
3) "
- + "OR (`t3`.`k1` = 'S' AND `t4`.`k3` = 'Secondary' AND
`t1`.`k4` >= 50 "
- + "AND `t1`.`k4` <= 100 AND `t4`.`k4` = 1) OR (`t3`.`k1` = 'W'
AND `t4`.`k3` = 'Advanced Degree' "
- + "AND `t1`.`k4` >= 150 AND `t1`.`k4` <= 200 AND `t4`.`k4` =
1)))";
- String rewritedFragment2 = "((`t1`.`k1` = `t5`.`k1` AND `t5`.`k2` =
'United States' "
- + "AND ((`t1`.`k4` >= 50 AND `t1`.`k4` <= 300) "
- + "AND `t5`.`k3` IN ('CO', 'IL', 'MN', 'OH', 'MT', 'NM', 'TX',
'MO', 'MI'))) "
- + "AND ((`t5`.`k3` IN ('CO', 'IL', 'MN') AND `t1`.`k4` >= 100
AND `t1`.`k4` <= 200) "
- + "OR (`t5`.`k3` IN ('OH', 'MT', 'NM') AND `t1`.`k4` >= 150
AND `t1`.`k4` <= 300) OR (`t5`.`k3` IN "
- + "('TX', 'MO', 'MI') AND `t1`.`k4` >= 50 AND `t1`.`k4` <=
250)))";
- System.out.println(stmt.toSql());
- Assert.assertTrue(stmt.toSql().contains(rewritedFragment1));
- Assert.assertTrue(stmt.toSql().contains(rewritedFragment2));
+ String commonExpr1 = "`t1`.`k2` = `t4`.`k2`";
+ String commonExpr2 = "`t3`.`k3` = `t1`.`k3`";
+ String commonExpr3 = "`t1`.`k1` = `t5`.`k1`";
+ String commonExpr4 = "t5`.`k2` = 'United States'";
+ String betweenExpanded1 = "`t1`.`k4` >= 100 AND `t1`.`k4` <= 150";
+ String betweenExpanded2 = "`t1`.`k4` >= 50 AND `t1`.`k4` <= 100";
+ String betweenExpanded3 = "`t1`.`k4` >= 50 AND `t1`.`k4` <= 250";
+
+ String rewrittenSql = stmt.toSql();
+ System.out.println(rewrittenSql);
+ Assert.assertTrue(rewrittenSql.contains(commonExpr1));
+ Assert.assertEquals(rewrittenSql.indexOf(commonExpr1),
rewrittenSql.lastIndexOf(commonExpr1));
+ Assert.assertTrue(rewrittenSql.contains(commonExpr2));
+ Assert.assertEquals(rewrittenSql.indexOf(commonExpr2),
rewrittenSql.lastIndexOf(commonExpr2));
+ Assert.assertTrue(rewrittenSql.contains(commonExpr3));
+ Assert.assertEquals(rewrittenSql.indexOf(commonExpr3),
rewrittenSql.lastIndexOf(commonExpr3));
+ Assert.assertTrue(rewrittenSql.contains(commonExpr4));
+ Assert.assertEquals(rewrittenSql.indexOf(commonExpr4),
rewrittenSql.lastIndexOf(commonExpr4));
+ Assert.assertTrue(rewrittenSql.contains(betweenExpanded1));
+ Assert.assertTrue(rewrittenSql.contains(betweenExpanded2));
+ Assert.assertTrue(rewrittenSql.contains(betweenExpanded3));
String sql2 = "select\n"
+ " avg(t1.k4)\n"
diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java
b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java
index d52314518a..5b50e1972f 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java
@@ -422,11 +422,11 @@ public class PlannerTest extends TestWithFeService {
compare.accept("select * from db1.tbl2 where k1 != 2.0", "select *
from db1.tbl2 where k1 != 2");
compare.accept("select * from db1.tbl2 where k1 != 2.1", "select *
from db1.tbl2");
compare.accept("select * from db1.tbl2 where k1 <= 2.0", "select *
from db1.tbl2 where k1 <= 2");
- compare.accept("select * from db1.tbl2 where k1 <= 2.1", "select *
from db1.tbl2 where k1 < 3");
+ compare.accept("select * from db1.tbl2 where k1 <= 2.1", "select *
from db1.tbl2 where k1 <= 2");
compare.accept("select * from db1.tbl2 where k1 >= 2.0", "select *
from db1.tbl2 where k1 >= 2");
compare.accept("select * from db1.tbl2 where k1 >= 2.1", "select *
from db1.tbl2 where k1 > 2");
compare.accept("select * from db1.tbl2 where k1 < 2.0", "select * from
db1.tbl2 where k1 < 2");
- compare.accept("select * from db1.tbl2 where k1 < 2.1", "select * from
db1.tbl2 where k1 < 3");
+ compare.accept("select * from db1.tbl2 where k1 < 2.1", "select * from
db1.tbl2 where k1 <= 2");
compare.accept("select * from db1.tbl2 where k1 > 2.0", "select * from
db1.tbl2 where k1 > 2");
compare.accept("select * from db1.tbl2 where k1 > 2.1", "select * from
db1.tbl2 where k1 > 2");
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java
b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java
new file mode 100644
index 0000000000..0b06502e11
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.rewrite;
+
+import org.apache.doris.analysis.BinaryPredicate;
+import org.apache.doris.analysis.BinaryPredicate.Operator;
+import org.apache.doris.analysis.BoolLiteral;
+import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.LiteralExpr;
+import org.apache.doris.analysis.SelectStmt;
+import org.apache.doris.catalog.PrimitiveType;
+import org.apache.doris.qe.StmtExecutor;
+import org.apache.doris.utframe.TestWithFeService;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class RewriteBinaryPredicatesRuleTest extends TestWithFeService {
+ @Override
+ protected void runBeforeAll() throws Exception {
+ connectContext = createDefaultCtx();
+ createDatabase("db");
+ useDatabase("db");
+ String createTable = "create table table1(id smallint, cost bigint
sum) "
+ + "aggregate key(`id`) distributed by hash (`id`) buckets 4 "
+ + "properties (\"replication_num\"=\"1\");";
+ createTable(createTable);
+ }
+
+ @Test
+ public void testNormal() throws Exception {
+ testBase(Operator.EQ, "2.0", Operator.EQ, 2L);
+ testBoolean(Operator.EQ, "2.5", false);
+
+ testBase(Operator.NE, "2.0", Operator.NE, 2L);
+ testBoolean(Operator.NE, "2.5", true);
+
+ testBase(Operator.LE, "2.0", Operator.LE, 2L);
+ testBase(Operator.LE, "-2.5", Operator.LT, -2L);
+ testBase(Operator.LE, "2.5", Operator.LE, 2L);
+
+ testBase(Operator.GE, "2.0", Operator.GE, 2L);
+ testBase(Operator.GE, "-2.5", Operator.GE, -2L);
+ testBase(Operator.GE, "2.5", Operator.GT, 2L);
+
+ testBase(Operator.LT, "2.0", Operator.LT, 2L);
+ testBase(Operator.LT, "-2.5", Operator.LT, -2L);
+ testBase(Operator.LT, "2.5", Operator.LE, 2L);
+
+ testBase(Operator.GT, "2.0", Operator.GT, 2L);
+ testBase(Operator.GT, "-2.5", Operator.GE, -2L);
+ testBase(Operator.GT, "2.5", Operator.GT, 2L);
+ }
+
+ @Test
+ public void testOutOfRange() throws Exception {
+ // 32767 -32768
+ testBoolean(Operator.EQ, "-32769.0", false);
+ testBase(Operator.EQ, "32767.0", Operator.EQ, 32767L);
+
+ testBoolean(Operator.NE, "32768.0", true);
+
+ testBoolean(Operator.LE, "32768.2", true);
+ testBoolean(Operator.LE, "-32769.1", false);
+ testBase(Operator.LE, "32767.0", Operator.LE, 32767L);
+
+ testBoolean(Operator.GE, "32768.1", false);
+ testBoolean(Operator.GE, "-32769.1", true);
+ testBase(Operator.GE, "32767.0", Operator.GE, 32767L);
+
+ testBoolean(Operator.LT, "32768.1", true);
+ testBoolean(Operator.LT, "-32769.1", false);
+ testBase(Operator.LT, "32767.1", Operator.LE, 32767L);
+
+ testBoolean(Operator.GT, "32768.1", false);
+ testBoolean(Operator.GT, "-32769.1", true);
+ testBase(Operator.GT, "32767.0", Operator.GT, 32767L);
+ }
+
+ private void testBase(Operator operator, String queryLiteral, Operator
expectedOperator, long expectedChild1)
+ throws Exception {
+ Expr expr1 = getExpr(operator, queryLiteral);
+ Assertions.assertTrue(expr1 instanceof BinaryPredicate);
+ Assertions.assertEquals(expectedOperator, ((BinaryPredicate)
expr1).getOp());
+ Assertions.assertEquals(PrimitiveType.SMALLINT,
expr1.getChild(0).getType().getPrimitiveType());
+ Assertions.assertEquals(PrimitiveType.SMALLINT,
expr1.getChild(1).getType().getPrimitiveType());
+ Assertions.assertEquals(expectedChild1, ((LiteralExpr)
expr1.getChild(1)).getLongValue());
+ }
+
+ private void testBoolean(Operator operator, String queryLiteral, boolean
result) throws Exception {
+ Expr expr1 = getExpr(operator, queryLiteral);
+ Assertions.assertTrue(expr1 instanceof BoolLiteral);
+ Assertions.assertEquals(result, ((BoolLiteral) expr1).getValue());
+ }
+
+ private Expr getExpr(Operator operator, String queryLiteral) throws
Exception {
+ String queryFormat = "select * from table1 where id %s %s;";
+ String query = String.format(queryFormat, operator.toString(),
queryLiteral);
+ StmtExecutor executor1 = getSqlStmtExecutor(query);
+ Assertions.assertNotNull(executor1);
+ return ((SelectStmt) executor1.getParsedStmt()).getWhereClause();
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]