This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new d60bb81  [SQL Function] Calculate 'case when expr' when possible 
(#3396)
d60bb81 is described below

commit d60bb81cb02bd12be6142c975ed78ecfcb5e518d
Author: wangbo <[email protected]>
AuthorDate: Thu May 7 22:04:09 2020 +0800

    [SQL Function] Calculate 'case when expr' when possible (#3396)
    
    Calculate 'case when expr' when possible
---
 .../java/org/apache/doris/analysis/CaseExpr.java   |  82 +++++++++++
 .../apache/doris/rewrite/FoldConstantsRule.java    |   6 +
 .../org/apache/doris/planner/QueryPlanTest.java    | 156 +++++++++++++++++----
 3 files changed, 217 insertions(+), 27 deletions(-)

diff --git a/fe/src/main/java/org/apache/doris/analysis/CaseExpr.java 
b/fe/src/main/java/org/apache/doris/analysis/CaseExpr.java
index 602aabd..ffe0d88 100644
--- a/fe/src/main/java/org/apache/doris/analysis/CaseExpr.java
+++ b/fe/src/main/java/org/apache/doris/analysis/CaseExpr.java
@@ -26,6 +26,7 @@ import org.apache.doris.thrift.TExprNodeType;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 
+import java.util.ArrayList;
 import java.util.List;
 
 /**
@@ -101,6 +102,7 @@ public class CaseExpr extends Expr {
         CaseExpr expr = (CaseExpr) obj;
         return hasCaseExpr == expr.hasCaseExpr && hasElseExpr == 
expr.hasElseExpr;
     }
+
     public boolean hasCaseExpr() {
         return hasCaseExpr;
     }
@@ -251,4 +253,84 @@ public class CaseExpr extends Expr {
         }
         return exprs;
     }
+
+    // this method just compare literal value and not completely consistent 
with be,for two cases
+    // 1 not deal float
+    // 2 just compare literal value with same type. for a example sql 'select 
case when 123 then '1' else '2' end as col'
+    //      for be will return '1', because be only regard 0 as false
+    //      but for current LiteralExpr.compareLiteral, `123`' won't be regard 
as true
+    //  the case which two values has different type left to be
+    public static Expr computeCaseExpr(CaseExpr expr) {
+        LiteralExpr caseExpr;
+        int startIndex = 0;
+        int endIndex = expr.getChildren().size();
+        if (expr.hasCaseExpr()) {
+            // just deal literal here
+            // and avoid `float compute` in java,float should be dealt in be
+            Expr caseChildExpr = expr.getChild(0);
+            if (!caseChildExpr.isLiteral()
+                    || caseChildExpr instanceof DecimalLiteral || 
caseChildExpr instanceof FloatLiteral) {
+                return expr;
+            }
+            caseExpr = (LiteralExpr) expr.getChild(0);
+            startIndex++;
+        } else {
+            caseExpr = new BoolLiteral(true);
+        }
+
+        if (caseExpr instanceof NullLiteral) {
+            if (expr.hasElseExpr) {
+                return expr.getChild(expr.getChildren().size() - 1);
+            } else {
+                return new NullLiteral();
+            }
+        }
+
+        if (expr.hasElseExpr) {
+            endIndex--;
+        }
+
+        // early return when the `when expr` can't be converted to constants
+        Expr startExpr = expr.getChild(startIndex);
+        if ((!startExpr.isLiteral() || startExpr instanceof DecimalLiteral || 
startExpr instanceof FloatLiteral)
+                || (!(startExpr instanceof NullLiteral) && 
!startExpr.getClass().toString().equals(caseExpr.getClass().toString()))) {
+            return expr;
+        }
+
+        for (int i = startIndex; i < endIndex; i = i + 2) {
+            Expr currentWhenExpr = expr.getChild(i);
+            // skip null literal
+            if (currentWhenExpr instanceof NullLiteral) {
+                continue;
+            }
+            // stop convert in three cases
+            // 1 not literal
+            // 2 float
+            // 3 `case expr` and `when expr` don't have same type
+            if ((!currentWhenExpr.isLiteral() || currentWhenExpr instanceof 
DecimalLiteral || currentWhenExpr instanceof FloatLiteral)
+                    || 
!currentWhenExpr.getClass().toString().equals(caseExpr.getClass().toString())) {
+                // remove the expr which has been evaluated
+                List<Expr> exprLeft = new ArrayList<>();
+                if (expr.hasCaseExpr()) {
+                    exprLeft.add(caseExpr);
+                }
+                for (int j = i; j < expr.getChildren().size(); j++) {
+                    exprLeft.add(expr.getChild(j));
+                }
+                Expr retCaseExpr = expr.clone();
+                retCaseExpr.getChildren().clear();
+                retCaseExpr.addChildren(exprLeft);
+                return retCaseExpr;
+            } else if (caseExpr.compareLiteral((LiteralExpr) currentWhenExpr) 
== 0) {
+                return expr.getChild(i + 1);
+            }
+        }
+
+        if (expr.hasElseExpr) {
+            return expr.getChild(expr.getChildren().size() - 1);
+        } else {
+            return new NullLiteral();
+        }
+    }
+
 }
diff --git a/fe/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java 
b/fe/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java
index 1c4298e..f697150 100644
--- a/fe/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java
+++ b/fe/src/main/java/org/apache/doris/rewrite/FoldConstantsRule.java
@@ -19,6 +19,7 @@ package org.apache.doris.rewrite;
 
 
 import org.apache.doris.analysis.Analyzer;
+import org.apache.doris.analysis.CaseExpr;
 import org.apache.doris.analysis.CastExpr;
 import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.NullLiteral;
@@ -48,6 +49,11 @@ public class FoldConstantsRule implements ExprRewriteRule {
 
     @Override
     public Expr apply(Expr expr, Analyzer analyzer) throws AnalysisException {
+        // evaluate `case when expr` when possible
+        if (expr instanceof CaseExpr) {
+            return CaseExpr.computeCaseExpr((CaseExpr) expr);
+        }
+
         // Avoid calling Expr.isConstant() because that would lead to repeated 
traversals
         // of the Expr tree. Assumes the bottom-up application of this rule. 
Constant
         // children should have been folded at this point.
diff --git a/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java 
b/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java
index dccd979..759d85d 100644
--- a/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java
+++ b/fe/src/test/java/org/apache/doris/planner/QueryPlanTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.planner;
 
+import org.apache.commons.lang3.StringUtils;
 import org.apache.doris.analysis.CreateDbStmt;
 import org.apache.doris.analysis.CreateTableStmt;
 import org.apache.doris.analysis.DropDbStmt;
@@ -242,31 +243,31 @@ public class QueryPlanTest {
                 "PROPERTIES (\n" +
                 " \"replication_num\" = \"1\"\n" +
                 ");");
-        
+
         createTable("CREATE TABLE test.`pushdown_test` (\n" +
-                "  `k1` tinyint(4) NULL COMMENT \"\",\n" + 
-                "  `k2` smallint(6) NULL COMMENT \"\",\n" + 
-                "  `k3` int(11) NULL COMMENT \"\",\n" + 
-                "  `k4` bigint(20) NULL COMMENT \"\",\n" + 
-                "  `k5` decimal(9, 3) NULL COMMENT \"\",\n" + 
-                "  `k6` char(5) NULL COMMENT \"\",\n" + 
-                "  `k10` date NULL COMMENT \"\",\n" + 
-                "  `k11` datetime NULL COMMENT \"\",\n" + 
-                "  `k7` varchar(20) NULL COMMENT \"\",\n" + 
-                "  `k8` double MAX NULL COMMENT \"\",\n" + 
-                "  `k9` float SUM NULL COMMENT \"\"\n" + 
-                ") ENGINE=OLAP\n" + 
-                "AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, 
`k11`, `k7`)\n" + 
-                "COMMENT \"OLAP\"\n" + 
-                "PARTITION BY RANGE(`k1`)\n" + 
-                "(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" + 
-                "PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" + 
-                "PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" + 
-                "DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" + 
-                "PROPERTIES (\n" + 
-                "\"replication_num\" = \"1\",\n" + 
-                "\"in_memory\" = \"false\",\n" + 
-                "\"storage_format\" = \"DEFAULT\"\n" + 
+                "  `k1` tinyint(4) NULL COMMENT \"\",\n" +
+                "  `k2` smallint(6) NULL COMMENT \"\",\n" +
+                "  `k3` int(11) NULL COMMENT \"\",\n" +
+                "  `k4` bigint(20) NULL COMMENT \"\",\n" +
+                "  `k5` decimal(9, 3) NULL COMMENT \"\",\n" +
+                "  `k6` char(5) NULL COMMENT \"\",\n" +
+                "  `k10` date NULL COMMENT \"\",\n" +
+                "  `k11` datetime NULL COMMENT \"\",\n" +
+                "  `k7` varchar(20) NULL COMMENT \"\",\n" +
+                "  `k8` double MAX NULL COMMENT \"\",\n" +
+                "  `k9` float SUM NULL COMMENT \"\"\n" +
+                ") ENGINE=OLAP\n" +
+                "AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`, `k10`, 
`k11`, `k7`)\n" +
+                "COMMENT \"OLAP\"\n" +
+                "PARTITION BY RANGE(`k1`)\n" +
+                "(PARTITION p1 VALUES [(\"-128\"), (\"-64\")),\n" +
+                "PARTITION p2 VALUES [(\"-64\"), (\"0\")),\n" +
+                "PARTITION p3 VALUES [(\"0\"), (\"64\")))\n" +
+                "DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n" +
+                "PROPERTIES (\n" +
+                "\"replication_num\" = \"1\",\n" +
+                "\"in_memory\" = \"false\",\n" +
+                "\"storage_format\" = \"DEFAULT\"\n" +
                 ");");
     }
 
@@ -711,18 +712,119 @@ public class QueryPlanTest {
         Assert.assertTrue(explainString.contains("PREDICATES: `join1`.`id` > 
1"));
         Assert.assertFalse(explainString.contains("PREDICATES: `join2`.`id` > 
1"));
     }
-    
+
+    @Test
+    public void testConvertCaseWhenToConstant() throws Exception {
+        // basic test
+        String caseWhenSql = "select "
+                + "case when date_format(now(),'%H%i')  < 123 then 1 else 0 
end as col "
+                + "from test.test1 "
+                + "where time = case when date_format(now(),'%H%i')  < 123 
then date_format(date_sub(now(),2),'%Y%m%d') else 
date_format(date_sub(now(),1),'%Y%m%d') end";
+        
Assert.assertTrue(!StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + caseWhenSql), "CASE WHEN"));
+
+        // test 1: case when then
+        // 1.1 multi when in on `case when` and can be converted to constants
+        String sql11 = "select case when false then 2 when true then 3 else 0 
end as col11;";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql11), "constant exprs: \n         3"));
+
+        // 1.2 multi `when expr` in on `case when` ,`when expr` can not be 
converted to constants
+        String sql121 = "select case when false then 2 when substr(k7,2,1) 
then 3 else 0 end as col121 from test.baseall";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql121),
+                "OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 3 ELSE 0 
END"));
+
+        // 1.2.2 when expr which can not be converted to constants in the first
+        String sql122 = "select case when substr(k7,2,1) then 2 when false 
then 3 else 0 end as col122 from test.baseall";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql122),
+                "OUTPUT EXPRS:CASE WHEN substr(`k7`, 2, 1) THEN 2 WHEN FALSE 
THEN 3 ELSE 0 END"));
+
+        // 1.2.3 test return `then expr` in the middle
+        String sql124 = "select case when false then 1 when true then 2 when 
false then 3 else 'other' end as col124";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql124), "constant exprs: \n         '2'"));
+
+        // 1.3 test return null
+        String sql3 = "select case when false then 2 end as col3";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql3), "constant exprs: \n         NULL"));
+
+        // 1.3.1 test return else expr
+        String sql131 = "select case when false then 2 when false then 3 else 
4 end as col131";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql131), "constant exprs: \n         4"));
+
+        // 1.4 nest `case when` and can be converted to constants
+        String sql14 = "select case when (case when true then true else false 
end) then 2 when false then 3 else 0 end as col";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql14), "constant exprs: \n         2"));
+
+        // 1.5 nest `case when` and can not be converted to constants
+        String sql15 = "select case when case when substr(k7,2,1) then true 
else false end then 2 when false then 3 else 0 end as col from test.baseall";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql15),
+                "OUTPUT EXPRS:CASE WHEN CASE WHEN substr(`k7`, 2, 1) THEN TRUE 
ELSE FALSE END THEN 2 WHEN FALSE THEN 3 ELSE 0 END"));
+
+        // 1.6 test when expr is null
+        String sql16 = "select case when null then 1 else 2 end as col16;";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql16), "constant exprs: \n         2"));
+
+        // test 2: case xxx when then
+        // 2.1 test equal
+        String sql2 = "select case 1 when 1 then 'a' when 2 then 'b' else 
'other' end as col2;";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql2), "constant exprs: \n         'a'"));
+
+        // 2.1.2 test not equal
+        String sql212 = "select case 'a' when 1 then 'a' when 'a' then 'b' 
else 'other' end as col212;";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql212), "constant exprs: \n         'b'"));
+
+        // 2.2 test return null
+        String sql22 = "select case 'a' when 1 then 'a' when 'b' then 'b' end 
as col22;";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql22), "constant exprs: \n         NULL"));
+
+        // 2.2.2 test return else
+        String sql222 = "select case 1 when 2 then 'a' when 3 then 'b' else 
'other' end as col222;";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql222), "constant exprs: \n         'other'"));
+
+        // 2.3 test can not convert to constant,middle when expr is not 
constant
+        String sql23 = "select case 'a' when 'b' then 'a' when substr(k7,2,1) 
then 2 when false then 3 else 0 end as col23 from test.baseall";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql23),
+                "OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN 
'0' THEN '3' ELSE '0' END"));
+
+        // 2.3.1  first when expr is not constant
+        String sql231 = "select case 'a' when substr(k7,2,1) then 2 when 1 
then 'a' when false then 3 else 0 end as col231 from test.baseall";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql231),
+                "OUTPUT EXPRS:CASE'a' WHEN substr(`k7`, 2, 1) THEN '2' WHEN 
'1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END"));
+
+        // 2.3.2 case expr is not constant
+        String sql232 = "select case k1 when substr(k7,2,1) then 2 when 1 then 
'a' when false then 3 else 0 end as col232 from test.baseall";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql232),
+                "OUTPUT EXPRS:CASE`k1` WHEN substr(`k7`, 2, 1) THEN '2' WHEN 
'1' THEN 'a' WHEN '0' THEN '3' ELSE '0' END"));
+
+        // 3.1 test float,float in case expr
+        String sql31 = "select case cast(100 as float) when 1 then 'a' when 2 
then 'b' else 'other' end as col31;";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql31),
+                "constant exprs: \n         CASE100.0 WHEN 1.0 THEN 'a' WHEN 
2.0 THEN 'b' ELSE 'other' END"));
+
+        // 4.1 test null in case expr return else
+        String sql41 = "select case null when 1 then 'a' when 2 then 'b' else 
'other' end as col41";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql41), "constant exprs: \n         'other'"));
+
+        // 4.1.2 test null in case expr return null
+        String sql412 = "select case null when 1 then 'a' when 2 then 'b' end 
as col41";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql412), "constant exprs: \n         NULL"));
+
+        // 4.2.1 test null in when expr
+        String sql421 = "select case 'a' when null then 'a' else 'other' end 
as col421";
+        
Assert.assertTrue(StringUtils.containsIgnoreCase(UtFrameUtils.getSQLPlanOrErrorMsg(connectContext,
 "explain " + sql421), "constant exprs: \n         'other'"));
+    }
+
     @Test
     public void testJoinPredicateTransitivityWithSubqueryInWhereClause() 
throws Exception {
         connectContext.setDatabase("default_cluster:test");
-        String sql = "SELECT *\n" + 
+        String sql = "SELECT *\n" +
                 "FROM test.pushdown_test\n" +
                 "WHERE 0 < (\n" +
-                "    SELECT MAX(k9)\n" + 
+                "    SELECT MAX(k9)\n" +
                 "    FROM test.pushdown_test);";
         String explainString = 
UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "explain " + sql);
         Assert.assertTrue(explainString.contains("PLAN FRAGMENT"));
         Assert.assertTrue(explainString.contains("CROSS JOIN"));
         Assert.assertTrue(!explainString.contains("PREDICATES"));
     }
+
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to