This is an automated email from the ASF dual-hosted git repository.

sunnianjun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new 2bca9739289 Support mysql case when then statement parse (#22540)
2bca9739289 is described below

commit 2bca9739289b3ae54a95086518004b4b05dd41a7
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Wed Nov 30 20:41:24 2022 +0800

    Support mysql case when then statement parse (#22540)
---
 .../statement/impl/MySQLStatementSQLVisitor.java   |  25 +++-
 .../SQLNodeConverterEngineParameterizedTest.java   |   2 +-
 .../segment/expression/ExpressionAssert.java       |   6 +-
 .../main/resources/case/dml/select-expression.xml  | 129 +++++++++++++--------
 .../sql/supported/dml/select-expression.xml        |   8 +-
 5 files changed, 107 insertions(+), 63 deletions(-)

diff --git 
a/sql-parser/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/impl/MySQLStatementSQLVisitor.java
 
b/sql-parser/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/impl/MySQLStatementSQLVisitor.java
index 341ee708eec..812ae8e85a4 100644
--- 
a/sql-parser/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/impl/MySQLStatementSQLVisitor.java
+++ 
b/sql-parser/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/impl/MySQLStatementSQLVisitor.java
@@ -36,6 +36,8 @@ import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.BitValu
 import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.BlobValueContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.BooleanLiteralsContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.BooleanPrimaryContext;
+import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CaseExpressionContext;
+import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CaseWhenContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CastFunctionContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CharFunctionContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.MySQLStatementParser.CollateClauseContext;
@@ -154,6 +156,7 @@ import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.OnDupl
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BetweenExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CaseWhenExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.CollateExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExistsSubqueryExpression;
 import 
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
@@ -991,9 +994,7 @@ public abstract class MySQLStatementSQLVisitor extends 
MySQLStatementBaseVisitor
     
     private ASTNode visitRemainSimpleExpr(final SimpleExprContext ctx) {
         if (null != ctx.caseExpression()) {
-            visit(ctx.caseExpression());
-            String text = ctx.start.getInputStream().getText(new 
Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
-            return new CommonExpressionSegment(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), text);
+            return visit(ctx.caseExpression());
         }
         if (null != ctx.BINARY()) {
             return visit(ctx.simpleExpr(0));
@@ -1008,6 +1009,19 @@ public abstract class MySQLStatementSQLVisitor extends 
MySQLStatementBaseVisitor
         return new CommonExpressionSegment(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), text);
     }
     
+    @Override
+    public ASTNode visitCaseExpression(final CaseExpressionContext ctx) {
+        Collection<ExpressionSegment> whenExprs = new LinkedList<>();
+        Collection<ExpressionSegment> thenExprs = new LinkedList<>();
+        for (CaseWhenContext each : ctx.caseWhen()) {
+            whenExprs.add((ExpressionSegment) visit(each.expr(0)));
+            thenExprs.add((ExpressionSegment) visit(each.expr(1)));
+        }
+        ExpressionSegment caseExpr = null == ctx.simpleExpr() ? null : 
(ExpressionSegment) visit(ctx.simpleExpr());
+        ExpressionSegment elseExpr = null == ctx.caseElse() ? null : 
(ExpressionSegment) visit(ctx.caseElse().expr());
+        return new CaseWhenExpression(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), caseExpr, whenExprs, thenExprs, elseExpr);
+    }
+    
     @Override
     public final ASTNode visitMatchExpression(final MatchExpressionContext 
ctx) {
         visit(ctx.expr());
@@ -1503,6 +1517,11 @@ public abstract class MySQLStatementSQLVisitor extends 
MySQLStatementBaseVisitor
             result.setAlias(alias);
             return projection;
         }
+        if (projection instanceof CaseWhenExpression) {
+            ExpressionProjectionSegment result = new 
ExpressionProjectionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), 
getOriginalText(ctx.expr()), (CaseWhenExpression) projection);
+            result.setAlias(alias);
+            return result;
+        }
         LiteralExpressionSegment column = (LiteralExpressionSegment) 
projection;
         ExpressionProjectionSegment result = null == alias
                 ? new ExpressionProjectionSegment(column.getStartIndex(), 
column.getStopIndex(), String.valueOf(column.getLiterals()), column)
diff --git 
a/test/optimize/src/test/java/org/apache/shardingsphere/infra/federation/converter/parameterized/engine/SQLNodeConverterEngineParameterizedTest.java
 
