This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 80cd86758a06 [SPARK-48353][FOLLOW UP][SQL] Exception handling 
improvements
80cd86758a06 is described below

commit 80cd86758a06dc466101eeec37669d19a85fc92a
Author: David Milicevic <[email protected]>
AuthorDate: Thu May 29 21:14:37 2025 +0800

    [SPARK-48353][FOLLOW UP][SQL] Exception handling improvements
    
    ### What changes were proposed in this pull request?
    
    This pull request proposes an improvement in how conditions are matched 
with exception handlers. Previously, condition would match with an exception 
handler only if the match was full/complete - i.e. condition 
`<MAIN>.<SUBCLASS>` would match only to the handler defined for 
`<MAIN>.<SUBCLASS>`. With this improvement, `<MAIN>.<SUBCLASS>` condition would 
match to the handlers defined for `<MAIN>` condition as well, with correct 
precedence.
    
    This pull requests also proposes a number of fixes for different missing 
things:
    - `CompoundBodyExec.reset()` is not resetting the `scopeStatus` field.
    - Exception handler body is never reset (this includes internal iterator 
reset). This causes issues if the handler is defined in a loop and gets matched 
multiple times.
    - When searching for a place to inject `LEAVE` statement, after the `EXIT` 
handler has been executed, the logic considered only `CompoundBodyExec` nodes, 
whereas it should have included all non-leaf statements.
    - Exception handling for exceptions that happen in conditions (for each 
statement type - if/else, while, case, etc) is not working properly because the 
injected `LEAVE` statement is not recognized properly.
    - Exception handling for exceptions that happen in the last statement of 
the body of if/else and searched case statement is not working properly because 
the injected `LEAVE` statement is not recognized properly.
    - `hasNext()` in `ForStatementExec` is executing the query. This causes the 
issues in cases when `FOR` is the first statement in the compound and the query 
fails. It means that exception would happen during the `hasNext()` checks 
instead of the actual iteration. In such case, the parent compound (of the 
`FOR` statement) is not entered (because that happens during the iteration) and 
call stack is not properly setup for exception handling.
    
    Changes corresponding to this problems, in the same order:
    - Reset the `scopeStatus` field in `CompoundBodyExec.reset()`.
    - Call `reset()` on handler before starting its execution in 
`SqlScriptingExecution.handleException()`.
    - Move `curr` pointer to `NonLeafStatementExec` and use that in 
`SqlScriptingExecution. getNextStatement()` in case when `EXIT` handler is 
finished and removed from the stack.
    - Add special case handling for `LeaveStatementExec` in `*.Condition` cases 
in all of the relevant statement types. When exception happens during the 
condition execution, the `LeaveStatementExec` is injected into `curr` field, 
but the state hasn't changed. This means that when the `LEAVE` statement is to 
be executed, the state would correspond to the condition. **I don't know how to 
do this better, so any suggestions are more than welcome!**
    - Add special case handling for `LeaveStatementExec` in `*Body` cases of 
if/else and searched case statement - equivalent to the previous bullet.
    - Reorder conditions in `ForStatementExec.hasNext()`.
    
    ### Why are the changes needed?
    
    These changes are fixing wrong logic and improving some of the already 
existing exception handling mechanisms.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests are added for to test/guard all of the introduced 
improvements.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51034 from davidm-db/exception_handling_improvements.
    
    Authored-by: David Milicevic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../plans/logical/SqlScriptingLogicalPlans.scala   |   1 +
 .../sql/scripting/SqlScriptingExecution.scala      |  24 +-
 .../scripting/SqlScriptingExecutionContext.scala   |   9 +
 .../sql/scripting/SqlScriptingExecutionNode.scala  | 319 ++++++++++-----
 .../sql/scripting/SqlScriptingExecutionSuite.scala | 454 +++++++++++++++++++++
 5 files changed, 696 insertions(+), 111 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
index 2073b296a0dc..fdd5524b536e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
@@ -292,6 +292,7 @@ case class SimpleCaseStatement(
     conditionExpressions: Seq[Expression],
     conditionalBodies: Seq[CompoundBody],
     elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
+  assert(conditionExpressions.nonEmpty)
   assert(conditionExpressions.length == conditionalBodies.length)
 
   override def output: Seq[Attribute] = Seq.empty
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
index 362f3e51d3df..826b7a8834cf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
@@ -71,6 +71,21 @@ class SqlScriptingExecution(
     contextManagerHandle.runWith(f)
   }
 
+  /**
+   * Helper method to inject leave statement into the execution plan.
+   * @param executionPlan Execution plan to inject leave statement into.
+   * @param label Label of the leave statement.
+   */
+  private def injectLeaveStatement(executionPlan: NonLeafStatementExec, label: 
String): Unit = {
+    // Go as deep as possible, to find a leaf node. Instead of a statement that
+    //   should be executed next, inject LEAVE statement in its place.
+    var currExecPlan = executionPlan
+    while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) {
+      currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec]
+    }
+    currExecPlan.curr = Some(new LeaveStatementExec(label))
+  }
+
   /** Helper method to iterate get next statements from the first available 
frame. */
   private def getNextStatement: Option[CompoundStatementExec] = {
     // Remove frames that are already executed.
@@ -94,12 +109,8 @@ class SqlScriptingExecution(
           && lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) {
           context.firstHandlerScopeLabel = None
         }
-
-        var execPlan: CompoundBodyExec = context.frames.last.executionPlan
-        while (execPlan.curr.exists(_.isInstanceOf[CompoundBodyExec])) {
-          execPlan = execPlan.curr.get.asInstanceOf[CompoundBodyExec]
-        }
-        execPlan.curr = Some(new LeaveStatementExec(lastFrame.scopeLabel.get))
+        // Inject leave statement into the execution plan of the last frame.
+        injectLeaveStatement(context.frames.last.executionPlan, 
lastFrame.scopeLabel.get)
       }
     }
     // If there are still frames available, get the next statement.
