This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 216f761bcb12 [SPARK-48357][SQL] Support for LOOP statement
216f761bcb12 is described below

commit 216f761bcb122a253d42793466a9fe97e7ba3336
Author: Dušan Tišma <[email protected]>
AuthorDate: Thu Oct 3 09:12:23 2024 +0200

    [SPARK-48357][SQL] Support for LOOP statement
    
    ### What changes were proposed in this pull request?
    In this PR, support for LOOP statement in SQL scripting is introduced.
    
    Changes summary:
    
    Grammar/parser changes:
    - `loopStatement` grammar rule
    - `visitLoopStatement` rule visitor
    - `LoopStatement` logical operator
    
    `LoopStatementExec` execution node
    Iterator implementation - repeatedly execute body (only way to stop the 
loop is with LEAVE, or if an exception occurs)
    `SqlScriptingInterpreter` - added logic to transform LoopStatement logical 
operator to LoopStatementExec execution node
    
    ### Why are the changes needed?
    This is a part of SQL Scripting introduced to Spark, LOOP statement is a 
basic control flow construct in the SQL language.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    New tests are introduced to scripting test suites: 
`SqlScriptingParserSuite`, `SqlScriptingExecutionNodeSuite` and 
`SqlScriptingInterpreterSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48323 from dusantism-db/sql-scripting-loop.
    
    Authored-by: Dušan Tišma <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 docs/sql-ref-ansi-compliance.md                    |   1 +
 .../spark/sql/catalyst/parser/SqlBaseLexer.g4      |   1 +
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |   7 +
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  11 ++
 .../parser/SqlScriptingLogicalOperators.scala      |  18 +-
 .../catalyst/parser/SqlScriptingParserSuite.scala  | 205 +++++++++++++++++++++
 .../sql/scripting/SqlScriptingExecutionNode.scala  |  52 ++++++
 .../sql/scripting/SqlScriptingInterpreter.scala    |   6 +-
 .../sql-tests/results/ansi/keywords.sql.out        |   1 +
 .../resources/sql-tests/results/keywords.sql.out   |   1 +
 .../scripting/SqlScriptingExecutionNodeSuite.scala |  17 ++
 .../scripting/SqlScriptingInterpreterSuite.scala   | 152 +++++++++++++++
 .../ThriftServerWithSparkContextSuite.scala        |   2 +-
 13 files changed, 469 insertions(+), 5 deletions(-)

diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index 12dff1e325c4..b4446b1538cd 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -581,6 +581,7 @@ Below is a list of all the keywords in Spark SQL.
 |LOCKS|non-reserved|non-reserved|non-reserved|
 |LOGICAL|non-reserved|non-reserved|non-reserved|
 |LONG|non-reserved|non-reserved|non-reserved|
+|LOOP|non-reserved|non-reserved|non-reserved|
 |MACRO|non-reserved|non-reserved|non-reserved|
 |MAP|non-reserved|non-reserved|non-reserved|
 |MATCHED|non-reserved|non-reserved|non-reserved|
diff --git 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index de28041acd41..7391e8c353de 100644
--- 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++ 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -301,6 +301,7 @@ LOCK: 'LOCK';
 LOCKS: 'LOCKS';
 LOGICAL: 'LOGICAL';
 LONG: 'LONG';
+LOOP: 'LOOP';
 MACRO: 'MACRO';
 MAP: 'MAP' {incComplexTypeLevelCounter();};
 MATCHED: 'MATCHED';
diff --git 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index e8e2e980135a..644c7e732fbf 100644
--- 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -69,6 +69,7 @@ compoundStatement
     | repeatStatement
     | leaveStatement
     | iterateStatement
+    | loopStatement
     ;
 
 setStatementWithOptionalVarKeyword
@@ -106,6 +107,10 @@ caseStatement
         (ELSE elseBody=compoundBody)? END CASE                
#simpleCaseStatement
     ;
 
+loopStatement
+    : beginLabel? LOOP compoundBody END LOOP endLabel?
+    ;
+
 singleStatement
     : (statement|setResetStatement) SEMICOLON* EOF
     ;
@@ -1658,6 +1663,7 @@ ansiNonReserved
     | LOCKS
     | LOGICAL
     | LONG
+    | LOOP
     | MACRO
     | MAP
     | MATCHED
@@ -2016,6 +2022,7 @@ nonReserved
     | LOCKS
     | LOGICAL
     | LONG
+    | LOOP
     | MACRO
     | MAP
     | MATCHED
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 9ce96ae652fe..f1d211f51778 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -337,6 +337,10 @@ class AstBuilder extends DataTypeAstBuilder
         if Option(c.beginLabel()).isDefined &&
           
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
         => true
+      case c: LoopStatementContext
+        if Option(c.beginLabel()).isDefined &&
+          
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+        => true
       case _ => false
     }
   }
@@ -373,6 +377,13 @@ class AstBuilder extends DataTypeAstBuilder
         CurrentOrigin.get, labelText, "ITERATE")
     }
 
+  override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = {
+    val labelText = generateLabelText(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
+    val body = visitCompoundBody(ctx.compoundBody())
+
+    LoopStatement(body, Some(labelText))
+  }
+
   override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan 
= withOrigin(ctx) {
     Option(ctx.statement().asInstanceOf[ParserRuleContext])
       .orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext]))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
index ed40a5fd734b..9fd87f51bd57 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
@@ -81,7 +81,7 @@ case class IfElseStatement(
  *                 Body is executed as long as the condition evaluates to true
  * @param body Compound body is a collection of statements that are executed 
if condition is true.
  * @param label An optional label for the loop which is unique amongst all 
labels for statements
- *              within which the LOOP statement is contained.
+ *              within which the WHILE statement is contained.
  *              If an end label is specified it must match the beginning label.
  *              The label can be used to LEAVE or ITERATE the loop.
  */
@@ -97,7 +97,7 @@ case class WhileStatement(
  * @param body Compound body is a collection of statements that are executed 
once no matter what,
  *             and then as long as condition is false.
  * @param label An optional label for the loop which is unique amongst all 
labels for statements
- *              within which the LOOP statement is contained.
+ *              within which the REPEAT statement is contained.
  *              If an end label is specified it must match the beginning label.
  *              The label can be used to LEAVE or ITERATE the loop.
  */
@@ -106,7 +106,6 @@ case class RepeatStatement(
     body: CompoundBody,
     label: Option[String]) extends CompoundPlanStatement
 
-
 /**
  * Logical operator for LEAVE statement.
  * The statement can be used both for compounds or any kind of loops.
@@ -138,3 +137,16 @@ case class CaseStatement(
     elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
   assert(conditions.length == conditionalBodies.length)
 }
+
+/**
+ * Logical operator for LOOP statement.
+ * @param body Compound body is a collection of statements that are executed 
until the
+ *             LOOP statement is terminated by using the LEAVE statement.
+ * @param label An optional label for the loop which is unique amongst all 
labels for statements
+ *              within which the LOOP statement is contained.
+ *              If an end label is specified it must match the beginning label.
+ *              The label can be used to LEAVE or ITERATE the loop.
+ */
+case class LoopStatement(
+    body: CompoundBody,
+    label: Option[String]) extends CompoundPlanStatement
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index ba634333e06f..2972ba2db21d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -1400,6 +1400,211 @@ class SqlScriptingParserSuite extends SparkFunSuite 
with SQLHelper {
       .getText == "SELECT 42")
   }
 
+  test("loop statement") {
+    val sqlScriptText =
+      """BEGIN
+        |lbl: LOOP
+        |  SELECT 1;
+        |  SELECT 2;
+        |END LOOP lbl;
+        |END
+      """.stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+    val whileStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+    assert(whileStmt.body.isInstanceOf[CompoundBody])
+    assert(whileStmt.body.collection.length == 2)
+    assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement])
+    
assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == 
"SELECT 1")
+
+    assert(whileStmt.label.contains("lbl"))
+  }
+
+  test("loop with if else block") {
+    val sqlScriptText =
+      """BEGIN
+        |lbl: LOOP
+        | IF 1 = 1 THEN
+        |   SELECT 1;
+        | ELSE
+        |   SELECT 2;
+        | END IF;
+        |END LOOP lbl;
+        |END
+      """.stripMargin
+
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+    val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+    assert(loopStmt.body.isInstanceOf[CompoundBody])
+    assert(loopStmt.body.collection.length == 1)
+    assert(loopStmt.body.collection.head.isInstanceOf[IfElseStatement])
+    val ifStmt = loopStmt.body.collection.head.asInstanceOf[IfElseStatement]
+
+    assert(ifStmt.conditions.length == 1)
+    assert(ifStmt.conditionalBodies.length == 1)
+    assert(ifStmt.elseBody.isDefined)
+
+    assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
+    assert(ifStmt.conditions.head.getText == "1 = 1")
+
+    
assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
+    
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 1")
+
+    assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
+    assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
+      .getText == "SELECT 2")
+
+    assert(loopStmt.label.contains("lbl"))
+  }
+
+  test("nested loop") {
+    val sqlScriptText =
+      """BEGIN
+        |lbl: LOOP
+        |  LOOP
+        |    SELECT 42;
+        |  END LOOP;
+        |END LOOP lbl;
+        |END
+      """.stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+    val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+    assert(loopStmt.body.isInstanceOf[CompoundBody])
+    assert(loopStmt.body.collection.length == 1)
+    assert(loopStmt.body.collection.head.isInstanceOf[LoopStatement])
+    val nestedLoopStmt = 
loopStmt.body.collection.head.asInstanceOf[LoopStatement]
+
+    assert(nestedLoopStmt.body.isInstanceOf[CompoundBody])
+    assert(nestedLoopStmt.body.collection.length == 1)
+    assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement])
+    assert(nestedLoopStmt.body.collection.
+      head.asInstanceOf[SingleStatement].getText == "SELECT 42")
+
+    assert(loopStmt.label.contains("lbl"))
+  }
+
+  test("leave loop statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  lbl: LOOP
+        |    SELECT 1;
+        |    LEAVE lbl;
+        |  END LOOP;
+        |END""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+    val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+    assert(loopStmt.body.isInstanceOf[CompoundBody])
+    assert(loopStmt.body.collection.length == 2)
+
+    assert(loopStmt.body.collection.head.isInstanceOf[SingleStatement])
+    assert(loopStmt.body.collection.head.asInstanceOf[SingleStatement].getText 
== "SELECT 1")
+
+    assert(loopStmt.body.collection(1).isInstanceOf[LeaveStatement])
+    assert(loopStmt.body.collection(1).asInstanceOf[LeaveStatement].label == 
"lbl")
+  }
+
+  test("iterate loop statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  lbl: LOOP
+        |    SELECT 1;
+        |    ITERATE lbl;
+        |  END LOOP;
+        |END""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+    val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+    assert(loopStmt.body.isInstanceOf[CompoundBody])
+    assert(loopStmt.body.collection.length == 2)
+
+    assert(loopStmt.body.collection.head.isInstanceOf[SingleStatement])
+    assert(loopStmt.body.collection.head.asInstanceOf[SingleStatement].getText 
== "SELECT 1")
+
+    assert(loopStmt.body.collection(1).isInstanceOf[IterateStatement])
+    assert(loopStmt.body.collection(1).asInstanceOf[IterateStatement].label == 
"lbl")
+  }
+
+  test("leave outer loop from nested loop statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  lbl: LOOP
+        |    lbl2: LOOP
+        |      SELECT 1;
+        |      LEAVE lbl;
+        |    END LOOP;
+        |  END LOOP;
+        |END""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+    val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+    assert(loopStmt.body.isInstanceOf[CompoundBody])
+    assert(loopStmt.body.collection.length == 1)
+
+    val nestedLoopStmt = 
loopStmt.body.collection.head.asInstanceOf[LoopStatement]
+
+    assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement])
+    assert(
+      
nestedLoopStmt.body.collection.head.asInstanceOf[SingleStatement].getText == 
"SELECT 1")
+
+    assert(nestedLoopStmt.body.collection(1).isInstanceOf[LeaveStatement])
+    
assert(nestedLoopStmt.body.collection(1).asInstanceOf[LeaveStatement].label == 
"lbl")
+  }
+
+  test("iterate outer loop from nested loop statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  lbl: LOOP
+        |    lbl2: LOOP
+        |      SELECT 1;
+        |      ITERATE lbl;
+        |    END LOOP;
+        |  END LOOP;
+        |END""".stripMargin
+    val tree = parseScript(sqlScriptText)
+    assert(tree.collection.length == 1)
+    assert(tree.collection.head.isInstanceOf[LoopStatement])
+
+    val loopStmt = tree.collection.head.asInstanceOf[LoopStatement]
+
+    assert(loopStmt.body.isInstanceOf[CompoundBody])
+    assert(loopStmt.body.collection.length == 1)
+
+    val nestedLoopStmt = 
loopStmt.body.collection.head.asInstanceOf[LoopStatement]
+
+    assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement])
+    assert(
+      
nestedLoopStmt.body.collection.head.asInstanceOf[SingleStatement].getText == 
"SELECT 1")
+
+    assert(nestedLoopStmt.body.collection(1).isInstanceOf[IterateStatement])
+    
assert(nestedLoopStmt.body.collection(1).asInstanceOf[IterateStatement].label 
== "lbl")
+  }
+
   // Helper methods
   def cleanupStatementString(statementStr: String): String = {
     statementStr
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index af9fd5464277..9fdb9626556f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -592,3 +592,55 @@ class IterateStatementExec(val label: String) extends 
LeafStatementExec {
   var hasBeenMatched: Boolean = false
   override def reset(): Unit = hasBeenMatched = false
 }
+
+class LoopStatementExec(
+    body: CompoundBodyExec,
+    val label: Option[String]) extends NonLeafStatementExec {
+
+  /**
+   * Loop can be interrupted by LeaveStatementExec
+   */
+  private var interrupted: Boolean = false
+
+  /**
+   * Loop can be iterated by IterateStatementExec
+   */
+  private var iterated: Boolean = false
+
+  private lazy val treeIterator =
+    new Iterator[CompoundStatementExec] {
+      override def hasNext: Boolean = !interrupted
+
+      override def next(): CompoundStatementExec = {
+        if (!body.getTreeIterator.hasNext || iterated) {
+          reset()
+        }
+
+        val retStmt = body.getTreeIterator.next()
+
+        retStmt match {
+          case leaveStatementExec: LeaveStatementExec if 
!leaveStatementExec.hasBeenMatched =>
+            if (label.contains(leaveStatementExec.label)) {
+              leaveStatementExec.hasBeenMatched = true
+            }
+            interrupted = true
+          case iterStatementExec: IterateStatementExec if 
!iterStatementExec.hasBeenMatched =>
+            if (label.contains(iterStatementExec.label)) {
+              iterStatementExec.hasBeenMatched = true
+            }
+            iterated = true
+          case _ =>
+        }
+
+        retStmt
+      }
+    }
+
+  override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+  override def reset(): Unit = {
+    interrupted = false
+    iterated = false
+    body.reset()
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 917b4d6f45ee..78ef715e1898 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
-import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, 
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, 
RepeatStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, 
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, 
LoopStatement, RepeatStatement, SingleStatement, WhileStatement}
 import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, 
DropVariable, LogicalPlan}
 import org.apache.spark.sql.catalyst.trees.Origin
 
@@ -120,6 +120,10 @@ case class SqlScriptingInterpreter() {
           transformTreeIntoExecutable(body, 
session).asInstanceOf[CompoundBodyExec]
         new RepeatStatementExec(conditionExec, bodyExec, label, session)
 
+      case LoopStatement(body, label) =>
+        val bodyExec = transformTreeIntoExecutable(body, 
session).asInstanceOf[CompoundBodyExec]
+        new LoopStatementExec(bodyExec, label)
+
       case leaveStatement: LeaveStatement =>
         new LeaveStatementExec(leaveStatement.label)
 
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
index 7c694503056a..d9d266e8a674 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
@@ -187,6 +187,7 @@ LOCK        false
 LOCKS  false
 LOGICAL        false
 LONG   false
+LOOP   false
 MACRO  false
 MAP    false
 MATCHED        false
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out 
b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
index 2c16d961b131..cd93a811d64f 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
@@ -187,6 +187,7 @@ LOCK        false
 LOCKS  false
 LOGICAL        false
 LONG   false
+LOOP   false
 MACRO  false
 MAP    false
 MATCHED        false
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index 83d8191d01ec..baad5702f4f2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -97,6 +97,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite 
with SharedSparkSessi
       case TestLeafStatement(testVal) => testVal
       case TestIfElseCondition(_, description) => description
       case TestLoopCondition(_, _, description) => description
+      case loopStmt: LoopStatementExec => loopStmt.label.get
       case leaveStmt: LeaveStatementExec => leaveStmt.label
       case iterateStmt: IterateStatementExec => iterateStmt.label
       case _ => fail("Unexpected statement type")
@@ -669,4 +670,20 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite 
with SharedSparkSessi
     val statements = iter.map(extractStatementValue).toSeq
     assert(statements === Seq("con1", "con2"))
   }
+
+  test("loop statement with leave") {
+    val iter = new CompoundBodyExec(
+      statements = Seq(
+        new LoopStatementExec(
+          body = new CompoundBodyExec(Seq(
+            TestLeafStatement("body1"),
+            new LeaveStatementExec("lbl"))
+          ),
+          label = Some("lbl")
+        )
+      )
+    ).getTreeIterator
+    val statements = iter.map(extractStatementValue).toSeq
+    assert(statements === Seq("body1", "lbl"))
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index ac190eb48d1f..3551608a1ee8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -1383,4 +1383,156 @@ class SqlScriptingInterpreterSuite extends QueryTest 
with SharedSparkSession {
     )
     verifySqlScriptResult(sqlScriptText, expected)
   }
+
+  test("loop statement with leave") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  DECLARE x INT;
+        |  SET x = 0;
+        |  lbl: LOOP
+        |    SET x = x + 1;
+        |    SELECT x;
+        |    IF x > 2
+        |    THEN
+        |     LEAVE lbl;
+        |    END IF;
+        |  END LOOP;
+        |  SELECT x;
+        |END""".stripMargin
+    val expected = Seq(
+      Seq.empty[Row], // declare
+      Seq.empty[Row], // set x = 0
+      Seq.empty[Row], // set x = 1
+      Seq(Row(1)), // select x
+      Seq.empty[Row], // set x = 2
+      Seq(Row(2)), // select x
+      Seq.empty[Row], // set x = 3
+      Seq(Row(3)), // select x
+      Seq(Row(3)), // select x
+      Seq.empty[Row] // drop
+    )
+    verifySqlScriptResult(sqlScriptText, expected)
+  }
+
+  test("nested loop statement with leave") {
+    val commands =
+      """
+        |BEGIN
+        | DECLARE x = 0;
+        | DECLARE y = 0;
+        | lbl1: LOOP
+        |   SET VAR y = 0;
+        |   lbl2: LOOP
+        |     SELECT x, y;
+        |     SET VAR y = y + 1;
+        |     IF y >= 2 THEN
+        |       LEAVE lbl2;
+        |     END IF;
+        |   END LOOP;
+        |   SET VAR x = x + 1;
+        |   IF x >= 2 THEN
+        |     LEAVE lbl1;
+        |   END IF;
+        | END LOOP;
+        |END
+        |""".stripMargin
+
+    val expected = Seq(
+      Seq.empty[Row], // declare x
+      Seq.empty[Row], // declare y
+      Seq.empty[Row], // set y to 0
+      Seq(Row(0, 0)), // select x, y
+      Seq.empty[Row], // increase y
+      Seq(Row(0, 1)), // select x, y
+      Seq.empty[Row], // increase y
+      Seq.empty[Row], // increase x
+      Seq.empty[Row], // set y to 0
+      Seq(Row(1, 0)), // select x, y
+      Seq.empty[Row], // increase y
+      Seq(Row(1, 1)), // select x, y
+      Seq.empty[Row], // increase y
+      Seq.empty[Row], // increase x
+      Seq.empty[Row], // drop y
+      Seq.empty[Row] // drop x
+    )
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("iterate loop statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  DECLARE x INT;
+        |  SET x = 0;
+        |  lbl: LOOP
+        |    SET x = x + 1;
+        |    IF x > 1 THEN
+        |     LEAVE lbl;
+        |    END IF;
+        |    ITERATE lbl;
+        |    SET x = x + 2;
+        |  END LOOP;
+        |  SELECT x;
+        |END""".stripMargin
+    val expected = Seq(
+      Seq.empty[Row], // declare
+      Seq.empty[Row], // set x = 0
+      Seq.empty[Row], // set x = 1
+      Seq.empty[Row], // set x = 2
+      Seq(Row(2)), // select x
+      Seq.empty[Row] // drop
+    )
+    verifySqlScriptResult(sqlScriptText, expected)
+  }
+
+  test("leave outer loop from nested loop statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  lbl: LOOP
+        |    lbl2: LOOP
+        |      SELECT 1;
+        |      LEAVE lbl;
+        |    END LOOP;
+        |  END LOOP;
+        |END""".stripMargin
+    val expected = Seq(
+      Seq(Row(1)) // select 1
+    )
+    verifySqlScriptResult(sqlScriptText, expected)
+  }
+
+  test("iterate outer loop from nested loop statement") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  DECLARE x INT;
+        |  SET x = 0;
+        |  lbl: LOOP
+        |    SET x = x + 1;
+        |    IF x > 2 THEN
+        |     LEAVE lbl;
+        |    END IF;
+        |    lbl2: LOOP
+        |      SELECT 1;
+        |      ITERATE lbl;
+        |      SET x = 10;
+        |    END LOOP;
+        |  END LOOP;
+        |  SELECT x;
+        |END""".stripMargin
+    val expected = Seq(
+      Seq.empty[Row], // declare
+      Seq.empty[Row], // set x = 0
+      Seq.empty[Row], // set x = 1
+      Seq(Row(1)), // select 1
+      Seq.empty[Row], // set x = 2
+      Seq(Row(1)), // select 1
+      Seq.empty[Row], // set x = 3
+      Seq(Row(3)), // select x
+      Seq.empty[Row] // drop
+    )
+    verifySqlScriptResult(sqlScriptText, expected)
+  }
 }
diff --git 
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
 
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
index dcf3bd8c7173..60c49619552e 100644
--- 
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
+++ 
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
@@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends 
SharedThriftServer {
       val sessionHandle = client.openSession(user, "")
       val infoValue = client.getInfo(sessionHandle, 
GetInfoType.CLI_ODBC_KEYWORDS)
       // scalastyle:off line.size.limit
-      assert(infoValue.getStringValue == 
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURREN
 [...]
+      assert(infoValue.getStringValue == 
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURREN
 [...]
       // scalastyle:on line.size.limit
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to