b/test/optimize/src/test/java/org/apache/shardingsphere/infra/federation/converter/parameterized/engine/SQLNodeConverterEngineParameterizedTest.java
index b4733cc0cf3..7a5397331c2 100644
--- 
a/test/optimize/src/test/java/org/apache/shardingsphere/infra/federation/converter/parameterized/engine/SQLNodeConverterEngineParameterizedTest.java
+++ 
b/test/optimize/src/test/java/org/apache/shardingsphere/infra/federation/converter/parameterized/engine/SQLNodeConverterEngineParameterizedTest.java
@@ -138,7 +138,7 @@ public final class SQLNodeConverterEngineParameterizedTest {
         SUPPORTED_SQL_CASE_IDS.add("select_minus");
         SUPPORTED_SQL_CASE_IDS.add("select_minus_order_by");
         SUPPORTED_SQL_CASE_IDS.add("select_minus_order_by_limit");
-        
SUPPORTED_SQL_CASE_IDS.add("select_projections_with_only_expr_for_postgres");
+        SUPPORTED_SQL_CASE_IDS.add("select_projections_with_only_expr");
         SUPPORTED_SQL_CASE_IDS.add("select_natural_join");
         SUPPORTED_SQL_CASE_IDS.add("select_natural_inner_join");
         SUPPORTED_SQL_CASE_IDS.add("select_natural_left_join");
diff --git 
a/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/asserts/segment/expression/ExpressionAssert.java
 
b/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/asserts/segment/expression/ExpressionAssert.java
index 49a9fe1c6a2..53d63bfa78f 100644
--- 
a/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/asserts/segment/expression/ExpressionAssert.java
+++ 
b/test/parser/src/main/java/org/apache/shardingsphere/test/sql/parser/internal/asserts/segment/expression/ExpressionAssert.java
@@ -46,7 +46,6 @@ import 
org.apache.shardingsphere.test.sql.parser.internal.asserts.segment.generi
 import 
org.apache.shardingsphere.test.sql.parser.internal.asserts.segment.owner.OwnerAssert;
 import 
org.apache.shardingsphere.test.sql.parser.internal.asserts.segment.projection.ProjectionAssert;
 import 
org.apache.shardingsphere.test.sql.parser.internal.asserts.statement.dml.impl.SelectStatementAssert;
-import 
org.apache.shardingsphere.test.sql.parser.internal.cases.sql.type.SQLCaseType;
 import 
org.apache.shardingsphere.test.sql.parser.internal.cases.parser.jaxb.segment.impl.expr.ExpectedBetweenExpression;
 import 
org.apache.shardingsphere.test.sql.parser.internal.cases.parser.jaxb.segment.impl.expr.ExpectedBinaryOperationExpression;
 import 
org.apache.shardingsphere.test.sql.parser.internal.cases.parser.jaxb.segment.impl.expr.ExpectedCaseWhenExpression;
@@ -61,6 +60,7 @@ import 
org.apache.shardingsphere.test.sql.parser.internal.cases.parser.jaxb.segm
 import 
org.apache.shardingsphere.test.sql.parser.internal.cases.parser.jaxb.segment.impl.expr.simple.ExpectedParameterMarkerExpression;
 import 
org.apache.shardingsphere.test.sql.parser.internal.cases.parser.jaxb.segment.impl.expr.simple.ExpectedSubquery;
 import 
org.apache.shardingsphere.test.sql.parser.internal.cases.parser.jaxb.segment.impl.function.ExpectedFunction;
+import 
org.apache.shardingsphere.test.sql.parser.internal.cases.sql.type.SQLCaseType;
 
 import java.util.Iterator;
 
@@ -335,8 +335,8 @@ public final class ExpressionAssert {
      * @param expected expected case when expression
      */
     public static void assertCaseWhenExpression(final SQLCaseAssertContext 
assertContext, final CaseWhenExpression actual, final 
ExpectedCaseWhenExpression expected) {
-        assertThat(assertContext.getText("When list size is not same!"), 
actual.getWhenExprs().size(), is(expected.getWhenExprs().size()));
-        assertThat(assertContext.getText("Then list size is not same!"), 
actual.getThenExprs().size(), is(expected.getThenExprs().size()));
+        assertThat(assertContext.getText("When exprs size is not same!"), 
actual.getWhenExprs().size(), is(expected.getWhenExprs().size()));
+        assertThat(assertContext.getText("Then exprs size is not same!"), 
actual.getThenExprs().size(), is(expected.getThenExprs().size()));
         Iterator<ExpectedExpression> whenExprsIterator = 
expected.getWhenExprs().iterator();
         for (ExpressionSegment each : actual.getWhenExprs()) {
             assertExpression(assertContext, each, whenExprsIterator.next());
diff --git a/test/parser/src/main/resources/case/dml/select-expression.xml 
b/test/parser/src/main/resources/case/dml/select-expression.xml
index 92a9899ccbb..f0521221abf 100644
--- a/test/parser/src/main/resources/case/dml/select-expression.xml
+++ b/test/parser/src/main/resources/case/dml/select-expression.xml
@@ -129,50 +129,87 @@
             <column-projection start-index="11" stop-index="30" name="item_id" 
alias="item_id">
                 <owner start-index="11" stop-index="11" name="o" />
             </column-projection>
-            <expression-projection text="case when t.status = 'init' then 
'已启用' when t.status = 'failed' then '已停用' end" start-index="33" 
stop-index="124" alias="stateName" />
+            <expression-projection text="case when t.status = 'init' then 
'已启用' when t.status = 'failed' then '已停用' end" start-index="32" 
stop-index="124" alias="stateName">
+                <expr>
+                    <case-when-expression>
+                        <when-exprs>
+                            <binary-operation-expression start-index="43" 
stop-index="59">
+                                <left>
+                                    <column name="status" start-index="43" 
stop-index="50">
+                                        <owner name="t" start-index="43" 
stop-index="43" />
+                                    </column>
+                                </left>
+                                <operator>=</operator>
+                                <right>
+                                    <literal-expression value="init" 
start-index="54" stop-index="59" />
+                                </right>
+                            </binary-operation-expression>
+                        </when-exprs>
+                        <when-exprs>
+                            <binary-operation-expression start-index="77" 
stop-index="95">
+                                <left>
+                                    <column name="status" start-index="77" 
stop-index="84">
+                                        <owner name="t" start-index="77" 
stop-index="77" />
+                                    </column>
+                                </left>
+                                <operator>=</operator>
+                                <right>
+                                    <literal-expression value="failed" 
start-index="88" stop-index="95" />
+                                </right>
+                            </binary-operation-expression>
+                        </when-exprs>
+                        <then-exprs>
+                            <literal-expression value="已启用" start-index="66" 
stop-index="70" />
+                        </then-exprs>
+                        <then-exprs>
+                            <literal-expression value="已停用" start-index="102" 
stop-index="106" />
+                        </then-exprs>
+                    </case-when-expression>
+                </expr>
+            </expression-projection>
         </projections>
         <from>
             <join-table join-type="LEFT">
                 <left>
-                    <simple-table start-index="135" stop-index="143" 
name="t_order" alias="t" />
+                    <simple-table start-index="131" stop-index="139" 
name="t_order" alias="t" />
                 </left>
                 <right>
-                    <simple-table start-index="155" stop-index="171" 
name="t_order_item" alias="o" />
+                    <simple-table start-index="151" stop-index="167" 
name="t_order_item" alias="o" />
                 </right>
                 <on-condition>
-                    <binary-operation-expression start-index="176" 
stop-index="197">
+                    <binary-operation-expression start-index="172" 
stop-index="193">
                         <left>
-                            <column name="order_id" start-index="176" 
stop-index="185">
-                                <owner name="o" start-index="176" 
stop-index="176" />
+                            <column name="order_id" start-index="172" 
stop-index="181">
+                                <owner name="o" start-index="172" 
stop-index="172" />
                             </column>
                         </left>
                         <operator>=</operator>
                         <right>
-                            <column name="order_id" start-index="188" 
stop-index="197">
-                                <owner name="t" start-index="188" 
stop-index="188" />
+                            <column name="order_id" start-index="184" 
stop-index="193">
+                                <owner name="t" start-index="184" 
stop-index="184" />
                             </column>
                         </right>
                     </binary-operation-expression>
                 </on-condition>
             </join-table>
         </from>
-        <where start-index="199" stop-index="219">
+        <where start-index="195" stop-index="215">
             <expr>
-                <binary-operation-expression start-index="205" 
stop-index="219">
+                <binary-operation-expression start-index="201" 
stop-index="215">
                     <left>
-                        <column name="order_id" start-index="205" 
stop-index="214">
-                            <owner name="t" start-index="205" stop-index="205" 
/>
+                        <column name="order_id" start-index="201" 
stop-index="210">
+                            <owner name="t" start-index="201" stop-index="201" 
/>
                         </column>
                     </left>
                     <operator>=</operator>
                     <right>
-                        <literal-expression value="1000" start-index="216" 
stop-index="219" />
+                        <literal-expression value="1000" start-index="212" 
stop-index="215" />
                     </right>
                 </binary-operation-expression>
             </expr>
         </where>
-        <limit start-index="221" stop-index="227">
-            <row-count value="1" start-index="227" stop-index="227" />
+        <limit start-index="217" stop-index="223">
+            <row-count value="1" start-index="223" stop-index="223" />
         </limit>
     </select>
 
@@ -1703,15 +1740,36 @@
     </select>
 
     <select sql-case-id="select_where_with_simple_expr_with_case" 
parameters="1,'true','false'">
-        <from start-index="14" stop-index="20">
-            <simple-table name="t_order" start-index="14" stop-index="20" />
-        </from>
         <projections start-index="7" stop-index="7">
             <shorthand-projection start-index="7" stop-index="7" />
         </projections>
+        <from start-index="14" stop-index="20">
+            <simple-table name="t_order" start-index="14" stop-index="20" />
+        </from>
         <where start-index="22" stop-index="67" literal-stop-index="78">
             <expr>
-                <common-expression text="CASE WHEN order_id &gt; ? THEN ? ELSE 
? END" literal-text="CASE WHEN order_id > 1 THEN 'true' ELSE 'false' END" 
start-index="28" stop-index="67" literal-stop-index="78" />
+                <case-when-expression>
+                    <when-exprs>
+                        <binary-operation-expression start-index="38" 
stop-index="49">
+                            <left>
+                                <column name="order_id" start-index="38" 
stop-index="45" />
+                            </left>
+                            <operator>&gt;</operator>
+                            <right>
+                                <literal-expression value="1" start-index="49" 
stop-index="49" />
+                                <parameter-marker-expression 
parameter-index="0" start-index="49" stop-index="49" />
+                            </right>
+                        </binary-operation-expression>
+                    </when-exprs>
+                    <then-exprs>
+                        <literal-expression value="true" start-index="56" 
stop-index="61" />
+                        <parameter-marker-expression parameter-index="1" 
start-index="56" stop-index="56" />
+                    </then-exprs>
+                    <else-expr>
+                        <literal-expression value="false" start-index="68" 
stop-index="74" />
+                        <parameter-marker-expression parameter-index="2" 
start-index="63" stop-index="63" />
+                    </else-expr>
+                </case-when-expression>
             </expr>
         </where>
     </select>
@@ -1861,37 +1919,6 @@
     </select>
 
     <select sql-case-id="select_projections_with_expr">
-        <projections start-index="7" stop-index="58">
-            <expression-projection start-index="7" stop-index="11" 
text="10+20">
-                <expr>
-                    <binary-operation-expression start-index="7" 
stop-index="11">
-                        <left>
-                            <literal-expression value="10" start-index="7" 
stop-index="8" />
-                        </left>
-                        <right>
-                            <literal-expression value="20" start-index="10" 
stop-index="11" />
-                        </right>
-                        <operator>+</operator>
-                    </binary-operation-expression>
-                </expr>
-            </expression-projection>
-            <expression-projection start-index="13" stop-index="56" text="CASE 
order_id WHEN 1 THEN '11' ELSE '00' END">
-                <expr>
-                    <common-expression literal-text="CASE order_id WHEN 1 THEN 
'11' ELSE '00' END" start-index="13" stop-index="56" />
-                </expr>
-            </expression-projection>
-            <expression-projection start-index="58" stop-index="58" text="1">
-                <expr>
-                    <literal-expression value="1" start-index="58" 
stop-index="58" />
-                </expr>
-            </expression-projection>
-        </projections>
-        <from>
-            <simple-table name="t_order" start-index="65" stop-index="71" />
-        </from>
-    </select>
-
-    <select sql-case-id="select_projections_with_expr_for_postgres">
         <projections start-index="7" stop-index="58">
             <expression-projection start-index="7" stop-index="11" 
text="10+20" />
             <expression-projection start-index="13" stop-index="56" text="CASE 
order_id WHEN 1 THEN '11' ELSE '00' END">
@@ -1923,7 +1950,7 @@
         </from>
     </select>
 
-    <select sql-case-id="select_projections_with_only_expr_for_postgres">
+    <select sql-case-id="select_projections_with_only_expr">
         <projections start-index="7" stop-index="50">
             <expression-projection start-index="7" stop-index="50" text="CASE 
order_id WHEN 1 THEN '11' ELSE '00' END">
                 <expr>
diff --git 
a/test/parser/src/main/resources/sql/supported/dml/select-expression.xml 
b/test/parser/src/main/resources/sql/supported/dml/select-expression.xml
index 7d0390eed1a..e613182edfb 100644
--- a/test/parser/src/main/resources/sql/supported/dml/select-expression.xml
+++ b/test/parser/src/main/resources/sql/supported/dml/select-expression.xml
@@ -21,8 +21,7 @@
     <sql-case id="select_with_expression_for_postgresql" value="SELECT 
o.order_id + 1 * 2 as exp FROM t_order AS o ORDER BY o.order_id" 
db-types="PostgreSQL,openGauss" />
     <sql-case id="select_with_date_function" value="SELECT 
DATE(i.creation_date) AS creation_date FROM `t_order_item` AS i ORDER BY 
DATE(i.creation_date) DESC" db-types="MySQL" />
     <sql-case id="select_with_regexp" value="SELECT * FROM t_order_item t 
WHERE t.status REGEXP ? AND t.item_id IN (?, ?)" db-types="MySQL" />
-    <sql-case id="select_with_case_expression" value="select t.*,o.item_id as 
item_id,(case when t.status = 'init' then '已启用' when t.status = 'failed' then 
'已停用' end) as stateName
-    from t_order t left join t_order_item as o on o.order_id =t.order_id where 
t.order_id=1000 limit 1" db-types="MySQL,H2" />
+    <sql-case id="select_with_case_expression" value="select t.*,o.item_id as 
item_id,(case when t.status = 'init' then '已启用' when t.status = 'failed' then 
'已停用' end) as stateName from t_order t left join t_order_item as o on 
o.order_id =t.order_id where t.order_id=1000 limit 1" db-types="MySQL,H2" />
     <sql-case id="select_where_with_expr_with_or" value="SELECT * FROM t_order 
WHERE t_order.order_id = ? OR ? = t_order.order_id" db-types="MySQL" />
     <sql-case id="select_where_with_expr_with_or_sign" value="SELECT * FROM 
t_order WHERE t_order.order_id = ? || ? = t_order.order_id" db-types="MySQL" />
     <sql-case id="select_where_with_expr_with_xor" value="SELECT * FROM 
t_order WHERE t_order.order_id = ? XOR ? = t_order.order_id" db-types="MySQL" />
@@ -82,9 +81,8 @@
     <sql-case id="select_where_with_expr_with_not_with_order_by" value="SELECT 
last_name, job_id, salary, department_id FROM employees WHERE NOT (job_id = 
'PU_CLERK' AND department_id = 30) ORDER BY last_name" db-types="Oracle" />
     <sql-case id="select_where_with_subquery" value="SELECT last_name, 
department_id FROM employees WHERE department_id = (SELECT department_id FROM 
employees WHERE last_name = 'Lorentz') ORDER BY last_name, department_id" 
db-types="Oracle" />
     <sql-case id="select_where_with_expr_with_not_in" value="SELECT * FROM 
employees WHERE department_id NOT IN (SELECT department_id FROM departments 
WHERE location_id = 1700) ORDER BY last_name" db-types="Oracle" />
-    <sql-case id="select_projections_with_expr" value="SELECT 10+20,CASE 
order_id WHEN 1 THEN '11' ELSE '00' END,1 FROM t_order" db-types="MySQL" />
-    <sql-case id="select_projections_with_expr_for_postgres" value="SELECT 
10+20,CASE order_id WHEN 1 THEN '11' ELSE '00' END,1 FROM t_order" 
db-types="PostgreSQL,openGauss" />
-    <sql-case id="select_projections_with_only_expr_for_postgres" 
value="SELECT CASE order_id WHEN 1 THEN '11' ELSE '00' END FROM t_order" 
db-types="PostgreSQL,openGauss" />
+    <sql-case id="select_projections_with_expr" value="SELECT 10+20,CASE 
order_id WHEN 1 THEN '11' ELSE '00' END,1 FROM t_order" 
db-types="MySQL,PostgreSQL,openGauss" />
+    <sql-case id="select_projections_with_only_expr" value="SELECT CASE 
order_id WHEN 1 THEN '11' ELSE '00' END FROM t_order" 
db-types="MySQL,PostgreSQL,openGauss" />
     <sql-case id="select_with_amp" value="select 1 &amp; 1" 
db-types="PostgreSQL,openGauss" />
     <sql-case id="select_with_vertical_bar" value="select 1 | 1" 
db-types="PostgreSQL,openGauss" />
     <sql-case id="select_with_abs_function" value="SELECT ABS(1) FROM t_order 
WHERE ABS(1) &gt; 1 GROUP BY ABS(1) ORDER BY ABS(1)" db-types="Oracle" />

Reply via email to