This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch dev-1.1.2
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/dev-1.1.2 by this push:
     new cb1f53052f [improvement][fix](planner) Add a rewrite rule to optimize 
InPredicate. (#9739) (#11559)
cb1f53052f is described below

commit cb1f53052f23ec5f50dd06afb8ce9e72941c69a5
Author: morrySnow <[email protected]>
AuthorDate: Fri Aug 5 16:43:44 2022 +0800

    [improvement][fix](planner) Add a rewrite rule to optimize InPredicate. 
(#9739) (#11559)
    
    1. Convert child expressions in InPredicate to column type and discard 
child expressions in them that cannot be converted exactly.
    2. Fix the bug of ColumnRange exception caused by InPredicate child 
expressions type conversion.
    3. Fix the problem that the tablet could not be hit due caused by 
InPredicate child expressions type conversion.
    
    Co-authored-by: luozenglin <[email protected]>
---
 .../java/org/apache/doris/analysis/Analyzer.java   |   2 +
 .../org/apache/doris/analysis/DecimalLiteral.java  |   2 +
 .../org/apache/doris/analysis/FloatLiteral.java    |   3 +-
 .../doris/rewrite/RewriteInPredicateRule.java      | 115 ++++++++++++++
 .../doris/rewrite/RewriteInPredicateRuleTest.java  | 175 +++++++++++++++++++++
 5 files changed, 296 insertions(+), 1 deletion(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java
index ce9cfea77a..932b0b6a0f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java
@@ -51,6 +51,7 @@ import org.apache.doris.rewrite.RewriteBinaryPredicatesRule;
 import org.apache.doris.rewrite.RewriteDateLiteralRule;
 import org.apache.doris.rewrite.RewriteEncryptKeyRule;
 import org.apache.doris.rewrite.RewriteFromUnixTimeRule;
+import org.apache.doris.rewrite.RewriteInPredicateRule;
 import org.apache.doris.rewrite.mvrewrite.CountDistinctToBitmap;
 import org.apache.doris.rewrite.mvrewrite.CountDistinctToBitmapOrHLLRule;
 import org.apache.doris.rewrite.mvrewrite.CountFieldToSum;
@@ -339,6 +340,7 @@ public class Analyzer {
             rules.add(CompoundPredicateWriteRule.INSTANCE);
             rules.add(RewriteDateLiteralRule.INSTANCE);
             rules.add(RewriteEncryptKeyRule.INSTANCE);
+            rules.add(RewriteInPredicateRule.INSTANCE);
             rules.add(RewriteAliasFunctionRule.INSTANCE);
             List<ExprRewriteRule> onceRules = Lists.newArrayList();
             onceRules.add(ExtractCommonFactorsRule.INSTANCE);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
index 1cf3fcd7de..397acac355 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
@@ -250,6 +250,8 @@ public class DecimalLiteral extends LiteralExpr {
             return new IntLiteral(value.longValue(), targetType);
         } else if (targetType.isStringType()) {
             return new StringLiteral(value.toString());
+        } else if (targetType.isLargeIntType()) {
+            return new LargeIntLiteral(value.toBigInteger().toString());
         }
         return super.uncheckedCastTo(targetType);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java
index 0b70f12454..0071d8f5ad 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java
@@ -168,7 +168,8 @@ public class FloatLiteral extends LiteralExpr {
             }
             return this;
         } else if (targetType.isDecimalV2()) {
-            return new DecimalLiteral(new BigDecimal(value));
+            // the double constructor does an exact translation, use valueOf() 
instead.
+            return new DecimalLiteral(BigDecimal.valueOf(value));
         }
         return this;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteInPredicateRule.java 
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteInPredicateRule.java
new file mode 100644
index 0000000000..b37377e72a
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteInPredicateRule.java
@@ -0,0 +1,115 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.rewrite;
+
+import org.apache.doris.analysis.Analyzer;
+import org.apache.doris.analysis.BoolLiteral;
+import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.InPredicate;
+import org.apache.doris.analysis.LiteralExpr;
+import org.apache.doris.analysis.SlotRef;
+import org.apache.doris.analysis.Subquery;
+import org.apache.doris.catalog.Type;
+import org.apache.doris.common.AnalysisException;
+import org.apache.doris.rewrite.ExprRewriter.ClauseType;
+
+import com.google.common.collect.Lists;
+
+import java.util.List;
+
+/**
+ * Optimize the InPredicate when the child expr type is integerType, 
largeIntType, floatingPointType, decimalV2,
+ * char, varchar, string and the column type is integerType, largeIntType: 
convert the child expr type to the column
+ * type and discard the expressions that cannot be converted exactly.
+ *
+ * <p>For example:<br>
+ * column type is integerType or largeIntType, then:<br>
+ * col in (1, 2.5, 2.0, "3.0", "4.6") -> col in (1, 2, 3)<br>
+ * col in (2.5, "4.6") -> false<br>
+ * column type is tinyType, then:<br>
+ * col in (1, 2.0, 128, "1000") -> col in (1, 2)
+ */
+public class RewriteInPredicateRule implements ExprRewriteRule {
+    public static ExprRewriteRule INSTANCE = new RewriteInPredicateRule();
+
+    @Override
+    public Expr apply(Expr expr, Analyzer analyzer, ClauseType clauseType) 
throws AnalysisException {
+        if (!(expr instanceof InPredicate)) {
+            return expr;
+        }
+        InPredicate inPredicate = (InPredicate) expr;
+        SlotRef slotRef;
+        if (inPredicate.contains(Subquery.class) || 
!inPredicate.isLiteralChildren() || inPredicate.isNotIn()
+                || !(inPredicate.getChild(0).unwrapExpr(false) instanceof 
SlotRef)
+                || (slotRef = inPredicate.getChild(0).getSrcSlotRef()) == null 
|| slotRef.getColumn() == null) {
+            return expr;
+        }
+        Type columnType = slotRef.getColumn().getType();
+        if (!columnType.isFixedPointType()) {
+            return expr;
+        }
+
+        Expr newColumnExpr = expr.getChild(0).getType().getPrimitiveType() == 
columnType.getPrimitiveType()
+                ? expr.getChild(0) : expr.getChild(0).castTo(columnType);
+        List<Expr> newInList = Lists.newArrayList();
+        boolean isCast = false;
+        for (int i = 1; i < inPredicate.getChildren().size(); ++i) {
+            LiteralExpr childExpr = (LiteralExpr) inPredicate.getChild(i);
+            if (!(childExpr.getType().isNumericType() || 
childExpr.getType().getPrimitiveType().isCharFamily())) {
+                return expr;
+            }
+            if 
(childExpr.getType().getPrimitiveType().equals(columnType.getPrimitiveType())) {
+                newInList.add(childExpr);
+                continue;
+            }
+
+            // StringLiteral "2.0" cannot be directly converted to IntLiteral 
or LargeIntLiteral, and FloatLiteral
+            // cannot be directly converted to LargeIntLiteral, so it is 
converted to decimal first.
+            if (childExpr.getType().getPrimitiveType().isCharFamily() || 
childExpr.getType().isFloatingPointType()) {
+                try {
+                    childExpr = (LiteralExpr) childExpr.castTo(Type.DECIMALV2);
+                } catch (AnalysisException e) {
+                    continue;
+                }
+            }
+
+            try {
+                // Convert childExpr to column type and compare the converted 
values. There are 3 possible situations:
+                // 1. The value of childExpr exceeds the range of the column 
type, then castTo() will throw an
+                //   exception. For example, the value of childExpr is 128 and 
the column type is tinyint.
+                // 2. childExpr is converted to column type, but the value of 
childExpr loses precision.
+                //   For example, 2.1 is converted to 2;
+                // 3. childExpr is precisely converted to column type. For 
example, 2.0 is converted to 2.
+                // In cases 1 and 2 above, childExpr should be discarded.
+                LiteralExpr newExpr = (LiteralExpr) 
childExpr.castTo(columnType);
+                if (childExpr.compareLiteral(newExpr) == 0) {
+                    isCast = true;
+                    newInList.add(newExpr);
+                }
+            } catch (AnalysisException ignored) {
+                // pass
+            }
+        }
+        if (newInList.isEmpty()) {
+            return new BoolLiteral(false);
+        }
+        // Expr rewriting if there is childExpr discarded or type is converted.
+        return newInList.size() + 1 < expr.getChildren().size() || isCast
+                ? new InPredicate(newColumnExpr, newInList, false) : expr;
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteInPredicateRuleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteInPredicateRuleTest.java
new file mode 100644
index 0000000000..9400256fad
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteInPredicateRuleTest.java
@@ -0,0 +1,175 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.rewrite;
+
+import org.apache.doris.analysis.BoolLiteral;
+import org.apache.doris.analysis.CreateDbStmt;
+import org.apache.doris.analysis.CreateTableStmt;
+import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.InPredicate;
+import org.apache.doris.analysis.IntLiteral;
+import org.apache.doris.analysis.LargeIntLiteral;
+import org.apache.doris.analysis.LiteralExpr;
+import org.apache.doris.analysis.SelectStmt;
+import org.apache.doris.catalog.Catalog;
+import org.apache.doris.catalog.PrimitiveType;
+import org.apache.doris.cluster.ClusterNamespace;
+import org.apache.doris.common.FeConstants;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.qe.QueryState;
+import org.apache.doris.qe.StmtExecutor;
+import org.apache.doris.system.SystemInfoService;
+import org.apache.doris.utframe.UtFrameUtils;
+
+import com.google.common.base.Joiner;
+import com.google.common.collect.Lists;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.List;
+import java.util.UUID;
+
+public class RewriteInPredicateRuleTest {
+    private static String runningDir = "fe/mocked/RewriteInPredicateRuleTest/" 
+ UUID.randomUUID() + "/";
+    private static ConnectContext ctx;
+    private static final String DB_NAME = "testdb";
+    private static final String TABLE_SMALL = "table_small";
+    private static final String TABLE_LARGE = "table_large";
+
+    @BeforeClass
+    public static void runBeforeAll() throws Exception {
+        FeConstants.runningUnitTest = true;
+        UtFrameUtils.createDorisCluster(runningDir, 2);
+        ctx = UtFrameUtils.createDefaultCtx();
+        String createDbStmtStr = "CREATE DATABASE " + DB_NAME;
+        CreateDbStmt createDbStmt = (CreateDbStmt) 
UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
+        Catalog.getCurrentCatalog().createDb(createDbStmt);
+        
ctx.setDatabase(ClusterNamespace.getFullName(SystemInfoService.DEFAULT_CLUSTER, 
DB_NAME));
+        String createTableFormat = "create table %s(id %s, `date` datetime, 
cost bigint sum) "
+                + "aggregate key(`id`, `date`) distributed by hash (`id`) 
buckets 4 "
+                + "properties (\"replication_num\"=\"1\");";
+        String createTableSmall = String.format(createTableFormat, 
TABLE_SMALL, PrimitiveType.SMALLINT);
+        String createTableLarge = String.format(createTableFormat, 
TABLE_LARGE, PrimitiveType.LARGEINT);
+        CreateTableStmt stmt = (CreateTableStmt) 
UtFrameUtils.parseAndAnalyzeStmt(createTableSmall, ctx);
+        Catalog.getCurrentCatalog().createTable(stmt);
+        stmt = (CreateTableStmt) 
UtFrameUtils.parseAndAnalyzeStmt(createTableLarge, ctx);
+        Catalog.getCurrentCatalog().createTable(stmt);
+    }
+
+    public StmtExecutor getSqlStmtExecutor(String queryStr) throws Exception {
+        ctx.getState().reset();
+        StmtExecutor stmtExecutor = new StmtExecutor(ctx, queryStr);
+        stmtExecutor.execute();
+        if (ctx.getState().getStateType() != QueryState.MysqlStateType.ERR) {
+            return stmtExecutor;
+        } else {
+            return null;
+        }
+    }
+
+    @Test
+    public void testIntLiteralAndLargeIntLiteral() throws Exception {
+        // id in (TINY_INT_MIN, SMALL_INT_MIN, INT_MIN, BIG_INT_MAX, 
LARGE_INT_MAX)
+        // => id in (TINY_INT_MIN, SMALL_INT_MIN)
+        testBase(3, PrimitiveType.SMALLINT, IntLiteral.TINY_INT_MIN, 
TABLE_SMALL,
+                String.valueOf(IntLiteral.TINY_INT_MIN), 
String.valueOf(IntLiteral.SMALL_INT_MAX),
+                String.valueOf(IntLiteral.INT_MIN), 
String.valueOf(IntLiteral.BIG_INT_MAX),
+                LargeIntLiteral.LARGE_INT_MAX.toString());
+
+        // id in (TINY_INT_MIN, SMALL_INT_MIN, INT_MIN, BIG_INT_MAX, 
LARGE_INT_MAX)
+        // => id in (TINY_INT_MIN, SMALL_INT_MIN, INT_MIN, BIG_INT_MAX, 
LARGE_INT_MAX)
+        testBase(6, PrimitiveType.LARGEINT, IntLiteral.TINY_INT_MIN, 
TABLE_LARGE,
+                String.valueOf(IntLiteral.TINY_INT_MIN), 
String.valueOf(IntLiteral.SMALL_INT_MAX),
+                String.valueOf(IntLiteral.INT_MIN), 
String.valueOf(IntLiteral.BIG_INT_MAX),
+                LargeIntLiteral.LARGE_INT_MAX.toString());
+    }
+
+    @Test
+    public void testDecimalLiteral() throws Exception {
+        // type of id is smallint: id in (2.0, 3.5) => id in (2)
+        testBase(2, PrimitiveType.SMALLINT, 2, TABLE_SMALL, "2.0", "3.5");
+
+        testBase(2, PrimitiveType.SMALLINT, 3, TABLE_SMALL, "2.1", "3.0", 
"3.5");
+
+        // type of id is largeint: id in (2.0, 3.5) => id in (2)
+        testBase(2, PrimitiveType.LARGEINT, 2, TABLE_LARGE, "2.0", "3.5");
+    }
+
+    @Test
+    public void testStringLiteral() throws Exception {
+        // type of id is smallint: id in ("2.0", "3.5") => id in (2)
+        testBase(2, PrimitiveType.SMALLINT, 2, TABLE_SMALL, "\"2.0\"", 
"\"3.5\"");
+
+        // type of id is largeint: id in ("2.0", "3.5") => id in (2)
+        testBase(2, PrimitiveType.LARGEINT, 2, TABLE_LARGE, "\"2.0\"", 
"\"3.5\"");
+    }
+
+    @Test
+    public void testBooleanLiteral() throws Exception {
+        // type of id is smallint: id in (true, false) => id in (1, 0)
+        testBase(3, PrimitiveType.SMALLINT, 0, TABLE_SMALL, "false", "true");
+
+        // type of id is largeint: id in (true, false) => id in (1, 0)
+        testBase(3, PrimitiveType.LARGEINT, 1, TABLE_LARGE, "true", "false");
+    }
+
+    @Test
+    public void testMixedLiteralExpr() throws Exception {
+        // type of id is smallint: id in (1, 2.0, 3.3) -> id in (1, 2)
+        testBase(3, PrimitiveType.SMALLINT, 1, TABLE_SMALL, "1", "2.0", "3.3");
+        // type of id is smallint: id in (1, 1.0, 1.1) => id in (1, 1)
+        testBase(3, PrimitiveType.SMALLINT, 1, TABLE_SMALL, "1", "1.0", "1.1");
+        // type of id is smallint: id in ("1.0", 2.0, 3.3, "5.2") => id in (1, 
2)
+        testBase(3, PrimitiveType.SMALLINT, 1, TABLE_SMALL, "\"1.0\"", "2.0", 
"3.3", "\"5.2\"");
+        // type of id is smallint: id in (false, 2.0, 3.3, "5.2", true) => id 
in (0, 2, 1)
+        testBase(4, PrimitiveType.SMALLINT, 0, TABLE_SMALL, "false", "2.0", 
"3.3", "\"5.2\"", "true");
+
+        // largeint
+        testBase(3, PrimitiveType.LARGEINT, 1, TABLE_LARGE, "1", "2.0", "3.3");
+        testBase(3, PrimitiveType.LARGEINT, 1, TABLE_LARGE, "1", "1.0", "1.1");
+        testBase(3, PrimitiveType.LARGEINT, 1, TABLE_LARGE, "\"1.0\"", "2.0", 
"3.3", "\"5.2\"");
+        testBase(4, PrimitiveType.LARGEINT, 0, TABLE_LARGE, "false", "2.0", 
"3.3", "\"5.2\"", "true");
+    }
+
+    @Test
+    public void testEmpty() throws Exception {
+        // type of id is smallint: id in (5.5, "6.2") => false
+        String query = "select * from table_small where id in (5.5, \"6.2\");";
+        StmtExecutor executor1 = getSqlStmtExecutor(query);
+        Expr expr1 = ((SelectStmt) executor1.getParsedStmt()).getWhereClause();
+        Assert.assertTrue(expr1 instanceof BoolLiteral);
+        Assert.assertFalse(((BoolLiteral) expr1).getValue());
+    }
+
+    private void testBase(int childrenNum, PrimitiveType type, long 
expectedOfChild1, String... literals)
+            throws Exception {
+        List<String> list = Lists.newArrayList();
+        Lists.newArrayList(literals).forEach(e -> list.add("%s"));
+        list.remove(list.size() - 1);
+        String queryFormat = "select * from %s where id in (" + Joiner.on(", 
").join(list) + ");";
+        String query = String.format(queryFormat, literals);
+        StmtExecutor executor1 = getSqlStmtExecutor(query);
+        Expr expr1 = ((SelectStmt) executor1.getParsedStmt()).getWhereClause();
+        Assert.assertTrue(expr1 instanceof InPredicate);
+        Assert.assertEquals(childrenNum, expr1.getChildren().size());
+        Assert.assertEquals(type, 
expr1.getChild(0).getType().getPrimitiveType());
+        Assert.assertEquals(type, 
expr1.getChild(1).getType().getPrimitiveType());
+        Assert.assertEquals(expectedOfChild1, ((LiteralExpr) 
expr1.getChild(1)).getLongValue());
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to