@@ -164,6 +175,7 @@ class SqlScriptingExecution(
         context.frames.append(
           handlerFrame
         )
+        handler.reset()
         handlerFrame.executionPlan.enterScope()
       case None =>
         throw e.asInstanceOf[Throwable]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
index bfd5a4b43711..e1c139addd34 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
@@ -205,6 +205,15 @@ class SqlScriptingExecutionScope(
 
     errorHandler = 
triggerToExceptionHandlerMap.getHandlerForCondition(uppercaseCondition)
 
+    if (errorHandler.isEmpty) {
+      if (uppercaseCondition.contains('.')) {
+        // If the condition contains a dot, it has a main error class and a 
subclass.
+        // Check if the error class is defined in the 
triggerToExceptionHandlerMap.
+        val errorClass = uppercaseCondition.split('.').head
+        errorHandler = 
triggerToExceptionHandlerMap.getHandlerForCondition(errorClass)
+      }
+    }
+
     if (errorHandler.isEmpty) {
       // Check if there is a specific handler for the given SQLSTATE.
       errorHandler = 
triggerToExceptionHandlerMap.getHandlerForSqlState(uppercaseSqlState)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index 73b463cdd249..c95ef72a2b31 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -58,6 +58,9 @@ trait LeafStatementExec extends CompoundStatementExec
  */
 trait NonLeafStatementExec extends CompoundStatementExec {
 
+  /** Pointer to the current statement - i.e. the statement that should be 
iterated next. */
+  protected[scripting] var curr: Option[CompoundStatementExec]
+
   /**
    * Construct the iterator to traverse the tree rooted at this node in an 
in-order traversal.
    * @return
@@ -244,7 +247,7 @@ class CompoundBodyExec(
   }
 
   private var localIterator = statements.iterator
-  private[scripting] var curr: Option[CompoundStatementExec] =
+  protected[scripting] var curr: Option[CompoundStatementExec] =
     if (localIterator.hasNext) Some(localIterator.next()) else None
   private var scopeStatus = ScopeStatus.NOT_ENTERED
 
@@ -406,7 +409,7 @@ class IfElseStatementExec(
   }
 
   private var state = IfElseState.Condition
-  private var curr: Option[CompoundStatementExec] = Some(conditions.head)
+  protected[scripting] var curr: Option[CompoundStatementExec] = 
Some(conditions.head)
 
   private var clauseIdx: Int = 0
   private val conditionsCount = conditions.length
@@ -416,36 +419,47 @@ class IfElseStatementExec(
     new Iterator[CompoundStatementExec] {
       override def hasNext: Boolean = curr.nonEmpty
 
-      override def next(): CompoundStatementExec = state match {
-        case IfElseState.Condition =>
-          val condition = curr.get.asInstanceOf[SingleStatementExec]
-          if (evaluateBooleanCondition(session, condition)) {
-            state = IfElseState.Body
-            curr = Some(conditionalBodies(clauseIdx))
-          } else {
-            clauseIdx += 1
-            if (clauseIdx < conditionsCount) {
-              // There are ELSEIF clauses remaining.
-              state = IfElseState.Condition
-              curr = Some(conditions(clauseIdx))
-            } else if (elseBody.isDefined) {
-              // ELSE clause exists.
+      override def next(): CompoundStatementExec = {
+        if (curr.exists(_.isInstanceOf[LeaveStatementExec])) {
+          // Handling two cases when an exception is thrown:
+          //   1. During condition evaluation - exception handling mechanism 
will replace condition
+          //     with the appropriate LEAVE statement if the relevant 
condition handler was found.
+          //   2. In the last statement of the body - curr would already be 
set to None when
+          //     LEAVE statement is injected to it (i.e. LEAVE statement would 
replace None).
+          return curr.get
+        }
+
+        state match {
+          case IfElseState.Condition =>
+            val condition = curr.get.asInstanceOf[SingleStatementExec]
+            if (evaluateBooleanCondition(session, condition)) {
               state = IfElseState.Body
-              curr = Some(elseBody.get)
+              curr = Some(conditionalBodies(clauseIdx))
             } else {
-              // No remaining clauses.
+              clauseIdx += 1
+              if (clauseIdx < conditionsCount) {
+                // There are ELSEIF clauses remaining.
+                state = IfElseState.Condition
+                curr = Some(conditions(clauseIdx))
+              } else if (elseBody.isDefined) {
+                // ELSE clause exists.
+                state = IfElseState.Body
+                curr = Some(elseBody.get)
+              } else {
+                // No remaining clauses.
+                curr = None
+              }
+            }
+            condition
+          case IfElseState.Body =>
+            assert(curr.get.isInstanceOf[CompoundBodyExec])
+            val currBody = curr.get.asInstanceOf[CompoundBodyExec]
+            val retStmt = currBody.getTreeIterator.next()
+            if (!currBody.getTreeIterator.hasNext) {
               curr = None
             }
-          }
-          condition
-        case IfElseState.Body =>
-          assert(curr.get.isInstanceOf[CompoundBodyExec])
-          val currBody = curr.get.asInstanceOf[CompoundBodyExec]
-          val retStmt = currBody.getTreeIterator.next()
-          if (!currBody.getTreeIterator.hasNext) {
-            curr = None
-          }
-          retStmt
+            retStmt
+        }
       }
     }
 
@@ -479,7 +493,7 @@ class WhileStatementExec(
   }
 
   private var state = WhileState.Condition
-  private var curr: Option[CompoundStatementExec] = Some(condition)
+  protected[scripting] var curr: Option[CompoundStatementExec] = 
Some(condition)
 
   private lazy val treeIterator: Iterator[CompoundStatementExec] =
     new Iterator[CompoundStatementExec] {
@@ -487,25 +501,32 @@ class WhileStatementExec(
 
       override def next(): CompoundStatementExec = state match {
           case WhileState.Condition =>
-            val condition = curr.get.asInstanceOf[SingleStatementExec]
-            if (evaluateBooleanCondition(session, condition)) {
-              state = WhileState.Body
-              curr = Some(body)
-              body.reset()
-            } else {
-              curr = None
+            curr match {
+              case Some(leaveStatement: LeaveStatementExec) =>
+                // Handling the case when condition evaluation throws an 
exception. Exception
+                //   handling mechanism will replace condition with the 
appropriate LEAVE statement
+                //   if the relevant condition handler was found.
+                handleLeaveStatement(leaveStatement)
+                leaveStatement
+              case Some(condition: SingleStatementExec) =>
+                if (evaluateBooleanCondition(session, condition)) {
+                  state = WhileState.Body
+                  curr = Some(body)
+                  body.reset()
+                } else {
+                  curr = None
+                }
+                condition
+              case _ =>
+                throw SparkException.internalError("Unexpected statement type 
in WHILE condition.")
             }
-            condition
           case WhileState.Body =>
             val retStmt = body.getTreeIterator.next()
 
             // Handle LEAVE or ITERATE statement if it has been encountered.
             retStmt match {
               case leaveStatementExec: LeaveStatementExec if 
!leaveStatementExec.hasBeenMatched =>
-                if (label.contains(leaveStatementExec.label)) {
-                  leaveStatementExec.hasBeenMatched = true
-                }
-                curr = None
+                handleLeaveStatement(leaveStatementExec)
                 return retStmt
               case iterStatementExec: IterateStatementExec if 
!iterStatementExec.hasBeenMatched =>
                 if (label.contains(iterStatementExec.label)) {
@@ -535,6 +556,13 @@ class WhileStatementExec(
     condition.reset()
     body.reset()
   }
+
+  private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit = 
{
+    if (label.contains(leaveStatement.label)) {
+      leaveStatement.hasBeenMatched = true
+    }
+    curr = None
+  }
 }
 
 /**
@@ -555,7 +583,7 @@ class SearchedCaseStatementExec(
   }
 
   private var state = CaseState.Condition
-  private var curr: Option[CompoundStatementExec] = Some(conditions.head)
+  protected[scripting] var curr: Option[CompoundStatementExec] = 
Some(conditions.head)
 
   private var clauseIdx: Int = 0
   private val conditionsCount = conditions.length
@@ -564,36 +592,47 @@ class SearchedCaseStatementExec(
     new Iterator[CompoundStatementExec] {
       override def hasNext: Boolean = curr.nonEmpty
 
-      override def next(): CompoundStatementExec = state match {
-        case CaseState.Condition =>
-          val condition = curr.get.asInstanceOf[SingleStatementExec]
-          if (evaluateBooleanCondition(session, condition)) {
-            state = CaseState.Body
-            curr = Some(conditionalBodies(clauseIdx))
-          } else {
-            clauseIdx += 1
-            if (clauseIdx < conditionsCount) {
-              // There are WHEN clauses remaining.
-              state = CaseState.Condition
-              curr = Some(conditions(clauseIdx))
-            } else if (elseBody.isDefined) {
-              // ELSE clause exists.
+      override def next(): CompoundStatementExec = {
+        if (curr.exists(_.isInstanceOf[LeaveStatementExec])) {
+          // Handling two cases when an exception is thrown:
+          //   1. During condition evaluation - exception handling mechanism 
will replace condition
+          //     with the appropriate LEAVE statement if the relevant 
condition handler was found.
+          //   2. In the last statement of the body - curr would already be 
set to None when
+          //     LEAVE statement is injected to it (i.e. LEAVE statement would 
replace None).
+          return curr.get
+        }
+
+        state match {
+          case CaseState.Condition =>
+            val condition = curr.get.asInstanceOf[SingleStatementExec]
+            if (evaluateBooleanCondition(session, condition)) {
               state = CaseState.Body
-              curr = Some(elseBody.get)
+              curr = Some(conditionalBodies(clauseIdx))
             } else {
-              // No remaining clauses.
+              clauseIdx += 1
+              if (clauseIdx < conditionsCount) {
+                // There are WHEN clauses remaining.
+                state = CaseState.Condition
+                curr = Some(conditions(clauseIdx))
+              } else if (elseBody.isDefined) {
+                // ELSE clause exists.
+                state = CaseState.Body
+                curr = Some(elseBody.get)
+              } else {
+                // No remaining clauses.
+                curr = None
+              }
+            }
+            condition
+          case CaseState.Body =>
+            assert(curr.get.isInstanceOf[CompoundBodyExec])
+            val currBody = curr.get.asInstanceOf[CompoundBodyExec]
+            val retStmt = currBody.getTreeIterator.next()
+            if (!currBody.getTreeIterator.hasNext) {
               curr = None
             }
-          }
-          condition
-        case CaseState.Body =>
-          assert(curr.get.isInstanceOf[CompoundBodyExec])
-          val currBody = curr.get.asInstanceOf[CompoundBodyExec]
-          val retStmt = currBody.getTreeIterator.next()
-          if (!currBody.getTreeIterator.hasNext) {
-            curr = None
-          }
-          retStmt
+            retStmt
+        }
       }
     }
 
@@ -633,6 +672,8 @@ class SimpleCaseStatementExec(
   private var state = CaseState.Condition
   private var bodyExec: Option[CompoundBodyExec] = None
 
+  protected[scripting] var curr: Option[CompoundStatementExec] = None
+
   private var conditionBodyTupleIterator: Iterator[(SingleStatementExec, 
CompoundBodyExec)] = _
   private var caseVariableLiteral: Literal = _
 
@@ -661,22 +702,45 @@ class SimpleCaseStatementExec(
   private lazy val treeIterator: Iterator[CompoundStatementExec] =
     new Iterator[CompoundStatementExec] {
       override def hasNext: Boolean = state match {
-        case CaseState.Condition => cachedConditionBodyIterator.hasNext || 
elseBody.isDefined
+        case CaseState.Condition =>
+          // Equivalent to the "iteration hasn't started yet" - to avoid 
computing cache
+          //   before the first actual iteration.
+          curr.isEmpty ||
+          // Special case when condition computation throws an exception.
+          curr.exists(_.isInstanceOf[LeaveStatementExec]) ||
+          // Regular conditions.
+          cachedConditionBodyIterator.hasNext ||
+          elseBody.isDefined
         case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext)
       }
 
       override def next(): CompoundStatementExec = state match {
         case CaseState.Condition =>
-          cachedConditionBodyIterator.nextOption()
+          if (curr.exists(_.isInstanceOf[LeaveStatementExec])) {
+            // Handling the case when condition evaluation throws an 
exception. Exception handling
+            //   mechanism will replace condition with the appropriate LEAVE 
statement if the
+            //   relevant condition handler was found.
+            return curr.get
+          }
+
+          val nextOption = if (cachedConditionBodyIterator.hasNext) {
+            Some(cachedConditionBodyIterator.next())
+          } else {
+            None
+          }
+          nextOption
             .map { case (condStmt, body) =>
+              curr = Some(condStmt)
               if (evaluateBooleanCondition(session, condStmt)) {
                 bodyExec = Some(body)
+                curr = bodyExec
                 state = CaseState.Body
               }
               condStmt
             }
             .orElse(elseBody.map { body => {
               bodyExec = Some(body)
+              curr = bodyExec
               state = CaseState.Body
               next()
             }})
@@ -714,6 +778,8 @@ class SimpleCaseStatementExec(
 
   override def reset(): Unit = {
     state = CaseState.Condition
+    bodyExec = None
+    curr = None
     isCacheValid = false
     caseVariableExec.reset()
     conditionalBodies.foreach(b => b.reset())
@@ -740,7 +806,7 @@ class RepeatStatementExec(
   }
 
   private var state = RepeatState.Body
-  private var curr: Option[CompoundStatementExec] = Some(body)
+  protected[scripting] var curr: Option[CompoundStatementExec] = Some(body)
 
   private lazy val treeIterator: Iterator[CompoundStatementExec] =
     new Iterator[CompoundStatementExec] {
@@ -748,24 +814,31 @@ class RepeatStatementExec(
 
       override def next(): CompoundStatementExec = state match {
         case RepeatState.Condition =>
-          val condition = curr.get.asInstanceOf[SingleStatementExec]
-          if (!evaluateBooleanCondition(session, condition)) {
-            state = RepeatState.Body
-            curr = Some(body)
-            body.reset()
-          } else {
-            curr = None
+          curr match {
+            case Some(leaveStatement: LeaveStatementExec) =>
+              // Handling the case when condition evaluation throws an 
exception. Exception
+              //   handling mechanism will replace condition with the 
appropriate LEAVE statement
+              //   if the relevant condition handler was found.
+              handleLeaveStatement(leaveStatement)
+              leaveStatement
+            case Some(condition: SingleStatementExec) =>
+              if (!evaluateBooleanCondition(session, condition)) {
+                state = RepeatState.Body
+                curr = Some(body)
+                body.reset()
+              } else {
+                curr = None
+              }
+              condition
+            case _ =>
+              throw SparkException.internalError("Unexpected statement type in 
REPEAT condition.")
           }
-          condition
         case RepeatState.Body =>
           val retStmt = body.getTreeIterator.next()
 
           retStmt match {
             case leaveStatementExec: LeaveStatementExec if 
!leaveStatementExec.hasBeenMatched =>
-              if (label.contains(leaveStatementExec.label)) {
-                leaveStatementExec.hasBeenMatched = true
-              }
-              curr = None
+              handleLeaveStatement(leaveStatementExec)
               return retStmt
             case iterStatementExec: IterateStatementExec if 
!iterStatementExec.hasBeenMatched =>
               if (label.contains(iterStatementExec.label)) {
@@ -795,6 +868,13 @@ class RepeatStatementExec(
     body.reset()
     condition.reset()
   }
+
+  private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit = 
{
+    if (label.contains(leaveStatement.label)) {
+      leaveStatement.hasBeenMatched = true
+    }
+    curr = None
+  }
 }
 
 /**
@@ -846,6 +926,8 @@ class LoopStatementExec(
     body: CompoundBodyExec,
     val label: Option[String]) extends NonLeafStatementExec {
 
+  protected[scripting] var curr: Option[CompoundStatementExec] = Some(body)
+
   /**
    * Loop can be interrupted by LeaveStatementExec
    */
@@ -927,7 +1009,9 @@ class ForStatementExec(
     queryResult
   }
 
-  private var bodyWithVariables: CompoundBodyExec = null
+  protected[scripting] var curr: Option[CompoundStatementExec] = None
+
+  private var bodyWithVariables: Option[CompoundBodyExec] = None
 
   /**
    * For can be interrupted by LeaveStatementExec
@@ -943,15 +1027,29 @@ class ForStatementExec(
     new Iterator[CompoundStatementExec] {
 
       override def hasNext: Boolean = !interrupted && (state match {
-          case ForState.VariableAssignment => cachedQueryResult().hasNext || 
firstIteration
-          case ForState.Body => bodyWithVariables.getTreeIterator.hasNext
-        })
+        // `firstIteration` NEEDS to be the first condition! This is to handle 
edge-cases when
+        //   query fails with an exception. If the 
`cachedQueryResult().hasNext` is first, this
+        //   would mean that exception would be thrown before the scope of the 
parent (which is
+        //   of CompoundBodyExec type) of the FOR statement is entered 
(required for proper
+        //   exception handling). This can happen in a case when FOR statement 
is a first
+        //   statement in the compound.
+        case ForState.VariableAssignment => firstIteration || 
cachedQueryResult().hasNext
+        case ForState.Body => 
bodyWithVariables.exists(_.getTreeIterator.hasNext)
+      })
 
-      @scala.annotation.tailrec
       override def next(): CompoundStatementExec = state match {
 
         case ForState.VariableAssignment =>
-          // If result set is empty and we are on the first iteration, we 
return NO-OP statement
+          if (curr.exists(_.isInstanceOf[LeaveStatementExec])) {
+            // Handling the case when condition evaluation throws an 
exception. Exception handling
+            //   mechanism will replace condition with the appropriate LEAVE 
statement if the
+            //   relevant condition handler was found.
+            val leaveStatement = curr.get.asInstanceOf[LeaveStatementExec]
+            handleLeaveStatement(leaveStatement)
+            return leaveStatement
+          }
+
+          // If result set is empty, and we are on the first iteration, we 
return NO-OP statement
           // to prevent compound statements from not having anything to 
return. For example,
           // if a FOR statement is nested in REPEAT, REPEAT will assume that 
FOR has at least
           // one statement to return. In the case the result set is empty, FOR 
doesn't have
@@ -979,7 +1077,7 @@ class ForStatementExec(
             }
           ).orElse(Some(UUID.randomUUID().toString.toLowerCase(Locale.ROOT)))
 
-          bodyWithVariables = new CompoundBodyExec(
+          bodyWithVariables = Some(new CompoundBodyExec(
             // NoOpStatementExec appended to end of body to prevent
             // dropping variables before last statement is executed.
             // This is necessary because we are calling exitScope before 
returning the last
@@ -991,26 +1089,23 @@ class ForStatementExec(
             isScope = true,
             context = context,
             triggerToExceptionHandlerMap = 
TriggerToExceptionHandlerMap.createEmptyMap()
-          )
+          ))
 
           state = ForState.Body
-          bodyWithVariables.reset()
-          bodyWithVariables.enterScope()
+          bodyWithVariables.foreach(_.reset())
+          bodyWithVariables.foreach(_.enterScope())
+          curr = bodyWithVariables
           next()
 
         case ForState.Body =>
-          val retStmt = bodyWithVariables.getTreeIterator.next()
+          // `bodyWithVariables` must be defined at this point.
+          assert(bodyWithVariables.isDefined)
+          val retStmt = bodyWithVariables.get.getTreeIterator.next()
 
           // Handle LEAVE or ITERATE statement if it has been encountered.
           retStmt match {
             case leaveStatementExec: LeaveStatementExec if 
!leaveStatementExec.hasBeenMatched =>
-              if (label.contains(leaveStatementExec.label)) {
-                leaveStatementExec.hasBeenMatched = true
-              }
-              interrupted = true
-              // If this for statement encounters LEAVE, we need to exit the 
scope, as
-              // we will not reach the point where we usually exit it.
-              bodyWithVariables.exitScope()
+              handleLeaveStatement(leaveStatementExec)
               return retStmt
             case iterStatementExec: IterateStatementExec if 
!iterStatementExec.hasBeenMatched =>
               if (label.contains(iterStatementExec.label)) {
@@ -1018,21 +1113,32 @@ class ForStatementExec(
               } else {
                 // If an outer loop is being iterated, we need to exit the 
scope, as
                 // we will not reach the point where we usually exit it.
-                bodyWithVariables.exitScope()
+                bodyWithVariables.foreach(_.exitScope())
               }
               state = ForState.VariableAssignment
               return retStmt
             case _ =>
           }
 
-          if (!bodyWithVariables.getTreeIterator.hasNext) {
-            bodyWithVariables.exitScope()
+          if (!bodyWithVariables.exists(_.getTreeIterator.hasNext)) {
+            bodyWithVariables.foreach(_.exitScope())
+            curr = None
             state = ForState.VariableAssignment
           }
           retStmt
       }
     }
 
+    private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit 
= {
+      if (label.contains(leaveStatement.label)) {
+        leaveStatement.hasBeenMatched = true
+      }
+      interrupted = true
+      // If this for statement encounters LEAVE, we need to exit the scope, as
+      // we will not reach the point where we usually exit it.
+      bodyWithVariables.foreach(_.exitScope())
+    }
+
   /**
    * Recursively creates a Catalyst expression from Scala value.<br>
    * See https://spark.apache.org/docs/latest/sql-ref-datatypes.html for Spark 
-> Scala mappings
@@ -1093,7 +1199,8 @@ class ForStatementExec(
     state = ForState.VariableAssignment
     isResultCacheValid = false
     interrupted = false
-    bodyWithVariables = null
+    curr = None
+    bodyWithVariables = None
     firstIteration = true
   }
 }
@@ -1109,6 +1216,8 @@ class ExceptionHandlerExec(
     val handlerType: ExceptionHandlerType,
     val scopeLabel: Option[String]) extends NonLeafStatementExec {
 
+  protected[scripting] var curr: Option[CompoundStatementExec] = body.curr
+
   override def getTreeIterator: Iterator[CompoundStatementExec] = 
body.getTreeIterator
 
   override def reset(): Unit = body.reset()
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
index 68c6d9607d32..3c0bb4020419 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
@@ -625,6 +625,264 @@ class SqlScriptingExecutionSuite extends QueryTest with 
SharedSparkSession {
       parameters = Map("sqlState" -> "X22012"))
   }
 
+  test("handler - correct handler is chosen based on the full error 
condition") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE EXIT HANDLER FOR UNRESOLVED_COLUMN
+        |  BEGIN
+        |    SELECT 1;
+        |  END;
+        |  DECLARE EXIT HANDLER FOR UNRESOLVED_COLUMN.WITHOUT_SUGGESTION
+        |  BEGIN
+        |    SELECT 2;
+        |  END;
+        |  DECLARE EXIT HANDLER FOR UNRESOLVED_COLUMN.WITH_SUGGESTION
+        |  BEGIN
+        |    SELECT 3;
+        |  END;
+        |  SELECT X;
+        |END
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(2)) // select
+    )
+    verifySqlScriptResult(sqlScript, expected)
+  }
+
+  test("handler - full condition takes precedence before main error class") {
+    withTable("t") {
+      val sqlScript =
+        """
+          |BEGIN
+          |  DECLARE EXIT HANDLER FOR UNRESOLVED_COLUMN
+          |  BEGIN
+          |    SELECT 1;
+          |  END;
+          |  DECLARE EXIT HANDLER FOR UNRESOLVED_COLUMN.WITH_SUGGESTION
+          |  BEGIN
+          |    SELECT 2;
+          |  END;
+          |  CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+          |  SELECT d FROM t;
+          |END
+          |""".stripMargin
+      val expected = Seq(
+        Seq(Row(2)) // select
+      )
+      verifySqlScriptResult(sqlScript, expected)
+    }
+  }
+
+  test("handler - catch the main error class without subclass") {
+    withTable("t") {
+      val sqlScript =
+        """
+          |BEGIN
+          |  DECLARE EXIT HANDLER FOR UNRESOLVED_COLUMN
+          |  BEGIN
+          |    SELECT 1;
+          |  END;
+          |  CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+          |  SELECT d FROM t;
+          |END
+          |""".stripMargin
+      val expected = Seq(
+        Seq(Row(1)) // select
+      )
+      verifySqlScriptResult(sqlScript, expected)
+    }
+  }
+
+  test("handler - exit resolve when if condition fails") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE VARIABLE flag INT = -1;
+        |  scope_to_exit: BEGIN
+        |    DECLARE EXIT HANDLER FOR SQLSTATE '22012'
+        |    BEGIN
+        |      SELECT flag;
+        |      SET flag = 1;
+        |    END;
+        |    IF 1 > 1/0 THEN
+        |      SELECT 10;
+        |    END IF;
+        |    SELECT 4;
+        |    SELECT 5;
+        |  END;
+        |  SELECT flag;
+        |END
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(-1)),   // select flag
+      Seq(Row(1))     // select flag from the outer body
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("handler - exit resolve when simple case variable computation fails") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE VARIABLE flag INT = -1;
+        |  scope_to_exit: BEGIN
+        |    DECLARE EXIT HANDLER FOR SQLSTATE '22012'
+        |    BEGIN
+        |      SELECT flag;
+        |      SET flag = 1;
+        |    END;
+        |    CASE 1/0
+        |      WHEN flag THEN SELECT 10;
+        |    END CASE;
+        |    SELECT 4;
+        |    SELECT 5;
+        |  END;
+        |  SELECT flag;
+        |END
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(-1)),   // select flag
+      Seq(Row(1))     // select flag from the outer body
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("handler - exit resolve when simple case condition computation fails") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE VARIABLE flag INT = -1;
+        |  scope_to_exit: BEGIN
+        |    DECLARE EXIT HANDLER FOR SQLSTATE '22012'
+        |    BEGIN
+        |      SELECT flag;
+        |      SET flag = 1;
+        |    END;
+        |    CASE flag
+        |      WHEN 1/0 THEN SELECT 10;
+        |    END CASE;
+        |    SELECT 4;
+        |    SELECT 5;
+        |  END;
+        |  SELECT flag;
+        |END
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(-1)),   // select flag
+      Seq(Row(1))     // select flag from the outer body
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("handler - exit resolve when simple case condition types are mismatch") 
{
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE VARIABLE flag INT = -1;
+        |  scope_to_exit: BEGIN
+        |    DECLARE EXIT HANDLER FOR CAST_INVALID_INPUT
+        |    BEGIN
+        |      SELECT flag;
+        |      SET flag = 1;
+        |    END;
+        |    CASE flag
+        |      WHEN 'teststr' THEN SELECT 10;
+        |    END CASE;
+        |    SELECT 4;
+        |    SELECT 5;
+        |  END;
+        |  SELECT flag;
+        |END
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(-1)),   // select flag
+      Seq(Row(1))     // select flag from the outer body
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("handler - exit resolve when searched case condition fails") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE VARIABLE flag INT = -1;
+        |  scope_to_exit: BEGIN
+        |    DECLARE EXIT HANDLER FOR SQLSTATE '22012'
+        |    BEGIN
+        |      SELECT flag;
+        |      SET flag = 1;
+        |    END;
+        |    CASE
+        |      WHEN flag = 1/0 THEN SELECT 10;
+        |    END CASE;
+        |    SELECT 4;
+        |    SELECT 5;
+        |  END;
+        |  SELECT flag;
+        |END
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(-1)),   // select flag
+      Seq(Row(1))     // select flag from the outer body
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("handler - exit resolve when while condition fails") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE VARIABLE flag INT = -1;
+        |  scope_to_exit: BEGIN
+        |    DECLARE EXIT HANDLER FOR SQLSTATE '22012'
+        |    BEGIN
+        |      SELECT flag;
+        |      SET flag = 1;
+        |    END;
+        |    WHILE 1 > 1/0 DO
+        |      SELECT 10;
+        |    END WHILE;
+        |    SELECT 4;
+        |    SELECT 5;
+        |  END;
+        |  SELECT flag;
+        |END
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(-1)),   // select flag
+      Seq(Row(1))     // select flag from the outer body
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("handler - exit resolve when select fails in FOR statement") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE VARIABLE flag INT = -1;
+        |  scope_to_exit: BEGIN
+        |    DECLARE EXIT HANDLER FOR SQLSTATE '22012'
+        |    BEGIN
+        |      SELECT flag;
+        |      SET flag = 1;
+        |    END;
+        |    FOR iter AS (SELECT 1/0) DO
+        |      SELECT 10;
+        |    END FOR;
+        |    SELECT 4;
+        |    SELECT 5;
+        |  END;
+        |  SELECT flag;
+        |END
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(-1)),   // select flag
+      Seq(Row(1))     // select flag from the outer body
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
   // Tests
   test("multi statement - simple") {
     withTable("t") {
@@ -2485,4 +2743,200 @@ class SqlScriptingExecutionSuite extends QueryTest with 
SharedSparkSession {
     )
     verifySqlScriptResult(sqlScript, expected = expected)
   }
+
+  test("Exception handler in a FOR loop - with condition") {
+    withTable("t") {
+      withView("v") {
+        val sqlScript =
+          """
+            |BEGIN
+            |  DECLARE cnt = 0;
+            |  CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+            |  CREATE VIEW v AS SELECT a, b FROM t WHERE c > 0;
+            |  FOR tables AS (SELECT * FROM VALUES ('v'), ('v') AS tbl(name)) 
DO
+            |    lbl: BEGIN
+            |      DECLARE EXIT HANDLER FOR 
EXPECT_TABLE_NOT_VIEW.NO_ALTERNATIVE
+            |         BEGIN SET cnt = cnt + 1; END;
+            |      ALTER TABLE IDENTIFIER(tables.name) DEFAULT COLLATION 
UTF8_LCASE;
+            |    END;
+            |  END FOR;
+            |  SELECT cnt;
+            |END
+            |""".stripMargin
+        val expected = Seq(
+          Seq(Row(2)) // select cnt
+        )
+        verifySqlScriptResult(sqlScript, expected)
+      }
+    }
+  }
+
+  test("Exception handler in a FOR loop - with SQL state") {
+    withTable("t") {
+      withView("v") {
+        val sqlScript =
+          """
+            |BEGIN
+            |  DECLARE cnt = 0;
+            |  CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+            |  CREATE VIEW v AS SELECT a, b FROM t WHERE c > 0;
+            |  FOR tables AS (SELECT * FROM VALUES ('v'), ('v') AS tbl(name)) 
DO
+            |    BEGIN
+            |      DECLARE EXIT HANDLER FOR SQLSTATE '42809'
+            |         BEGIN SET cnt = cnt + 1; END;
+            |      ALTER TABLE IDENTIFIER(tables.name) DEFAULT COLLATION 
UTF8_LCASE;
+            |    END;
+            |  END FOR;
+            |  SELECT cnt;
+            |END
+            |""".stripMargin
+        val expected = Seq(
+          Seq(Row(2)) // select cnt
+        )
+        verifySqlScriptResult(sqlScript, expected)
+      }
+    }
+  }
+
+  test("Exception in a last statement in if/else") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE EXIT HANDLER FOR SQLEXCEPTION
+        |  BEGIN
+        |    SELECT 1;
+        |  END;
+        |  IF true THEN
+        |    SELECT 1/0;
+        |  END IF;
+        |  SELECT 2;
+        |END;
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(1))     // select 1
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("Exception in a last statement in simple case") {
+    val commands =
+      """
+        |BEGIN
+        |  DECLARE EXIT HANDLER FOR SQLEXCEPTION
+        |  BEGIN
+        |    SELECT 1;
+        |  END;
+        |  CASE 1
+        |    WHEN 1 THEN
+        |      SELECT 1/0;
+        |  END CASE;
+        |  SELECT 2;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(1)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("Exception in a last statement in searched case") {
+    val commands =
+      """
+        |BEGIN
+        |  DECLARE EXIT HANDLER FOR SQLEXCEPTION
+        |  BEGIN
+        |    SELECT 1;
+        |  END;
+        |  CASE
+        |    WHEN 1=1 THEN
+        |      SELECT 1/0;
+        |  END CASE;
+        |  SELECT 2;
+        |END
+        |""".stripMargin
+    val expected = Seq(Seq(Row(1)))
+    verifySqlScriptResult(commands, expected)
+  }
+
+  test("Exception in a last statement in - while") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE i INT DEFAULT 0;
+        |  DECLARE EXIT HANDLER FOR SQLEXCEPTION
+        |  BEGIN
+        |    SELECT 1;
+        |  END;
+        |  WHILE i < 2 DO
+        |    SELECT 1/0;
+        |  END WHILE;
+        |  SELECT 2;
+        |END;
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(1))     // select 1
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("Exception in a last statement in - repeat") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE i INT DEFAULT 0;
+        |  DECLARE EXIT HANDLER FOR SQLEXCEPTION
+        |  BEGIN
+        |    SELECT 1;
+        |  END;
+        |  REPEAT
+        |    SELECT 1/0;
+        |  UNTIL i = 2
+        |  END REPEAT;
+        |  SELECT 2;
+        |END;
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(1))     // select 1
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("Exception in a last statement in loop") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE EXIT HANDLER FOR SQLEXCEPTION
+        |  BEGIN
+        |    SELECT 1;
+        |  END;
+        |  LOOP
+        |    SELECT 1/0;
+        |  END LOOP;
+        |  SELECT 2;
+        |END;
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(1))     // select 1
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
+
+  test("Exception in a last statement in for") {
+    val sqlScript =
+      """
+        |BEGIN
+        |  DECLARE i INT DEFAULT 0;
+        |  DECLARE EXIT HANDLER FOR SQLEXCEPTION
+        |  BEGIN
+        |    SELECT 1;
+        |  END;
+        |  FOR row AS (SELECT * FROM VALUES (1), (2), (3) AS tbl(i)) DO
+        |    SELECT 1/0;
+        |  END FOR;
+        |  SELECT 2;
+        |END;
+        |""".stripMargin
+    val expected = Seq(
+      Seq(Row(1))     // select 1
+    )
+    verifySqlScriptResult(sqlScript, expected = expected)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to