This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new da893223353e [SPARK-48358][SQL] Support for REPEAT statement
da893223353e is described below
commit da893223353e009afba5cd7e87921d69de7ddd04
Author: Dušan Tišma <[email protected]>
AuthorDate: Wed Sep 11 09:50:36 2024 +0200
[SPARK-48358][SQL] Support for REPEAT statement
### What changes were proposed in this pull request?
In this PR, support for REPEAT statement in SQL scripting is introduced.
Changes summary:
Grammar/parser changes
- `repeatStatement` grammar rule
- `visitRepeatStatement` rule visitor
- `RepeatStatement` logical operetor
`RepeatStatementExec` execution node
Internal sates - `Condition` and `Body`
Iterator implementation - switch between body and condition until condition
evaluates to true
SqlScriptingInterpreter - added logic to transform RepeatStatement logical
operator to RepeatStatementExec execution node
### Why are the changes needed?
This is a part of SQL Scripting introduced to Spark, REPEAT statement is a
basic control flow construct in the SQL language.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New tests are introduced to all of the three scripting test suites:
`SqlScriptingParserSuite`, `SqlScriptingExecutionNodeSuite` and
`SqlScriptingInterpreterSuite`.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47756 from dusantism-db/sql-scripting-repeat-statement.
Authored-by: Dušan Tišma <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
docs/sql-ref-ansi-compliance.md | 2 +
.../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 2 +
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 9 +
.../spark/sql/catalyst/parser/AstBuilder.scala | 18 ++
.../parser/SqlScriptingLogicalOperators.scala | 17 ++
.../catalyst/parser/SqlScriptingParserSuite.scala | 265 +++++++++++++++++++
.../sql/scripting/SqlScriptingExecutionNode.scala | 76 ++++++
.../sql/scripting/SqlScriptingInterpreter.scala | 14 +-
.../sql-tests/results/ansi/keywords.sql.out | 2 +
.../resources/sql-tests/results/keywords.sql.out | 2 +
.../scripting/SqlScriptingExecutionNodeSuite.scala | 207 +++++++++++++--
.../scripting/SqlScriptingInterpreterSuite.scala | 281 ++++++++++++++++++++-
.../ThriftServerWithSparkContextSuite.scala | 2 +-
13 files changed, 872 insertions(+), 25 deletions(-)
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index 0ac19e2ae943..fd56a9d4117a 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -646,6 +646,7 @@ Below is a list of all the keywords in Spark SQL.
|REGEXP|non-reserved|non-reserved|not a keyword|
|RENAME|non-reserved|non-reserved|non-reserved|
|REPAIR|non-reserved|non-reserved|non-reserved|
+|REPEAT|non-reserved|non-reserved|non-reserved|
|REPEATABLE|non-reserved|non-reserved|non-reserved|
|REPLACE|non-reserved|non-reserved|non-reserved|
|RESET|non-reserved|non-reserved|non-reserved|
@@ -734,6 +735,7 @@ Below is a list of all the keywords in Spark SQL.
|UNLOCK|non-reserved|non-reserved|non-reserved|
|UNPIVOT|non-reserved|non-reserved|non-reserved|
|UNSET|non-reserved|non-reserved|non-reserved|
+|UNTIL|non-reserved|non-reserved|non-reserved|
|UPDATE|non-reserved|non-reserved|reserved|
|USE|non-reserved|non-reserved|non-reserved|
|USER|reserved|non-reserved|reserved|
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index 6793cb46852b..28ebaeaaed6d 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -364,6 +364,7 @@ REFERENCES: 'REFERENCES';
REFRESH: 'REFRESH';
RENAME: 'RENAME';
REPAIR: 'REPAIR';
+REPEAT: 'REPEAT';
REPEATABLE: 'REPEATABLE';
REPLACE: 'REPLACE';
RESET: 'RESET';
@@ -453,6 +454,7 @@ UNKNOWN: 'UNKNOWN';
UNLOCK: 'UNLOCK';
UNPIVOT: 'UNPIVOT';
UNSET: 'UNSET';
+UNTIL: 'UNTIL';
UPDATE: 'UPDATE';
USE: 'USE';
USER: 'USER';
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 6a23bd394c8c..e9fc6c3ca4f2 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -65,6 +65,7 @@ compoundStatement
| beginEndCompoundBlock
| ifElseStatement
| whileStatement
+ | repeatStatement
| leaveStatement
| iterateStatement
;
@@ -85,6 +86,10 @@ ifElseStatement
(ELSE elseBody=compoundBody)? END IF
;
+repeatStatement
+ : beginLabel? REPEAT compoundBody UNTIL booleanExpression END REPEAT
endLabel?
+ ;
+
leaveStatement
: LEAVE multipartIdentifier
;
@@ -1660,6 +1665,7 @@ ansiNonReserved
| REFRESH
| RENAME
| REPAIR
+ | REPEAT
| REPEATABLE
| REPLACE
| RESET
@@ -1735,6 +1741,7 @@ ansiNonReserved
| UNLOCK
| UNPIVOT
| UNSET
+ | UNTIL
| UPDATE
| USE
| VALUES
@@ -2023,6 +2030,7 @@ nonReserved
| REFRESH
| RENAME
| REPAIR
+ | REPEAT
| REPEATABLE
| REPLACE
| RESET
@@ -2107,6 +2115,7 @@ nonReserved
| UNLOCK
| UNPIVOT
| UNSET
+ | UNTIL
| UPDATE
| USE
| USER
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index f4638920af3c..ab7936179917 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -261,6 +261,20 @@ class AstBuilder extends DataTypeAstBuilder
WhileStatement(condition, body, Some(labelText))
}
+ override def visitRepeatStatement(ctx: RepeatStatementContext):
RepeatStatement = {
+ val labelText = generateLabelText(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
+ val boolExpr = ctx.booleanExpression()
+
+ val condition = withOrigin(boolExpr) {
+ SingleStatement(
+ Project(
+ Seq(Alias(expression(boolExpr), "condition")()),
+ OneRowRelation()))}
+ val body = visitCompoundBody(ctx.compoundBody())
+
+ RepeatStatement(condition, body, Some(labelText))
+ }
+
private def leaveOrIterateContextHasLabel(
ctx: RuleContext, label: String, isIterate: Boolean): Boolean = {
ctx match {
@@ -275,6 +289,10 @@ class AstBuilder extends DataTypeAstBuilder
if Option(c.beginLabel()).isDefined &&
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
=> true
+ case c: RepeatStatementContext
+ if Option(c.beginLabel()).isDefined &&
+
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ => true
case _ => false
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
index dbb29a71323e..5e7e8b0b4fc9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
@@ -90,6 +90,23 @@ case class WhileStatement(
body: CompoundBody,
label: Option[String]) extends CompoundPlanStatement
+/**
+ * Logical operator for REPEAT statement.
+ * @param condition Any expression evaluating to a Boolean.
+ * Body is executed as long as the condition evaluates to false
+ * @param body Compound body is a collection of statements that are executed
once no matter what,
+ * and then as long as condition is false.
+ * @param label An optional label for the loop which is unique amongst all
labels for statements
+ * within which the LOOP statement is contained.
+ * If an end label is specified it must match the beginning label.
+ * The label can be used to LEAVE or ITERATE the loop.
+ */
+case class RepeatStatement(
+ condition: SingleStatement,
+ body: CompoundBody,
+ label: Option[String]) extends CompoundPlanStatement
+
+
/**
* Logical operator for LEAVE statement.
* The statement can be used both for compounds or any kind of loops.
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index 9ae516eb77e6..bf527b9c3bd7 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -708,6 +708,34 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
assert(whileStmt.body.collection(1).asInstanceOf[LeaveStatement].label ==
"lbl")
}
+ test("leave repeat loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: REPEAT
+ | SELECT 1;
+ | LEAVE lbl;
+ | UNTIL 1 = 2
+ | END REPEAT;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[RepeatStatement])
+
+ val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement]
+ assert(repeatStmt.condition.isInstanceOf[SingleStatement])
+ assert(repeatStmt.condition.getText == "1 = 2")
+
+ assert(repeatStmt.body.isInstanceOf[CompoundBody])
+ assert(repeatStmt.body.collection.length == 2)
+
+ assert(repeatStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(repeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(repeatStmt.body.collection(1).isInstanceOf[LeaveStatement])
+ assert(repeatStmt.body.collection(1).asInstanceOf[LeaveStatement].label ==
"lbl")
+ }
+
test ("iterate compound block - should fail") {
val sqlScriptText =
"""
@@ -750,6 +778,34 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
assert(whileStmt.body.collection(1).asInstanceOf[IterateStatement].label
== "lbl")
}
+ test("iterate repeat loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: REPEAT
+ | SELECT 1;
+ | ITERATE lbl;
+ | UNTIL 1 = 2
+ | END REPEAT;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[RepeatStatement])
+
+ val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement]
+ assert(repeatStmt.condition.isInstanceOf[SingleStatement])
+ assert(repeatStmt.condition.getText == "1 = 2")
+
+ assert(repeatStmt.body.isInstanceOf[CompoundBody])
+ assert(repeatStmt.body.collection.length == 2)
+
+ assert(repeatStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(repeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(repeatStmt.body.collection(1).isInstanceOf[IterateStatement])
+ assert(repeatStmt.body.collection(1).asInstanceOf[IterateStatement].label
== "lbl")
+ }
+
test("leave with wrong label - should fail") {
val sqlScriptText =
"""
@@ -813,6 +869,42 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
assert(nestedWhileStmt.body.collection(1).asInstanceOf[LeaveStatement].label ==
"lbl")
}
+ test("leave outer loop from nested repeat loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: REPEAT
+ | lbl2: REPEAT
+ | SELECT 1;
+ | LEAVE lbl;
+ | UNTIL 2 = 2
+ | END REPEAT;
+ | UNTIL 1 = 1
+ | END REPEAT;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[RepeatStatement])
+
+ val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement]
+ assert(repeatStmt.condition.isInstanceOf[SingleStatement])
+ assert(repeatStmt.condition.getText == "1 = 1")
+
+ assert(repeatStmt.body.isInstanceOf[CompoundBody])
+ assert(repeatStmt.body.collection.length == 1)
+
+ val nestedRepeatStmt =
repeatStmt.body.collection.head.asInstanceOf[RepeatStatement]
+ assert(nestedRepeatStmt.condition.isInstanceOf[SingleStatement])
+ assert(nestedRepeatStmt.condition.getText == "2 = 2")
+
+ assert(nestedRepeatStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(
+
nestedRepeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(nestedRepeatStmt.body.collection(1).isInstanceOf[LeaveStatement])
+
assert(nestedRepeatStmt.body.collection(1).asInstanceOf[LeaveStatement].label
== "lbl")
+ }
+
test("iterate outer loop from nested while loop") {
val sqlScriptText =
"""
@@ -846,6 +938,179 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
assert(nestedWhileStmt.body.collection(1).asInstanceOf[IterateStatement].label
== "lbl")
}
+ test("iterate outer loop from nested repeat loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: REPEAT
+ | lbl2: REPEAT
+ | SELECT 1;
+ | ITERATE lbl;
+ | UNTIL 2 = 2
+ | END REPEAT;
+ | UNTIL 1 = 1
+ | END REPEAT;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[RepeatStatement])
+
+ val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement]
+ assert(repeatStmt.condition.isInstanceOf[SingleStatement])
+ assert(repeatStmt.condition.getText == "1 = 1")
+
+ assert(repeatStmt.body.isInstanceOf[CompoundBody])
+ assert(repeatStmt.body.collection.length == 1)
+
+ val nestedRepeatStmt =
repeatStmt.body.collection.head.asInstanceOf[RepeatStatement]
+ assert(nestedRepeatStmt.condition.isInstanceOf[SingleStatement])
+ assert(nestedRepeatStmt.condition.getText == "2 = 2")
+
+ assert(nestedRepeatStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(
+
nestedRepeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(nestedRepeatStmt.body.collection(1).isInstanceOf[IterateStatement])
+
assert(nestedRepeatStmt.body.collection(1).asInstanceOf[IterateStatement].label
== "lbl")
+ }
+
+ test("repeat") {
+ val sqlScriptText =
+ """BEGIN
+ |lbl: REPEAT
+ | SELECT 1;
+ | UNTIL 1 = 1
+ |END REPEAT lbl;
+ |END
+ """.stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[RepeatStatement])
+
+ val repeatStmt = tree.collection.head.asInstanceOf[RepeatStatement]
+ assert(repeatStmt.condition.isInstanceOf[SingleStatement])
+ assert(repeatStmt.condition.getText == "1 = 1")
+
+ assert(repeatStmt.body.isInstanceOf[CompoundBody])
+ assert(repeatStmt.body.collection.length == 1)
+ assert(repeatStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(repeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(repeatStmt.label.contains("lbl"))
+ }
+
+ test("repeat with complex condition") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+ |REPEAT
+ | SELECT 42;
+ |UNTIL
+ | (SELECT COUNT(*) < 2 FROM t)
+ |END REPEAT;
+ |END
+ |""".stripMargin
+
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 2)
+ assert(tree.collection(1).isInstanceOf[RepeatStatement])
+
+ val repeatStmt = tree.collection(1).asInstanceOf[RepeatStatement]
+ assert(repeatStmt.condition.isInstanceOf[SingleStatement])
+ assert(repeatStmt.condition.getText == "(SELECT COUNT(*) < 2 FROM t)")
+
+ assert(repeatStmt.body.isInstanceOf[CompoundBody])
+ assert(repeatStmt.body.collection.length == 1)
+ assert(repeatStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(repeatStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 42")
+ }
+
+ test("repeat with if else block") {
+ val sqlScriptText =
+ """BEGIN
+ |lbl: REPEAT
+ | IF 1 = 1 THEN
+ | SELECT 1;
+ | ELSE
+ | SELECT 2;
+ | END IF;
+ |UNTIL
+ | 1 = 1
+ |END REPEAT lbl;
+ |END
+ """.stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[RepeatStatement])
+
+ val whileStmt = tree.collection.head.asInstanceOf[RepeatStatement]
+ assert(whileStmt.condition.isInstanceOf[SingleStatement])
+ assert(whileStmt.condition.getText == "1 = 1")
+
+ assert(whileStmt.body.isInstanceOf[CompoundBody])
+ assert(whileStmt.body.collection.length == 1)
+ assert(whileStmt.body.collection.head.isInstanceOf[IfElseStatement])
+ val ifStmt = whileStmt.body.collection.head.asInstanceOf[IfElseStatement]
+
+ assert(ifStmt.conditions.length == 1)
+ assert(ifStmt.conditionalBodies.length == 1)
+ assert(ifStmt.elseBody.isDefined)
+
+ assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
+ assert(ifStmt.conditions.head.getText == "1 = 1")
+
+
assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 1")
+
+ assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+ assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 2")
+
+ assert(whileStmt.label.contains("lbl"))
+ }
+
+ test("nested repeat") {
+ val sqlScriptText =
+ """BEGIN
+ |lbl: REPEAT
+ | REPEAT
+ | SELECT 42;
+ | UNTIL
+ | 2 = 2
+ | END REPEAT;
+ |UNTIL
+ | 1 = 1
+ |END REPEAT lbl;
+ |END
+ """.stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[RepeatStatement])
+
+ val whileStmt = tree.collection.head.asInstanceOf[RepeatStatement]
+ assert(whileStmt.condition.isInstanceOf[SingleStatement])
+ assert(whileStmt.condition.getText == "1 = 1")
+
+ assert(whileStmt.body.isInstanceOf[CompoundBody])
+ assert(whileStmt.body.collection.length == 1)
+ assert(whileStmt.body.collection.head.isInstanceOf[RepeatStatement])
+ val nestedWhileStmt =
whileStmt.body.collection.head.asInstanceOf[RepeatStatement]
+
+ assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement])
+ assert(nestedWhileStmt.condition.getText == "2 = 2")
+
+ assert(nestedWhileStmt.body.isInstanceOf[CompoundBody])
+ assert(nestedWhileStmt.body.collection.length == 1)
+ assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement])
+ assert(nestedWhileStmt.body.collection.
+ head.asInstanceOf[SingleStatement].getText == "SELECT 42")
+
+ assert(whileStmt.label.contains("lbl"))
+
+ }
+
// Helper methods
def cleanupStatementString(statementStr: String): String = {
statementStr
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index 284ccc5d5bfe..cae797614314 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -405,6 +405,82 @@ class WhileStatementExec(
}
}
+/**
+ * Executable node for RepeatStatement.
+ * @param condition Executable node for the condition - evaluates to a row
with a single boolean
+ * expression, otherwise throws an exception
+ * @param body Executable node for the body.
+ * @param label Label set to RepeatStatement by user, None if not set
+ * @param session Spark session that SQL script is executed within.
+ */
+class RepeatStatementExec(
+ condition: SingleStatementExec,
+ body: CompoundBodyExec,
+ label: Option[String],
+ session: SparkSession) extends NonLeafStatementExec {
+
+ private object RepeatState extends Enumeration {
+ val Condition, Body = Value
+ }
+
+ private var state = RepeatState.Body
+ private var curr: Option[CompoundStatementExec] = Some(body)
+
+ private lazy val treeIterator: Iterator[CompoundStatementExec] =
+ new Iterator[CompoundStatementExec] {
+ override def hasNext: Boolean = curr.nonEmpty
+
+ override def next(): CompoundStatementExec = state match {
+ case RepeatState.Condition =>
+ val condition = curr.get.asInstanceOf[SingleStatementExec]
+ if (!evaluateBooleanCondition(session, condition)) {
+ state = RepeatState.Body
+ curr = Some(body)
+ body.reset()
+ } else {
+ curr = None
+ }
+ condition
+ case RepeatState.Body =>
+ val retStmt = body.getTreeIterator.next()
+
+ retStmt match {
+ case leaveStatementExec: LeaveStatementExec if
!leaveStatementExec.hasBeenMatched =>
+ if (label.contains(leaveStatementExec.label)) {
+ leaveStatementExec.hasBeenMatched = true
+ }
+ curr = None
+ return retStmt
+ case iterStatementExec: IterateStatementExec if
!iterStatementExec.hasBeenMatched =>
+ if (label.contains(iterStatementExec.label)) {
+ iterStatementExec.hasBeenMatched = true
+ }
+ state = RepeatState.Condition
+ curr = Some(condition)
+ condition.reset()
+ return retStmt
+ case _ =>
+ }
+
+ if (!body.getTreeIterator.hasNext) {
+ state = RepeatState.Condition
+ curr = Some(condition)
+ condition.reset()
+ }
+ retStmt
+ }
+ }
+
+ override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+ override def reset(): Unit = {
+ state = RepeatState.Body
+ curr = Some(body)
+ body.reset()
+ condition.reset()
+ }
+}
+
/**
* Executable node for LeaveStatement.
* @param label Label of the compound or loop to leave.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 8a5a9774d42f..865b33999655 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
-import org.apache.spark.sql.catalyst.parser.{CompoundBody,
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement,
SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.parser.{CompoundBody,
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement,
RepeatStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable,
DropVariable, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.Origin
@@ -84,6 +84,7 @@ case class SqlScriptingInterpreter() {
new CompoundBodyExec(
collection.map(st => transformTreeIntoExecutable(st, session)) ++
dropVariables,
label)
+
case IfElseStatement(conditions, conditionalBodies, elseBody) =>
val conditionsExec = conditions.map(condition =>
new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false))
@@ -93,16 +94,27 @@ case class SqlScriptingInterpreter() {
transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec])
new IfElseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec,
session)
+
case WhileStatement(condition, body, label) =>
val conditionExec =
new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false)
val bodyExec =
transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec]
new WhileStatementExec(conditionExec, bodyExec, label, session)
+
+ case RepeatStatement(condition, body, label) =>
+ val conditionExec =
+ new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false)
+ val bodyExec =
+ transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec]
+ new RepeatStatementExec(conditionExec, bodyExec, label, session)
+
case leaveStatement: LeaveStatement =>
new LeaveStatementExec(leaveStatement.label)
+
case iterateStatement: IterateStatement =>
new IterateStatementExec(iterateStatement.label)
+
case sparkStatement: SingleStatement =>
new SingleStatementExec(
sparkStatement.parsedPlan,
diff --git
a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
index b2f3fdda74db..e6a36ac2445c 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
@@ -251,6 +251,7 @@ REFERENCES true
REFRESH false
RENAME false
REPAIR false
+REPEAT false
REPEATABLE false
REPLACE false
RESET false
@@ -336,6 +337,7 @@ UNKNOWN true
UNLOCK false
UNPIVOT false
UNSET false
+UNTIL false
UPDATE false
USE false
USER true
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
index ce9fd580b2ff..19816c8252c9 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
@@ -251,6 +251,7 @@ REFERENCES false
REFRESH false
RENAME false
REPAIR false
+REPEAT false
REPEATABLE false
REPLACE false
RESET false
@@ -336,6 +337,7 @@ UNKNOWN false
UNLOCK false
UNPIVOT false
UNSET false
+UNTIL false
UPDATE false
USE false
USER false
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index 97a21c505fdd..4b72ca8ecaa9 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -45,24 +45,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
override def output: Seq[Attribute] = Seq.empty
}
- case class TestWhileCondition(
+ case class TestLoopCondition(
condVal: Boolean, reps: Int, description: String)
extends SingleStatementExec(
parsedPlan = DummyLogicalPlan(),
Origin(startIndex = Some(0), stopIndex = Some(description.length)),
isInternal = false)
- case class TestWhile(
- condition: TestWhileCondition,
- body: CompoundBodyExec,
- label: Option[String] = None)
- extends WhileStatementExec(condition, body, label, spark) {
-
+ class LoopBooleanConditionEvaluator(condition: TestLoopCondition) {
private var callCount: Int = 0
- override def evaluateBooleanCondition(
- session: SparkSession,
- statement: LeafStatementExec): Boolean = {
+ def evaluateLoopBooleanCondition(): Boolean = {
if (callCount < condition.reps) {
callCount += 1
condition.condVal
@@ -73,11 +66,37 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
}
}
+ case class TestWhile(
+ condition: TestLoopCondition,
+ body: CompoundBodyExec,
+ label: Option[String] = None)
+ extends WhileStatementExec(condition, body, label, spark) {
+
+ private val evaluator = new LoopBooleanConditionEvaluator(condition)
+
+ override def evaluateBooleanCondition(
+ session: SparkSession,
+ statement: LeafStatementExec): Boolean =
evaluator.evaluateLoopBooleanCondition()
+ }
+
+ case class TestRepeat(
+ condition: TestLoopCondition,
+ body: CompoundBodyExec,
+ label: Option[String] = None)
+ extends RepeatStatementExec(condition, body, label, spark) {
+
+ private val evaluator = new LoopBooleanConditionEvaluator(condition)
+
+ override def evaluateBooleanCondition(
+ session: SparkSession,
+ statement: LeafStatementExec): Boolean =
evaluator.evaluateLoopBooleanCondition()
+ }
+
private def extractStatementValue(statement: CompoundStatementExec): String =
statement match {
case TestLeafStatement(testVal) => testVal
case TestIfElseCondition(_, description) => description
- case TestWhileCondition(_, _, description) => description
+ case TestLoopCondition(_, _, description) => description
case leaveStmt: LeaveStatementExec => leaveStmt.label
case iterateStmt: IterateStatementExec => iterateStmt.label
case _ => fail("Unexpected statement type")
@@ -265,7 +284,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
test("while - doesn't enter body") {
val iter = new CompoundBodyExec(Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 0, description =
"con1"),
+ condition = TestLoopCondition(condVal = true, reps = 0, description =
"con1"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
)
)).getTreeIterator
@@ -276,7 +295,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
test("while - enters body once") {
val iter = new CompoundBodyExec(Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 1, description =
"con1"),
+ condition = TestLoopCondition(condVal = true, reps = 1, description =
"con1"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
)
)).getTreeIterator
@@ -287,7 +306,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
test("while - enters body with multiple statements multiple times") {
val iter = new CompoundBodyExec(Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2, description =
"con1"),
+ condition = TestLoopCondition(condVal = true, reps = 2, description =
"con1"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("statement1"),
TestLeafStatement("statement2")))
@@ -301,10 +320,10 @@ class SqlScriptingExecutionNodeSuite extends
SparkFunSuite with SharedSparkSessi
test("nested while - 2 times outer 2 times inner") {
val iter = new CompoundBodyExec(Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2, description =
"con1"),
+ condition = TestLoopCondition(condVal = true, reps = 2, description =
"con1"),
body = new CompoundBodyExec(Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2,
description = "con2"),
+ condition = TestLoopCondition(condVal = true, reps = 2,
description = "con2"),
body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
))
)
@@ -317,6 +336,64 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
"con2", "body1", "con2", "con1"))
}
+ test("repeat - true condition") {
+ val iter = new CompoundBodyExec(Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 0, description =
"con1"),
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("body1", "con1"))
+ }
+
+ test("repeat - condition false once") {
+ val iter = new CompoundBodyExec(Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 1, description =
"con1"),
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("body1", "con1", "body1", "con1"))
+ }
+
+ test("repeat - enters body with multiple statements multiple times") {
+ val iter = new CompoundBodyExec(Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2, description =
"con1"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("statement1"),
+ TestLeafStatement("statement2")))
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("statement1", "statement2", "con1",
"statement1", "statement2",
+ "con1", "statement1", "statement2", "con1"))
+ }
+
+ test("nested repeat") {
+ val iter = new CompoundBodyExec(Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2, description =
"con1"),
+ body = new CompoundBodyExec(Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2,
description = "con2"),
+ body = new CompoundBodyExec(Seq(TestLeafStatement("body1")))
+ ))
+ )
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("body1", "con2", "body1",
+ "con2", "body1", "con2",
+ "con1", "body1", "con2",
+ "body1", "con2", "body1",
+ "con2", "con1", "body1",
+ "con2", "body1", "con2",
+ "body1", "con2", "con1"))
+ }
+
test("leave compound block") {
val iter = new CompoundBodyExec(
statements = Seq(
@@ -333,7 +410,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
val iter = new CompoundBodyExec(
statements = Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2, description
= "con1"),
+ condition = TestLoopCondition(condVal = true, reps = 2, description
= "con1"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
new LeaveStatementExec("lbl"))
@@ -346,11 +423,28 @@ class SqlScriptingExecutionNodeSuite extends
SparkFunSuite with SharedSparkSessi
assert(statements === Seq("con1", "body1", "lbl"))
}
+ test("leave repeat loop") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2, description
= "con1"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new LeaveStatementExec("lbl"))
+ ),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("body1", "lbl"))
+ }
+
test("iterate while loop") {
val iter = new CompoundBodyExec(
statements = Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2, description
= "con1"),
+ condition = TestLoopCondition(condVal = true, reps = 2, description
= "con1"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
new IterateStatementExec("lbl"),
@@ -364,14 +458,33 @@ class SqlScriptingExecutionNodeSuite extends
SparkFunSuite with SharedSparkSessi
assert(statements === Seq("con1", "body1", "lbl", "con1", "body1", "lbl",
"con1"))
}
+ test("iterate repeat loop") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2, description
= "con1"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new IterateStatementExec("lbl"),
+ TestLeafStatement("body2"))
+ ),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(
+ statements === Seq("body1", "lbl", "con1", "body1", "lbl", "con1",
"body1", "lbl", "con1"))
+ }
+
test("leave outer loop from nested while loop") {
val iter = new CompoundBodyExec(
statements = Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2, description
= "con1"),
+ condition = TestLoopCondition(condVal = true, reps = 2, description
= "con1"),
body = new CompoundBodyExec(Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2,
description = "con2"),
+ condition = TestLoopCondition(condVal = true, reps = 2,
description = "con2"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
new LeaveStatementExec("lbl"))
@@ -387,14 +500,37 @@ class SqlScriptingExecutionNodeSuite extends
SparkFunSuite with SharedSparkSessi
assert(statements === Seq("con1", "con2", "body1", "lbl"))
}
+ test("leave outer loop from nested repeat loop") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2, description
= "con1"),
+ body = new CompoundBodyExec(Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2,
description = "con2"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new LeaveStatementExec("lbl"))
+ ),
+ label = Some("lbl2")
+ )
+ )),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("body1", "lbl"))
+ }
+
test("iterate outer loop from nested while loop") {
val iter = new CompoundBodyExec(
statements = Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2, description
= "con1"),
+ condition = TestLoopCondition(condVal = true, reps = 2, description
= "con1"),
body = new CompoundBodyExec(Seq(
TestWhile(
- condition = TestWhileCondition(condVal = true, reps = 2,
description = "con2"),
+ condition = TestLoopCondition(condVal = true, reps = 2,
description = "con2"),
body = new CompoundBodyExec(Seq(
TestLeafStatement("body1"),
new IterateStatementExec("lbl"),
@@ -413,4 +549,31 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
"con1", "con2", "body1", "lbl",
"con1"))
}
+
+ test("iterate outer loop from nested repeat loop") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2, description
= "con1"),
+ body = new CompoundBodyExec(Seq(
+ TestRepeat(
+ condition = TestLoopCondition(condVal = false, reps = 2,
description = "con2"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new IterateStatementExec("lbl"),
+ TestLeafStatement("body2"))
+ ),
+ label = Some("lbl2")
+ )
+ )),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "body1", "lbl", "con1",
+ "body1", "lbl", "con1",
+ "body1", "lbl", "con1"))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index a45cd0bf010b..b703e77d4d73 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -537,6 +537,195 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
}
}
+ test("repeat") {
+ val commands =
+ """
+ |BEGIN
+ | DECLARE i = 0;
+ | REPEAT
+ | SELECT i;
+ | SET VAR i = i + 1;
+ | UNTIL
+ | i = 3
+ | END REPEAT;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare i
+ Seq(Row(0)), // select i
+ Seq.empty[Row], // set i
+ Seq(Row(1)), // select i
+ Seq.empty[Row], // set i
+ Seq(Row(2)), // select i
+ Seq.empty[Row], // set i
+ Seq.empty[Row] // drop var
+ )
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("repeat: enters body only once") {
+ val commands =
+ """
+ |BEGIN
+ | DECLARE i = 3;
+ | REPEAT
+ | SELECT i;
+ | SET VAR i = i + 1;
+ | UNTIL
+ | 1 = 1
+ | END REPEAT;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare i
+ Seq(Row(3)), // select i
+ Seq.empty[Row], // set i
+ Seq.empty[Row] // drop i
+ )
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("nested repeat") {
+ val commands =
+ """
+ |BEGIN
+ | DECLARE i = 0;
+ | DECLARE j = 0;
+ | REPEAT
+ | SET VAR j = 0;
+ | REPEAT
+ | SELECT i, j;
+ | SET VAR j = j + 1;
+ | UNTIL j >= 2
+ | END REPEAT;
+ | SET VAR i = i + 1;
+ | UNTIL i >= 2
+ | END REPEAT;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // declare i
+ Seq.empty[Row], // declare j
+ Seq.empty[Row], // set j to 0
+ Seq(Row(0, 0)), // select i, j
+ Seq.empty[Row], // increase j
+ Seq(Row(0, 1)), // select i, j
+ Seq.empty[Row], // increase j
+ Seq.empty[Row], // increase i
+ Seq.empty[Row], // set j to 0
+ Seq(Row(1, 0)), // select i, j
+ Seq.empty[Row], // increase j
+ Seq(Row(1, 1)), // select i, j
+ Seq.empty[Row], // increase j
+ Seq.empty[Row], // increase i
+ Seq.empty[Row], // drop j
+ Seq.empty[Row] // drop i
+ )
+ verifySqlScriptResult(commands, expected)
+ }
+
+ test("repeat with count") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ |CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+ |REPEAT
+ | SELECT 42;
+ | INSERT INTO t VALUES (1, 'a', 1.0);
+ |UNTIL (SELECT COUNT(*) >= 2 FROM t)
+ |END REPEAT;
+ |END
+ |""".stripMargin
+
+ val expected = Seq(
+ Seq.empty[Row], // create table
+ Seq(Row(42)), // select
+ Seq.empty[Row], // insert
+ Seq(Row(42)), // select
+ Seq.empty[Row] // insert
+ )
+ verifySqlScriptResult(commands, expected)
+ }
+ }
+
+ test("repeat with non boolean condition - constant") {
+ val commands =
+ """
+ |BEGIN
+ | DECLARE i = 0;
+ | REPEAT
+ | SELECT i;
+ | SET VAR i = i + 1;
+ | UNTIL
+ | 1
+ | END REPEAT;
+ |END
+ |""".stripMargin
+
+ checkError(
+ exception = intercept[SqlScriptingException] (
+ runSqlScript(commands)
+ ),
+ errorClass = "INVALID_BOOLEAN_STATEMENT",
+ parameters = Map("invalidStatement" -> "1")
+ )
+ }
+
+ test("repeat with empty subquery condition") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ | CREATE TABLE t (a BOOLEAN) USING parquet;
+ | REPEAT
+ | SELECT 1;
+ | UNTIL
+ | (SELECT * FROM t)
+ | END REPEAT;
+ |END
+ |""".stripMargin
+
+ checkError(
+ exception = intercept[SqlScriptingException] (
+ runSqlScript(commands)
+ ),
+ errorClass = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
+ parameters = Map("invalidStatement" -> "(SELECT * FROM T)")
+ )
+ }
+ }
+
+ test("repeat with too many rows in subquery condition") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ | CREATE TABLE t (a BOOLEAN) USING parquet;
+ | INSERT INTO t VALUES (true);
+ | INSERT INTO t VALUES (true);
+ | REPEAT
+ | SELECT 1;
+ | UNTIL
+ | (SELECT * FROM t)
+ | END REPEAT;
+ |END
+ |""".stripMargin
+
+ checkError(
+ exception = intercept[SparkException] (
+ runSqlScript(commands)
+ ),
+ errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
+ parameters = Map.empty,
+ context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 141,
stop = 157)
+ )
+ }
+ }
+
test("leave compound block") {
val sqlScriptText =
"""
@@ -565,6 +754,22 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScriptText, expected)
}
+ test("leave repeat loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: REPEAT
+ | SELECT 1;
+ | LEAVE lbl;
+ | UNTIL 1 = 2
+ | END REPEAT;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq(Row(1)) // select 1
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
test("iterate compound block - should fail") {
val sqlScriptText =
"""
@@ -604,6 +809,31 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScriptText, expected)
}
+ test("iterate repeat loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE x INT;
+ | SET x = 0;
+ | lbl: REPEAT
+ | SET x = x + 1;
+ | ITERATE lbl;
+ | SET x = x + 2;
+ | UNTIL x > 1
+ | END REPEAT;
+ | SELECT x;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq.empty[Row], // declare
+ Seq.empty[Row], // set x = 0
+ Seq.empty[Row], // set x = 1
+ Seq.empty[Row], // set x = 2
+ Seq(Row(2)), // select x
+ Seq.empty[Row] // drop
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
test("leave with wrong label - should fail") {
val sqlScriptText =
"""
@@ -634,6 +864,25 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
parameters = Map("labelName" -> "RANDOMLBL", "statementType" ->
"ITERATE"))
}
+ test("leave outer loop from nested repeat loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: REPEAT
+ | lbl2: REPEAT
+ | SELECT 1;
+ | LEAVE lbl;
+ | UNTIL 1 = 2
+ | END REPEAT;
+ | UNTIL 1 = 2
+ | END REPEAT;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq(Row(1)) // select 1
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
test("leave outer loop from nested while loop") {
val sqlScriptText =
"""
@@ -671,7 +920,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
Seq.empty[Row], // set x = 0
Seq.empty[Row], // set x = 1
Seq(Row(1)), // select 1
- Seq.empty[Row], // set x= 2
+ Seq.empty[Row], // set x = 2
Seq(Row(1)), // select 1
Seq(Row(2)), // select x
Seq.empty[Row] // drop
@@ -712,4 +961,34 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
)
verifySqlScriptResult(sqlScriptText, expected)
}
+
+ test("iterate outer loop from nested repeat loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE x INT;
+ | SET x = 0;
+ | lbl: REPEAT
+ | SET x = x + 1;
+ | lbl2: REPEAT
+ | SELECT 1;
+ | ITERATE lbl;
+ | UNTIL 1 = 2
+ | END REPEAT;
+ | UNTIL x > 1
+ | END REPEAT;
+ | SELECT x;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq.empty[Row], // declare
+ Seq.empty[Row], // set x = 0
+ Seq.empty[Row], // set x = 1
+ Seq(Row(1)), // select 1
+ Seq.empty[Row], // set x = 2
+ Seq(Row(1)), // select 1
+ Seq(Row(2)), // select x
+ Seq.empty[Row] // drop
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
}
diff --git
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
index 2e3457dab09b..6f0b6bccac30 100644
---
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
+++
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
@@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends
SharedThriftServer {
val sessionHandle = client.openSession(user, "")
val infoValue = client.getInfo(sessionHandle,
GetInfoType.CLI_ODBC_KEYWORDS)
// scalastyle:off line.size.limit
- assert(infoValue.getStringValue ==
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DAT
[...]
+ assert(infoValue.getStringValue ==
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DAT
[...]
// scalastyle:on line.size.limit
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]