This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5533c81e3453 [SPARK-48355][SQL] Support for CASE statement
5533c81e3453 is described below
commit 5533c81e34534d43ae90fc2ce5ac1d174d4e8289
Author: Dušan Tišma <[email protected]>
AuthorDate: Fri Sep 13 15:01:09 2024 +0200
[SPARK-48355][SQL] Support for CASE statement
### What changes were proposed in this pull request?
Add support for [case
statements](https://docs.google.com/document/d/1cpSuR3KxRuTSJ4ZMQ73FJ4_-hjouNNU2zfI4vri6yhs/edit#heading=h.ofijhkunigv)
to sql scripting. There are 2 types of case statement - simple and searched
(EXAMPLES BELOW). Proposed changes are:
- Add `caseStatement` grammar rule to SqlBaseParser.g4
- Add visit case statement methods to `AstBuilder`
- Add `SearchedCaseStatement` and `SearchedCaseStatementExec` classes, to
enable them to be run in sql scripts.
The reason only searched case nodes are added is that, in the current
implementation, a simple case is parsed into a searched case, by creating
internal `EqualTo` expressions to compare the main case expression to the
expressions in the when clauses. This approach is similar to the existing case
**expressions**, which are parsed in the same way. The problem with this
approach is that the main expression is unnecessarily evaluated N times, where
N is the number of when clauses, which c [...]
Simple case compares one expression (case variable) to others, until an
equal one is found. Else clause is optional.
```
BEGIN
CASE 1
WHEN 1 THEN
SELECT 1;
WHEN 2 THEN
SELECT 2;
ELSE
SELECT 3;
END CASE;
END
```
Searched case evaluates boolean expressions. Else clause is optional.
```
BEGIN
CASE
WHEN 1 = 1 THEN
SELECT 1;
WHEN 2 IN (1,2,3) THEN
SELECT 2;
ELSE
SELECT 3;
END CASE;
END
```
### Why are the changes needed?
Case statements are currently not implemented in sql scripting.
### Does this PR introduce _any_ user-facing change?
Yes, users will now be able to use case statements in their sql scripts.
### How was this patch tested?
Tests for both simple and searched case statements are added to
SqlScriptingParserSuite, SqlScriptingExecutionNodeSuite and
SqlScriptingInterpreterSuite.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47672 from dusantism-db/sql-scripting-case-statement.
Authored-by: Dušan Tišma <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 8 +
.../spark/sql/catalyst/parser/AstBuilder.scala | 48 ++-
.../parser/SqlScriptingLogicalOperators.scala | 14 +
.../catalyst/parser/SqlScriptingParserSuite.scala | 297 +++++++++++++++-
.../sql/scripting/SqlScriptingExecutionNode.scala | 72 ++++
.../sql/scripting/SqlScriptingInterpreter.scala | 13 +-
.../scripting/SqlScriptingExecutionNodeSuite.scala | 93 +++++
.../scripting/SqlScriptingInterpreterSuite.scala | 379 ++++++++++++++++++++-
8 files changed, 920 insertions(+), 4 deletions(-)
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 42f0094de351..73d5cb55295a 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -64,6 +64,7 @@ compoundStatement
| setStatementWithOptionalVarKeyword
| beginEndCompoundBlock
| ifElseStatement
+ | caseStatement
| whileStatement
| repeatStatement
| leaveStatement
@@ -98,6 +99,13 @@ iterateStatement
: ITERATE multipartIdentifier
;
+caseStatement
+ : CASE (WHEN conditions+=booleanExpression THEN
conditionalBodies+=compoundBody)+
+ (ELSE elseBody=compoundBody)? END CASE
#searchedCaseStatement
+ | CASE caseVariable=expression (WHEN conditionExpressions+=expression THEN
conditionalBodies+=compoundBody)+
+ (ELSE elseBody=compoundBody)? END CASE
#simpleCaseStatement
+ ;
+
singleStatement
: (statement|setResetStatement) SEMICOLON* EOF
;
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 924b5c2cfeb1..9620ce13d92e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -261,6 +261,52 @@ class AstBuilder extends DataTypeAstBuilder
WhileStatement(condition, body, Some(labelText))
}
+ override def visitSearchedCaseStatement(ctx: SearchedCaseStatementContext):
CaseStatement = {
+ val conditions = ctx.conditions.asScala.toList.map(boolExpr =>
withOrigin(boolExpr) {
+ SingleStatement(
+ Project(
+ Seq(Alias(expression(boolExpr), "condition")()),
+ OneRowRelation()))
+ })
+ val conditionalBodies =
+ ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
+
+ if (conditions.length != conditionalBodies.length) {
+ throw SparkException.internalError(
+ s"Mismatched number of conditions ${conditions.length} and condition
bodies" +
+ s" ${conditionalBodies.length} in case statement")
+ }
+
+ CaseStatement(
+ conditions = conditions,
+ conditionalBodies = conditionalBodies,
+ elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
+ }
+
+ override def visitSimpleCaseStatement(ctx: SimpleCaseStatementContext):
CaseStatement = {
+ // uses EqualTo to compare the case variable(the main case expression)
+ // to the WHEN clause expressions
+ val conditions = ctx.conditionExpressions.asScala.toList.map(expr =>
withOrigin(expr) {
+ SingleStatement(
+ Project(
+ Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)),
"condition")()),
+ OneRowRelation()))
+ })
+ val conditionalBodies =
+ ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body))
+
+ if (conditions.length != conditionalBodies.length) {
+ throw SparkException.internalError(
+ s"Mismatched number of conditions ${conditions.length} and condition
bodies" +
+ s" ${conditionalBodies.length} in case statement")
+ }
+
+ CaseStatement(
+ conditions = conditions,
+ conditionalBodies = conditionalBodies,
+ elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body)))
+ }
+
override def visitRepeatStatement(ctx: RepeatStatementContext):
RepeatStatement = {
val labelText = generateLabelText(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()
@@ -292,7 +338,7 @@ class AstBuilder extends DataTypeAstBuilder
case c: RepeatStatementContext
if Option(c.beginLabel()).isDefined &&
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
- => true
+ => true
case _ => false
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
index 5e7e8b0b4fc9..ed40a5fd734b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
@@ -124,3 +124,17 @@ case class LeaveStatement(label: String) extends
CompoundPlanStatement
* @param label Label of the loop to iterate.
*/
case class IterateStatement(label: String) extends CompoundPlanStatement
+
+/**
+ * Logical operator for CASE statement.
+ * @param conditions Collection of conditions which correspond to WHEN clauses.
+ * @param conditionalBodies Collection of bodies that have a corresponding
condition,
+ * in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met,
i.e. ELSE branch.
+ */
+case class CaseStatement(
+ conditions: Seq[SingleStatement],
+ conditionalBodies: Seq[CompoundBody],
+ elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
+ assert(conditions.length == conditionalBodies.length)
+}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index bf527b9c3bd7..24ad32c5300b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -18,8 +18,9 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression,
In, Literal, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.catalyst.plans.logical.CreateVariable
+import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, Project}
import org.apache.spark.sql.exceptions.SqlScriptingException
class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
@@ -1111,6 +1112,287 @@ class SqlScriptingParserSuite extends SparkFunSuite
with SQLHelper {
}
+ test("searched case statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1 = 1 THEN
+ | SELECT 42;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[CaseStatement])
+ val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(caseStmt.conditions.length == 1)
+ assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+ assert(caseStmt.conditions.head.getText == "1 = 1")
+ }
+
+ test("searched case statement - multi when") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1 IN (1,2,3) THEN
+ | SELECT 1;
+ | WHEN (SELECT * FROM t) THEN
+ | SELECT * FROM b;
+ | WHEN 1 = 1 THEN
+ | SELECT 42;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val tree = parseScript(sqlScriptText)
+
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[CaseStatement])
+
+ val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(caseStmt.conditions.length == 3)
+ assert(caseStmt.conditionalBodies.length == 3)
+ assert(caseStmt.elseBody.isEmpty)
+
+ assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+ assert(caseStmt.conditions.head.getText == "1 IN (1,2,3)")
+
+
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 1")
+
+ assert(caseStmt.conditions(1).isInstanceOf[SingleStatement])
+ assert(caseStmt.conditions(1).getText == "(SELECT * FROM t)")
+
+
assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
+
assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT * FROM b")
+
+ assert(caseStmt.conditions(2).isInstanceOf[SingleStatement])
+ assert(caseStmt.conditions(2).getText == "1 = 1")
+
+
assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
+
assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 42")
+ }
+
+ test("searched case statement with else") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1 = 1 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[CaseStatement])
+ val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(caseStmt.elseBody.isDefined)
+ assert(caseStmt.conditions.length == 1)
+ assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+ assert(caseStmt.conditions.head.getText == "1 = 1")
+
+ assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+ assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 43")
+ }
+
+ test("searched case statement nested") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1 = 1 THEN
+ | CASE
+ | WHEN 2 = 1 THEN
+ | SELECT 41;
+ | ELSE
+ | SELECT 42;
+ | END CASE;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[CaseStatement])
+
+ val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(caseStmt.conditions.length == 1)
+ assert(caseStmt.conditionalBodies.length == 1)
+ assert(caseStmt.elseBody.isEmpty)
+
+ assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+ assert(caseStmt.conditions.head.getText == "1 = 1")
+
+
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
+ val nestedCaseStmt =
+
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
+
+ assert(nestedCaseStmt.conditions.length == 1)
+ assert(nestedCaseStmt.conditionalBodies.length == 1)
+ assert(nestedCaseStmt.elseBody.isDefined)
+
+ assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement])
+ assert(nestedCaseStmt.conditions.head.getText == "2 = 1")
+
+
assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+
assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 41")
+
+
assert(nestedCaseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+
assert(nestedCaseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 42")
+ }
+
+ test("simple case statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | CASE 1
+ | WHEN 1 THEN
+ | SELECT 1;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[CaseStatement])
+ val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(caseStmt.conditions.length == 1)
+ assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+ checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ ==
Literal(1), _ == Literal(1))
+ }
+
+
+ test("simple case statement - multi when") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | CASE 1
+ | WHEN 1 THEN
+ | SELECT 1;
+ | WHEN (SELECT 2) THEN
+ | SELECT * FROM b;
+ | WHEN 3 IN (1,2,3) THEN
+ | SELECT 42;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val tree = parseScript(sqlScriptText)
+
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[CaseStatement])
+
+ val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(caseStmt.conditions.length == 3)
+ assert(caseStmt.conditionalBodies.length == 3)
+ assert(caseStmt.elseBody.isEmpty)
+
+ assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+ checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ ==
Literal(1), _ == Literal(1))
+
+
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 1")
+
+ assert(caseStmt.conditions(1).isInstanceOf[SingleStatement])
+ checkSimpleCaseStatementCondition(
+ caseStmt.conditions(1), _ == Literal(1), _.isInstanceOf[ScalarSubquery])
+
+
assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
+
assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT * FROM b")
+
+ assert(caseStmt.conditions(2).isInstanceOf[SingleStatement])
+ checkSimpleCaseStatementCondition(
+ caseStmt.conditions(2), _ == Literal(1), _.isInstanceOf[In])
+
+
assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
+
assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 42")
+ }
+
+ test("simple case statement with else") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | CASE 1
+ | WHEN 1 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[CaseStatement])
+ val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(caseStmt.elseBody.isDefined)
+ assert(caseStmt.conditions.length == 1)
+ assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+ checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ ==
Literal(1), _ == Literal(1))
+
+ assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+ assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 43")
+ }
+
+ test("simple case statement nested") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | CASE (SELECT 1)
+ | WHEN 1 THEN
+ | CASE 2
+ | WHEN 2 THEN
+ | SELECT 41;
+ | ELSE
+ | SELECT 42;
+ | END CASE;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[CaseStatement])
+
+ val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(caseStmt.conditions.length == 1)
+ assert(caseStmt.conditionalBodies.length == 1)
+ assert(caseStmt.elseBody.isEmpty)
+
+ assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
+ checkSimpleCaseStatementCondition(
+ caseStmt.conditions.head, _.isInstanceOf[ScalarSubquery], _ ==
Literal(1))
+
+
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
+ val nestedCaseStmt =
+
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
+
+ assert(nestedCaseStmt.conditions.length == 1)
+ assert(nestedCaseStmt.conditionalBodies.length == 1)
+ assert(nestedCaseStmt.elseBody.isDefined)
+
+ assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement])
+ checkSimpleCaseStatementCondition(
+ nestedCaseStmt.conditions.head, _ == Literal(2), _ == Literal(2))
+
+
assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+
assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 41")
+
+
assert(nestedCaseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+
assert(nestedCaseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 42")
+ }
+
// Helper methods
def cleanupStatementString(statementStr: String): String = {
statementStr
@@ -1119,4 +1401,17 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
.replace("END", "")
.trim
}
+
+ private def checkSimpleCaseStatementCondition(
+ conditionStatement: SingleStatement,
+ predicateLeft: Expression => Boolean,
+ predicateRight: Expression => Boolean): Unit = {
+ assert(conditionStatement.parsedPlan.isInstanceOf[Project])
+ val project = conditionStatement.parsedPlan.asInstanceOf[Project]
+ assert(project.projectList.head.isInstanceOf[Alias])
+
assert(project.projectList.head.asInstanceOf[Alias].child.isInstanceOf[EqualTo])
+ val equalTo =
project.projectList.head.asInstanceOf[Alias].child.asInstanceOf[EqualTo]
+ assert(predicateLeft(equalTo.left))
+ assert(predicateRight(equalTo.right))
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index cae797614314..af9fd5464277 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -405,6 +405,78 @@ class WhileStatementExec(
}
}
+/**
+ * Executable node for CaseStatement.
+ * @param conditions Collection of executable conditions which correspond to
WHEN clauses.
+ * @param conditionalBodies Collection of executable bodies that have a
corresponding condition,
+ * in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met,
i.e. ELSE branch.
+ * @param session Spark session that SQL script is executed within.
+ */
+class CaseStatementExec(
+ conditions: Seq[SingleStatementExec],
+ conditionalBodies: Seq[CompoundBodyExec],
+ elseBody: Option[CompoundBodyExec],
+ session: SparkSession) extends NonLeafStatementExec {
+ private object CaseState extends Enumeration {
+ val Condition, Body = Value
+ }
+
+ private var state = CaseState.Condition
+ private var curr: Option[CompoundStatementExec] = Some(conditions.head)
+
+ private var clauseIdx: Int = 0
+ private val conditionsCount = conditions.length
+
+ private lazy val treeIterator: Iterator[CompoundStatementExec] =
+ new Iterator[CompoundStatementExec] {
+ override def hasNext: Boolean = curr.nonEmpty
+
+ override def next(): CompoundStatementExec = state match {
+ case CaseState.Condition =>
+ val condition = curr.get.asInstanceOf[SingleStatementExec]
+ if (evaluateBooleanCondition(session, condition)) {
+ state = CaseState.Body
+ curr = Some(conditionalBodies(clauseIdx))
+ } else {
+ clauseIdx += 1
+ if (clauseIdx < conditionsCount) {
+ // There are WHEN clauses remaining.
+ state = CaseState.Condition
+ curr = Some(conditions(clauseIdx))
+ } else if (elseBody.isDefined) {
+ // ELSE clause exists.
+ state = CaseState.Body
+ curr = Some(elseBody.get)
+ } else {
+ // No remaining clauses.
+ curr = None
+ }
+ }
+ condition
+ case CaseState.Body =>
+ assert(curr.get.isInstanceOf[CompoundBodyExec])
+ val currBody = curr.get.asInstanceOf[CompoundBodyExec]
+ val retStmt = currBody.getTreeIterator.next()
+ if (!currBody.getTreeIterator.hasNext) {
+ curr = None
+ }
+ retStmt
+ }
+ }
+
+ override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+ override def reset(): Unit = {
+ state = CaseState.Condition
+ curr = Some(conditions.head)
+ clauseIdx = 0
+ conditions.foreach(c => c.reset())
+ conditionalBodies.foreach(b => b.reset())
+ elseBody.foreach(b => b.reset())
+ }
+}
+
/**
* Executable node for RepeatStatement.
* @param condition Executable node for the condition - evaluates to a row
with a single boolean
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 865b33999655..917b4d6f45ee 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
-import org.apache.spark.sql.catalyst.parser.{CompoundBody,
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement,
RepeatStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody,
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement,
RepeatStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable,
DropVariable, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.Origin
@@ -95,6 +95,17 @@ case class SqlScriptingInterpreter() {
new IfElseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec,
session)
+ case CaseStatement(conditions, conditionalBodies, elseBody) =>
+ val conditionsExec = conditions.map(condition =>
+ // todo: what to put here for isInternal, in case of simple case
statement
+ new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false))
+ val conditionalBodiesExec = conditionalBodies.map(body =>
+ transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec])
+ val unconditionalBodiesExec = elseBody.map(body =>
+ transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec])
+ new CaseStatementExec(
+ conditionsExec, conditionalBodiesExec, unconditionalBodiesExec,
session)
+
case WhileStatement(condition, body, label) =>
val conditionExec =
new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index 4b72ca8ecaa9..83d8191d01ec 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -576,4 +576,97 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
"body1", "lbl", "con1",
"body1", "lbl", "con1"))
}
+
+ test("searched case - enter first WHEN clause") {
+ val iter = new CompoundBodyExec(Seq(
+ new CaseStatementExec(
+ conditions = Seq(
+ TestIfElseCondition(condVal = true, description = "con1"),
+ TestIfElseCondition(condVal = false, description = "con2")
+ ),
+ conditionalBodies = Seq(
+ new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
+ new CompoundBodyExec(Seq(TestLeafStatement("body2")))
+ ),
+ elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))),
+ session = spark
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("con1", "body1"))
+ }
+
+ test("searched case - enter body of the ELSE clause") {
+ val iter = new CompoundBodyExec(Seq(
+ new CaseStatementExec(
+ conditions = Seq(
+ TestIfElseCondition(condVal = false, description = "con1")
+ ),
+ conditionalBodies = Seq(
+ new CompoundBodyExec(Seq(TestLeafStatement("body1")))
+ ),
+ elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body2")))),
+ session = spark
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("con1", "body2"))
+ }
+
+ test("searched case - enter second WHEN clause") {
+ val iter = new CompoundBodyExec(Seq(
+ new CaseStatementExec(
+ conditions = Seq(
+ TestIfElseCondition(condVal = false, description = "con1"),
+ TestIfElseCondition(condVal = true, description = "con2")
+ ),
+ conditionalBodies = Seq(
+ new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
+ new CompoundBodyExec(Seq(TestLeafStatement("body2")))
+ ),
+ elseBody = Some(new CompoundBodyExec(Seq(TestLeafStatement("body3")))),
+ session = spark
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("con1", "con2", "body2"))
+ }
+
+ test("searched case - without else (successful check)") {
+ val iter = new CompoundBodyExec(Seq(
+ new CaseStatementExec(
+ conditions = Seq(
+ TestIfElseCondition(condVal = false, description = "con1"),
+ TestIfElseCondition(condVal = true, description = "con2")
+ ),
+ conditionalBodies = Seq(
+ new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
+ new CompoundBodyExec(Seq(TestLeafStatement("body2")))
+ ),
+ elseBody = None,
+ session = spark
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("con1", "con2", "body2"))
+ }
+
+ test("searched case - without else (unsuccessful checks)") {
+ val iter = new CompoundBodyExec(Seq(
+ new CaseStatementExec(
+ conditions = Seq(
+ TestIfElseCondition(condVal = false, description = "con1"),
+ TestIfElseCondition(condVal = false, description = "con2")
+ ),
+ conditionalBodies = Seq(
+ new CompoundBodyExec(Seq(TestLeafStatement("body1"))),
+ new CompoundBodyExec(Seq(TestLeafStatement("body2")))
+ ),
+ elseBody = None,
+ session = spark
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("con1", "con2"))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index 8d9cd1d8c780..4851faf897a0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.scripting
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkNumberFormatException}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest,
Row}
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.exceptions.SqlScriptingException
@@ -368,6 +368,383 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
}
}
+ test("searched case") {
+ val commands =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1 = 1 THEN
+ | SELECT 42;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq(Seq(Row(42)))
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("searched case nested") {
+ val commands =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1=1 THEN
+ | CASE
+ | WHEN 2=1 THEN
+ | SELECT 41;
+ | ELSE
+ | SELECT 42;
+ | END CASE;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq(Seq(Row(42)))
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("searched case second case") {
+ val commands =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1 = (SELECT 2) THEN
+ | SELECT 1;
+ | WHEN 2 = 2 THEN
+ | SELECT 42;
+ | WHEN (SELECT * FROM t) THEN
+ | SELECT * FROM b;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq(Seq(Row(42)))
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("searched case going in else") {
+ val commands =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 2 = 1 THEN
+ | SELECT 1;
+ | WHEN 3 IN (1,2) THEN
+ | SELECT 2;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq(Seq(Row(43)))
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("searched case with count") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+ |INSERT INTO t VALUES (1, 'a', 1.0);
+ |INSERT INTO t VALUES (1, 'a', 1.0);
+ |CASE
+ | WHEN (SELECT COUNT(*) > 2 FROM t) THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(43)))
+ verifySqlScriptResult(commands, expected)
+ }
+ }
+
+ test("searched case else with count") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+ | INSERT INTO t VALUES (1, 'a', 1.0);
+ | INSERT INTO t VALUES (1, 'a', 1.0);
+ | CASE
+ | WHEN (SELECT COUNT(*) > 2 FROM t) THEN
+ | SELECT 42;
+ | WHEN (SELECT COUNT(*) > 1 FROM t) THEN
+ | SELECT 43;
+ | ELSE
+ | SELECT 44;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(43)))
+ verifySqlScriptResult(commands, expected)
+ }
+ }
+
+ test("searched case no cases matched no else") {
+ val commands =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1 = 2 THEN
+ | SELECT 42;
+ | WHEN 1 = 3 THEN
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq()
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("searched case when evaluates to null") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ | CREATE TABLE t (a BOOLEAN) USING parquet;
+ | CASE
+ | WHEN (SELECT * FROM t) THEN
+ | SELECT 42;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ checkError(
+ exception = intercept[SqlScriptingException] (
+ runSqlScript(commands)
+ ),
+ condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
+ parameters = Map("invalidStatement" -> "(SELECT * FROM T)")
+ )
+ }
+ }
+
+ test("searched case with non boolean condition - constant") {
+ val commands =
+ """
+ |BEGIN
+ | CASE
+ | WHEN 1 THEN
+ | SELECT 42;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ checkError(
+ exception = intercept[SqlScriptingException] (
+ runSqlScript(commands)
+ ),
+ condition = "INVALID_BOOLEAN_STATEMENT",
+ parameters = Map("invalidStatement" -> "1")
+ )
+ }
+
+ test("searched case with too many rows in subquery condition") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ | CREATE TABLE t (a BOOLEAN) USING parquet;
+ | INSERT INTO t VALUES (true);
+ | INSERT INTO t VALUES (true);
+ | CASE
+ | WHEN (SELECT * FROM t) THEN
+ | SELECT 1;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ checkError(
+ exception = intercept[SparkException] (
+ runSqlScript(commands)
+ ),
+ condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
+ parameters = Map.empty,
+ context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 124,
stop = 140)
+ )
+ }
+ }
+
+ test("simple case") {
+ val commands =
+ """
+ |BEGIN
+ | CASE 1
+ | WHEN 1 THEN
+ | SELECT 42;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq(Seq(Row(42)))
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("simple case nested") {
+ val commands =
+ """
+ |BEGIN
+ | CASE 1
+ | WHEN 1 THEN
+ | CASE 2
+ | WHEN (SELECT 3) THEN
+ | SELECT 41;
+ | ELSE
+ | SELECT 42;
+ | END CASE;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq(Seq(Row(42)))
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("simple case second case") {
+ val commands =
+ """
+ |BEGIN
+ | CASE (SELECT 2)
+ | WHEN 1 THEN
+ | SELECT 1;
+ | WHEN 2 THEN
+ | SELECT 42;
+ | WHEN (SELECT * FROM t) THEN
+ | SELECT * FROM b;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq(Seq(Row(42)))
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("simple case going in else") {
+ val commands =
+ """
+ |BEGIN
+ | CASE 1
+ | WHEN 2 THEN
+ | SELECT 1;
+ | WHEN 3 THEN
+ | SELECT 2;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq(Seq(Row(43)))
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("simple case with count") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+ |INSERT INTO t VALUES (1, 'a', 1.0);
+ |INSERT INTO t VALUES (1, 'a', 1.0);
+ |CASE (SELECT COUNT(*) FROM t)
+ | WHEN 1 THEN
+ | SELECT 41;
+ | WHEN 2 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(42)))
+ verifySqlScriptResult(commands, expected)
+ }
+ }
+
+ test("simple case else with count") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+ | INSERT INTO t VALUES (1, 'a', 1.0);
+ | INSERT INTO t VALUES (2, 'b', 2.0);
+ | CASE (SELECT COUNT(*) FROM t)
+ | WHEN 1 THEN
+ | SELECT 42;
+ | WHEN 3 THEN
+ | SELECT 43;
+ | ELSE
+ | SELECT 44;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(44)))
+ verifySqlScriptResult(commands, expected)
+ }
+ }
+
+ test("simple case no cases matched no else") {
+ val commands =
+ """
+ |BEGIN
+ | CASE 1
+ | WHEN 2 THEN
+ | SELECT 42;
+ | WHEN 3 THEN
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+ val expected = Seq()
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("simple case mismatched types") {
+ val commands =
+ """
+ |BEGIN
+ | CASE 1
+ | WHEN "one" THEN
+ | SELECT 42;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ checkError(
+ exception = intercept[SparkNumberFormatException] (
+ runSqlScript(commands)
+ ),
+ condition = "CAST_INVALID_INPUT",
+ parameters = Map(
+ "expression" -> "'one'",
+ "sourceType" -> "\"STRING\"",
+ "targetType" -> "\"BIGINT\""),
+ context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27)
+ )
+ }
+
+ test("simple case compare with null") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ | CREATE TABLE t (a INT) USING parquet;
+ | CASE (SELECT COUNT(*) FROM t)
+ | WHEN 1 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(Seq.empty[Row], Seq(Row(43)))
+ verifySqlScriptResult(commands, expected)
+ }
+ }
+
test("if's condition must be a boolean statement") {
withTable("t") {
val commands =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]