This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 216f761bcb12 [SPARK-48357][SQL] Support for LOOP statement
216f761bcb12 is described below
commit 216f761bcb122a253d42793466a9fe97e7ba3336
Author: Dušan Tišma <[email protected]>
AuthorDate: Thu Oct 3 09:12:23 2024 +0200
[SPARK-48357][SQL] Support for LOOP statement
### What changes were proposed in this pull request?
In this PR, support for LOOP statement in SQL scripting is introduced.
Changes summary:
Grammar/parser changes:
- `loopStatement` grammar rule
- `visitLoopStatement` rule visitor
- `LoopStatement` logical operator
`LoopStatementExec` execution node
Iterator implementation - repeatedly execute body (only way to stop the
loop is with LEAVE, or if an exception occurs)
`SqlScriptingInterpreter` - added logic to transform LoopStatement logical
operator to LoopStatementExec execution node
### Why are the changes needed?
This is a part of SQL Scripting introduced to Spark, LOOP statement is a
basic control flow construct in the SQL language.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New tests are introduced to scripting test suites:
`SqlScriptingParserSuite`, `SqlScriptingExecutionNodeSuite` and
`SqlScriptingInterpreterSuite`.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48323 from dusantism-db/sql-scripting-loop.
Authored-by: Dušan Tišma <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
docs/sql-ref-ansi-compliance.md | 1 +
.../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 +
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 7 +
.../spark/sql/catalyst/parser/AstBuilder.scala | 11 ++
.../parser/SqlScriptingLogicalOperators.scala | 18 +-
.../catalyst/parser/SqlScriptingParserSuite.scala | 205 +++++++++++++++++++++
.../sql/scripting/SqlScriptingExecutionNode.scala | 52 ++++++
.../sql/scripting/SqlScriptingInterpreter.scala | 6 +-
.../sql-tests/results/ansi/keywords.sql.out | 1 +
.../resources/sql-tests/results/keywords.sql.out | 1 +
.../scripting/SqlScriptingExecutionNodeSuite.scala | 17 ++
.../scripting/SqlScriptingInterpreterSuite.scala | 152 +++++++++++++++
.../ThriftServerWithSparkContextSuite.scala | 2 +-
13 files changed, 469 insertions(+), 5 deletions(-)
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index 12dff1e325c4..b4446b1538cd 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -581,6 +581,7 @@ Below is a list of all the keywords in Spark SQL.
|LOCKS|non-reserved|non-reserved|non-reserved|
|LOGICAL|non-reserved|non-reserved|non-reserved|
|LONG|non-reserved|non-reserved|non-reserved|
+|LOOP|non-reserved|non-reserved|non-reserved|
|MACRO|non-reserved|non-reserved|non-reserved|
|MAP|non-reserved|non-reserved|non-reserved|
|MATCHED|non-reserved|non-reserved|non-reserved|
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index de28041acd41..7391e8c353de 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -301,6 +301,7 @@ LOCK: 'LOCK';
LOCKS: 'LOCKS';
LOGICAL: 'LOGICAL';
LONG: 'LONG';
+LOOP: 'LOOP';
MACRO: 'MACRO';
MAP: 'MAP' {incComplexTypeLevelCounter();};
MATCHED: 'MATCHED';
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index e8e2e980135a..644c7e732fbf 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -69,6 +69,7 @@ compoundStatement
| repeatStatement
| leaveStatement
| iterateStatement
+ | loopStatement
;
setStatementWithOptionalVarKeyword
@@ -106,6 +107,10 @@ caseStatement
(ELSE elseBody=compoundBody)? END CASE
#simpleCaseStatement
;
+loopStatement
+ : beginLabel? LOOP compoundBody END LOOP endLabel?
+ ;
+
singleStatement
: (statement|setResetStatement) SEMICOLON* EOF
;
@@ -1658,6 +1663,7 @@ ansiNonReserved
| LOCKS
| LOGICAL
| LONG
+ | LOOP
| MACRO
| MAP
| MATCHED
@@ -2016,6 +2022,7 @@ nonReserved
| LOCKS
| LOGICAL
| LONG
+ | LOOP
| MACRO
| MAP
| MATCHED
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 9ce96ae652fe..f1d211f51778 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -337,6 +337,10 @@ class AstBuilder extends DataTypeAstBuilder
if Option(c.beginLabel()).isDefined &&
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
=> true
+ case c: LoopStatementContext
+ if Option(c.beginLabel()).isDefined &&
+
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ => true
case _ => false
}
}
@@ -373,6 +377,13 @@ class AstBuilder extends DataTypeAstBuilder
CurrentOrigin.get, labelText, "ITERATE")
}
+ override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = {
+ val labelText = generateLabelText(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
+ val body = visitCompoundBody(ctx.compoundBody())
+
+ LoopStatement(body, Some(labelText))
+ }
+
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan
= withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext]))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
index ed40a5fd734b..9fd87f51bd57 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
@@ -81,7 +81,7 @@ case class IfElseStatement(
* Body is executed as long as the condition evaluates to true
* @param body Compound body is a collection of statements that are executed
if condition is true.
* @param label An optional label for the loop which is unique amongst all
labels for statements
- * within which the LOOP statement is contained.
+ * within which the WHILE statement is contained.
* If an end label is specified it must match the beginning label.
* The label can be used to LEAVE or ITERATE the loop.
*/
@@ -97,7 +97,7 @@ case class WhileStatement(
* @param body Compound body is a collection of statements that are executed
once no matter what,
* and then as long as condition is false.
* @param label An optional label for the loop which is unique amongst all
labels for statements
- * within which the LOOP statement is contained.
+ * within which the REPEAT statement is contained.
* If an end label is specified it must match the beginning label.
* The label can be used to LEAVE or ITERATE the loop.
*/
@@ -106,7 +106,6 @@ case class RepeatStatement(
body: CompoundBody,
label: Option[String]) extends CompoundPlanStatement
-
/**
* Logical operator for LEAVE statement.
* The statement can be used both for compounds or any kind of loops.
@@ -138,3 +137,16 @@ case class CaseStatement(
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
assert(conditions.length == conditionalBodies.length)
}
+
+/**
+ * Logical operator for LOOP statement.
+ * @param body Compound body is a collection of statements that are executed
until the
+ * LOOP statement is terminated by using the LEAVE statement.
+ * @param label An optional label for the loop which is unique amongst all
labels for statements
+ * within which the LOOP statement is contained.
+ * If an end label is specified it must match the beginning label.
+ * The label can be used to LEAVE or ITERATE the loop.
+ */
+case class LoopStatement(
+ body: CompoundBody,
+ label: Option[String]) extends CompoundPlanStatement
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index ba634333e06f..2972ba2db21d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -1400,6 +1400,211 @@ class SqlScriptingParserSuite extends SparkFunSuite
with SQLHelper {
.getText == "SELECT 42")
}
+ test("loop statement") {
+ val sqlScriptText =
+ """BEGIN
+ |lbl: LOOP
+ | SELECT 1;
+ | SELECT 2;
+ |END LOOP lbl;
+ |END
+ """.stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+ val whileStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+ assert(whileStmt.body.isInstanceOf[CompoundBody])
+ assert(whileStmt.body.collection.length == 2)
+ assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(whileStmt.label.contains("lbl"))
+ }
+
+ test("loop with if else block") {
+ val sqlScriptText =
+ """BEGIN
+ |lbl: LOOP
+ | IF 1 = 1 THEN
+ | SELECT 1;
+ | ELSE
+ | SELECT 2;
+ | END IF;
+ |END LOOP lbl;
+ |END
+ """.stripMargin
+
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+ val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+ assert(loopStmt.body.isInstanceOf[CompoundBody])
+ assert(loopStmt.body.collection.length == 1)
+ assert(loopStmt.body.collection.head.isInstanceOf[IfElseStatement])
+ val ifStmt = loopStmt.body.collection.head.asInstanceOf[IfElseStatement]
+
+ assert(ifStmt.conditions.length == 1)
+ assert(ifStmt.conditionalBodies.length == 1)
+ assert(ifStmt.elseBody.isDefined)
+
+ assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
+ assert(ifStmt.conditions.head.getText == "1 = 1")
+
+
assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 1")
+
+ assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+ assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 2")
+
+ assert(loopStmt.label.contains("lbl"))
+ }
+
+ test("nested loop") {
+ val sqlScriptText =
+ """BEGIN
+ |lbl: LOOP
+ | LOOP
+ | SELECT 42;
+ | END LOOP;
+ |END LOOP lbl;
+ |END
+ """.stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+ val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+ assert(loopStmt.body.isInstanceOf[CompoundBody])
+ assert(loopStmt.body.collection.length == 1)
+ assert(loopStmt.body.collection.head.isInstanceOf[LoopStatement])
+ val nestedLoopStmt =
loopStmt.body.collection.head.asInstanceOf[LoopStatement]
+
+ assert(nestedLoopStmt.body.isInstanceOf[CompoundBody])
+ assert(nestedLoopStmt.body.collection.length == 1)
+ assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(nestedLoopStmt.body.collection.
+ head.asInstanceOf[SingleStatement].getText == "SELECT 42")
+
+ assert(loopStmt.label.contains("lbl"))
+ }
+
+ test("leave loop statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: LOOP
+ | SELECT 1;
+ | LEAVE lbl;
+ | END LOOP;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+ val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+ assert(loopStmt.body.isInstanceOf[CompoundBody])
+ assert(loopStmt.body.collection.length == 2)
+
+ assert(loopStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(loopStmt.body.collection.head.asInstanceOf[SingleStatement].getText
== "SELECT 1")
+
+ assert(loopStmt.body.collection(1).isInstanceOf[LeaveStatement])
+ assert(loopStmt.body.collection(1).asInstanceOf[LeaveStatement].label ==
"lbl")
+ }
+
+ test("iterate loop statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: LOOP
+ | SELECT 1;
+ | ITERATE lbl;
+ | END LOOP;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+ val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+ assert(loopStmt.body.isInstanceOf[CompoundBody])
+ assert(loopStmt.body.collection.length == 2)
+
+ assert(loopStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(loopStmt.body.collection.head.asInstanceOf[SingleStatement].getText
== "SELECT 1")
+
+ assert(loopStmt.body.collection(1).isInstanceOf[IterateStatement])
+ assert(loopStmt.body.collection(1).asInstanceOf[IterateStatement].label ==
"lbl")
+ }
+
+ test("leave outer loop from nested loop statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: LOOP
+ | lbl2: LOOP
+ | SELECT 1;
+ | LEAVE lbl;
+ | END LOOP;
+ | END LOOP;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+ val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+ assert(loopStmt.body.isInstanceOf[CompoundBody])
+ assert(loopStmt.body.collection.length == 1)
+
+ val nestedLoopStmt =
loopStmt.body.collection.head.asInstanceOf[LoopStatement]
+
+ assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(
+
nestedLoopStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(nestedLoopStmt.body.collection(1).isInstanceOf[LeaveStatement])
+
assert(nestedLoopStmt.body.collection(1).asInstanceOf[LeaveStatement].label ==
"lbl")
+ }
+
+ test("iterate outer loop from nested loop statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: LOOP
+ | lbl2: LOOP
+ | SELECT 1;
+ | ITERATE lbl;
+ | END LOOP;
+ | END LOOP;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+ val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+ assert(loopStmt.body.isInstanceOf[CompoundBody])
+ assert(loopStmt.body.collection.length == 1)
+
+ val nestedLoopStmt =
loopStmt.body.collection.head.asInstanceOf[LoopStatement]
+
+ assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(
+
nestedLoopStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(nestedLoopStmt.body.collection(1).isInstanceOf[IterateStatement])
+
assert(nestedLoopStmt.body.collection(1).asInstanceOf[IterateStatement].label
== "lbl")
+ }
+
// Helper methods
def cleanupStatementString(statementStr: String): String = {
statementStr
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index af9fd5464277..9fdb9626556f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -592,3 +592,55 @@ class IterateStatementExec(val label: String) extends
LeafStatementExec {
var hasBeenMatched: Boolean = false
override def reset(): Unit = hasBeenMatched = false
}
+
+class LoopStatementExec(
+ body: CompoundBodyExec,
+ val label: Option[String]) extends NonLeafStatementExec {
+
+ /**
+ * Loop can be interrupted by LeaveStatementExec
+ */
+ private var interrupted: Boolean = false
+
+ /**
+ * Loop can be iterated by IterateStatementExec
+ */
+ private var iterated: Boolean = false
+
+ private lazy val treeIterator =
+ new Iterator[CompoundStatementExec] {
+ override def hasNext: Boolean = !interrupted
+
+ override def next(): CompoundStatementExec = {
+ if (!body.getTreeIterator.hasNext || iterated) {
+ reset()
+ }
+
+ val retStmt = body.getTreeIterator.next()
+
+ retStmt match {
+ case leaveStatementExec: LeaveStatementExec if
!leaveStatementExec.hasBeenMatched =>
+ if (label.contains(leaveStatementExec.label)) {
+ leaveStatementExec.hasBeenMatched = true
+ }
+ interrupted = true
+ case iterStatementExec: IterateStatementExec if
!iterStatementExec.hasBeenMatched =>
+ if (label.contains(iterStatementExec.label)) {
+ iterStatementExec.hasBeenMatched = true
+ }
+ iterated = true
+ case _ =>
+ }
+
+ retStmt
+ }
+ }
+
+ override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+ override def reset(): Unit = {
+ interrupted = false
+ iterated = false
+ body.reset()
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 917b4d6f45ee..78ef715e1898 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
-import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody,
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement,
RepeatStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody,
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement,
LoopStatement, RepeatStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable,
DropVariable, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.Origin
@@ -120,6 +120,10 @@ case class SqlScriptingInterpreter() {
transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec]
new RepeatStatementExec(conditionExec, bodyExec, label, session)
+ case LoopStatement(body, label) =>
+ val bodyExec = transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec]
+ new LoopStatementExec(bodyExec, label)
+
case leaveStatement: LeaveStatement =>
new LeaveStatementExec(leaveStatement.label)
diff --git
a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
index 7c694503056a..d9d266e8a674 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
@@ -187,6 +187,7 @@ LOCK false
LOCKS false
LOGICAL false
LONG false
+LOOP false
MACRO false
MAP false
MATCHED false
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
index 2c16d961b131..cd93a811d64f 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
@@ -187,6 +187,7 @@ LOCK false
LOCKS false
LOGICAL false
LONG false
+LOOP false
MACRO false
MAP false
MATCHED false
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index 83d8191d01ec..baad5702f4f2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -97,6 +97,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
case TestLeafStatement(testVal) => testVal
case TestIfElseCondition(_, description) => description
case TestLoopCondition(_, _, description) => description
+ case loopStmt: LoopStatementExec => loopStmt.label.get
case leaveStmt: LeaveStatementExec => leaveStmt.label
case iterateStmt: IterateStatementExec => iterateStmt.label
case _ => fail("Unexpected statement type")
@@ -669,4 +670,20 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
val statements = iter.map(extractStatementValue).toSeq
assert(statements === Seq("con1", "con2"))
}
+
+ test("loop statement with leave") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ new LoopStatementExec(
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new LeaveStatementExec("lbl"))
+ ),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("body1", "lbl"))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index ac190eb48d1f..3551608a1ee8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -1383,4 +1383,156 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
)
verifySqlScriptResult(sqlScriptText, expected)
}
+
+ test("loop statement with leave") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE x INT;
+ | SET x = 0;
+ | lbl: LOOP
+ | SET x = x + 1;
+ | SELECT x;
+ | IF x > 2
+ | THEN
+ | LEAVE lbl;
+ | END IF;
+ | END LOOP;
+ | SELECT x;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq.empty[Row], // declare
+ Seq.empty[Row], // set x = 0
+ Seq.empty[Row], // set x = 1
+ Seq(Row(1)), // select x
+ Seq.empty[Row], // set x = 2
+ Seq(Row(2)), // select x
+ Seq.empty[Row], // set x = 3
+ Seq(Row(3)), // select x
+ Seq(Row(3)), // select x
+ Seq.empty[Row] // drop
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
+ test("nested loop statement with leave") {
+ val commands =
+ """
+ |BEGIN
+ | DECLARE x = 0;
+ | DECLARE y = 0;
+ | lbl1: LOOP
+ | SET VAR y = 0;
+ | lbl2: LOOP
+ | SELECT x, y;
+ | SET VAR y = y + 1;
+ | IF y >= 2 THEN
+ | LEAVE lbl2;
+ | END IF;
+ | END LOOP;
+ | SET VAR x = x + 1;
+ | IF x >= 2 THEN
+ | LEAVE lbl1;
+ | END IF;
+ | END LOOP;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare x
+ Seq.empty[Row], // declare y
+ Seq.empty[Row], // set y to 0
+ Seq(Row(0, 0)), // select x, y
+ Seq.empty[Row], // increase y
+ Seq(Row(0, 1)), // select x, y
+ Seq.empty[Row], // increase y
+ Seq.empty[Row], // increase x
+ Seq.empty[Row], // set y to 0
+ Seq(Row(1, 0)), // select x, y
+ Seq.empty[Row], // increase y
+ Seq(Row(1, 1)), // select x, y
+ Seq.empty[Row], // increase y
+ Seq.empty[Row], // increase x
+ Seq.empty[Row], // drop y
+ Seq.empty[Row] // drop x
+ )
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("iterate loop statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE x INT;
+ | SET x = 0;
+ | lbl: LOOP
+ | SET x = x + 1;
+ | IF x > 1 THEN
+ | LEAVE lbl;
+ | END IF;
+ | ITERATE lbl;
+ | SET x = x + 2;
+ | END LOOP;
+ | SELECT x;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq.empty[Row], // declare
+ Seq.empty[Row], // set x = 0
+ Seq.empty[Row], // set x = 1
+ Seq.empty[Row], // set x = 2
+ Seq(Row(2)), // select x
+ Seq.empty[Row] // drop
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
+ test("leave outer loop from nested loop statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: LOOP
+ | lbl2: LOOP
+ | SELECT 1;
+ | LEAVE lbl;
+ | END LOOP;
+ | END LOOP;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq(Row(1)) // select 1
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
+ test("iterate outer loop from nested loop statement") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE x INT;
+ | SET x = 0;
+ | lbl: LOOP
+ | SET x = x + 1;
+ | IF x > 2 THEN
+ | LEAVE lbl;
+ | END IF;
+ | lbl2: LOOP
+ | SELECT 1;
+ | ITERATE lbl;
+ | SET x = 10;
+ | END LOOP;
+ | END LOOP;
+ | SELECT x;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq.empty[Row], // declare
+ Seq.empty[Row], // set x = 0
+ Seq.empty[Row], // set x = 1
+ Seq(Row(1)), // select 1
+ Seq.empty[Row], // set x = 2
+ Seq(Row(1)), // select 1
+ Seq.empty[Row], // set x = 3
+ Seq(Row(3)), // select x
+ Seq.empty[Row] // drop
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
}
diff --git
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
index dcf3bd8c7173..60c49619552e 100644
---
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
+++
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
@@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends
SharedThriftServer {
val sessionHandle = client.openSession(user, "")
val infoValue = client.getInfo(sessionHandle,
GetInfoType.CLI_ODBC_KEYWORDS)
// scalastyle:off line.size.limit
- assert(infoValue.getStringValue ==
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURREN
[...]
+ assert(infoValue.getStringValue ==
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURREN
[...]
// scalastyle:on line.size.limit
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]