This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d5729f0d52c9 [SPARK-53621][CORE] Adding Support for Executing CONTINUE
HANDLER
d5729f0d52c9 is described below
commit d5729f0d52c98fc1e62162c38c9d99037d7c17db
Author: Teodor Djelic <[email protected]>
AuthorDate: Wed Oct 1 01:18:02 2025 +0800
[SPARK-53621][CORE] Adding Support for Executing CONTINUE HANDLER
### What changes were proposed in this pull request?
- Added support for executing CONTINUE exception handlers in SQL scripting
- Extended existing exception handling framework to support both EXIT and
CONTINUE handler types
- Added interrupt capability for conditional statements to support CONTINUE
handlers
- Enhanced frame type system to distinguish between EXIT_HANDLER and
CONTINUE_HANDLER
- Updated test coverage with comprehensive CONTINUE handler scenarios
Feature is under a new feature switch
spark.sql.scripting.continueHandlerEnabled inside `SQLConfig.scala`.
### Why are the changes needed?
This is a part of PRs focused on an effort to add support for `CONTINUE
HANDLER`s. Follow-up PR will contain more tests.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- Added extensive unit tests in `SqlScriptingExecutionSuite.scala` covering
various CONTINUE handler scenarios
- Added E2E test in `SqlScriptingE2eSuite.scala` demonstrating CONTINUE
handler functionality
- Tests cover duplicate handler detection for both EXIT and CONTINUE types
- Tests verify proper execution flow continuation after exception handling
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52371 from TeodorDjelic/executing-continue-handlers.
Authored-by: Teodor Djelic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/scripting/SqlScriptingExecution.scala | 47 +-
.../scripting/SqlScriptingExecutionContext.scala | 9 +-
.../sql/scripting/SqlScriptingExecutionNode.scala | 42 +-
.../sql/scripting/SqlScriptingInterpreter.scala | 12 +-
.../spark/sql/scripting/SqlScriptingE2eSuite.scala | 43 +-
.../sql/scripting/SqlScriptingExecutionSuite.scala | 509 ++++++++++++++++++++-
6 files changed, 618 insertions(+), 44 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
index 826b7a8834cf..096ad11dd065 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkThrowable
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.SqlScriptingContextManager
import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{CommandResult,
CompoundBody, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{CommandResult,
CompoundBody, ExceptionHandlerType, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.classic.{DataFrame, SparkSession}
import org.apache.spark.sql.types.StructType
@@ -78,7 +78,7 @@ class SqlScriptingExecution(
*/
private def injectLeaveStatement(executionPlan: NonLeafStatementExec, label:
String): Unit = {
// Go as deep as possible, to find a leaf node. Instead of a statement that
- // should be executed next, inject LEAVE statement in its place.
+ // should be executed next, inject LEAVE statement in its place.
var currExecPlan = executionPlan
while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) {
currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec]
@@ -86,6 +86,27 @@ class SqlScriptingExecution(
currExecPlan.curr = Some(new LeaveStatementExec(label))
}
+ /**
+ * Helper method to execute interrupts to ConditionalStatements.
+ * This method should only interrupt when the statement that throws is a
conditional statement.
+ * @param executionPlan Execution plan.
+ */
+ private def interruptConditionalStatements(executionPlan:
NonLeafStatementExec): Unit = {
+ // Go as deep as possible into the execution plan children nodes, to find
a leaf node.
+ // That leaf node is the next statement that is to be executed. If the
parent node of that
+ // leaf node is a conditional statement, skip the conditional statement
entirely.
+ var currExecPlan = executionPlan
+ while (currExecPlan.curr.exists(_.isInstanceOf[NonLeafStatementExec])) {
+ currExecPlan = currExecPlan.curr.get.asInstanceOf[NonLeafStatementExec]
+ }
+
+ currExecPlan match {
+ case exec: ConditionalStatementExec =>
+ exec.interrupted = true
+ case _ =>
+ }
+ }
+
/** Helper method to iterate get next statements from the first available
frame. */
private def getNextStatement: Option[CompoundStatementExec] = {
// Remove frames that are already executed.
@@ -103,15 +124,29 @@ class SqlScriptingExecution(
// If the last frame is a handler, set leave statement to be the next
one in the
// innermost scope that should be exited.
- if (lastFrame.frameType == SqlScriptingFrameType.HANDLER &&
context.frames.nonEmpty) {
+ if (lastFrame.frameType == SqlScriptingFrameType.EXIT_HANDLER
+ && context.frames.nonEmpty) {
// Remove the scope if handler is executed.
if (context.firstHandlerScopeLabel.isDefined
&& lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) {
context.firstHandlerScopeLabel = None
}
+
// Inject leave statement into the execution plan of the last frame.
injectLeaveStatement(context.frames.last.executionPlan,
lastFrame.scopeLabel.get)
}
+
+ if (lastFrame.frameType == SqlScriptingFrameType.CONTINUE_HANDLER
+ && context.frames.nonEmpty) {
+ // Remove the scope if handler is executed.
+ if (context.firstHandlerScopeLabel.isDefined
+ && lastFrame.scopeLabel.get == context.firstHandlerScopeLabel.get) {
+ context.firstHandlerScopeLabel = None
+ }
+
+ // Interrupt conditional statements
+ interruptConditionalStatements(context.frames.last.executionPlan)
+ }
}
// If there are still frames available, get the next statement.
if (context.frames.nonEmpty) {
@@ -169,7 +204,11 @@ class SqlScriptingExecution(
case Some(handler) =>
val handlerFrame = new SqlScriptingExecutionFrame(
handler.body,
- SqlScriptingFrameType.HANDLER,
+ if (handler.handlerType == ExceptionHandlerType.CONTINUE) {
+ SqlScriptingFrameType.CONTINUE_HANDLER
+ } else {
+ SqlScriptingFrameType.EXIT_HANDLER
+ },
handler.scopeLabel
)
context.frames.append(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
index e1c139addd34..08ba54e6e4e4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
@@ -59,7 +59,8 @@ class SqlScriptingExecutionContext extends
SqlScriptingExecutionContextExtension
}
// If the last frame is a handler, try to find a handler in its body first.
- if (frames.last.frameType == SqlScriptingFrameType.HANDLER) {
+ if (frames.last.frameType == SqlScriptingFrameType.EXIT_HANDLER
+ || frames.last.frameType == SqlScriptingFrameType.CONTINUE_HANDLER) {
val handler = frames.last.findHandler(condition, sqlState,
firstHandlerScopeLabel)
if (handler.isDefined) {
return handler
@@ -83,7 +84,7 @@ class SqlScriptingExecutionContext extends
SqlScriptingExecutionContextExtension
object SqlScriptingFrameType extends Enumeration {
type SqlScriptingFrameType = Value
- val SQL_SCRIPT, HANDLER = Value
+ val SQL_SCRIPT, EXIT_HANDLER, CONTINUE_HANDLER = Value
}
/**
@@ -141,7 +142,9 @@ class SqlScriptingExecutionFrame(
sqlState: String,
firstHandlerScopeLabel: Option[String]): Option[ExceptionHandlerExec] = {
- val searchScopes = if (frameType == SqlScriptingFrameType.HANDLER) {
+ val searchScopes =
+ if (frameType == SqlScriptingFrameType.EXIT_HANDLER
+ || frameType == SqlScriptingFrameType.CONTINUE_HANDLER) {
// If the frame is a handler, search for the handler in its body. Don't
skip any scopes.
scopes.reverseIterator
} else if (firstHandlerScopeLabel.isEmpty) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index efc44f84cd2c..598e379c73ac 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -106,6 +106,13 @@ trait NonLeafStatementExec extends CompoundStatementExec {
}
}
+/**
+ * Conditional node in the execution tree. It is a conditional non-leaf node.
+ */
+trait ConditionalStatementExec extends NonLeafStatementExec {
+ protected[scripting] var interrupted: Boolean = false
+}
+
/**
* Executable node for SingleStatement.
* @param parsedPlan
@@ -401,7 +408,7 @@ class IfElseStatementExec(
conditions: Seq[SingleStatementExec],
conditionalBodies: Seq[CompoundBodyExec],
elseBody: Option[CompoundBodyExec],
- session: SparkSession) extends NonLeafStatementExec {
+ session: SparkSession) extends ConditionalStatementExec {
private object IfElseState extends Enumeration {
val Condition, Body = Value
}
@@ -415,7 +422,7 @@ class IfElseStatementExec(
private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
- override def hasNext: Boolean = curr.nonEmpty
+ override def hasNext: Boolean = !interrupted && curr.nonEmpty
override def next(): CompoundStatementExec = {
if (curr.exists(_.isInstanceOf[LeaveStatementExec])) {
@@ -467,6 +474,7 @@ class IfElseStatementExec(
state = IfElseState.Condition
curr = Some(conditions.head)
clauseIdx = 0
+ interrupted = false
conditions.foreach(c => c.reset())
conditionalBodies.foreach(b => b.reset())
elseBody.foreach(b => b.reset())
@@ -484,7 +492,7 @@ class WhileStatementExec(
condition: SingleStatementExec,
body: CompoundBodyExec,
label: Option[String],
- session: SparkSession) extends NonLeafStatementExec {
+ session: SparkSession) extends ConditionalStatementExec {
private object WhileState extends Enumeration {
val Condition, Body = Value
@@ -495,7 +503,7 @@ class WhileStatementExec(
private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
- override def hasNext: Boolean = curr.nonEmpty
+ override def hasNext: Boolean = !interrupted && curr.nonEmpty
override def next(): CompoundStatementExec = state match {
case WhileState.Condition =>
@@ -551,6 +559,7 @@ class WhileStatementExec(
override def reset(): Unit = {
state = WhileState.Condition
curr = Some(condition)
+ interrupted = false
condition.reset()
body.reset()
}
@@ -575,7 +584,7 @@ class SearchedCaseStatementExec(
conditions: Seq[SingleStatementExec],
conditionalBodies: Seq[CompoundBodyExec],
elseBody: Option[CompoundBodyExec],
- session: SparkSession) extends NonLeafStatementExec {
+ session: SparkSession) extends ConditionalStatementExec {
private object CaseState extends Enumeration {
val Condition, Body = Value
}
@@ -588,7 +597,7 @@ class SearchedCaseStatementExec(
private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
- override def hasNext: Boolean = curr.nonEmpty
+ override def hasNext: Boolean = !interrupted && curr.nonEmpty
override def next(): CompoundStatementExec = {
if (curr.exists(_.isInstanceOf[LeaveStatementExec])) {
@@ -640,6 +649,7 @@ class SearchedCaseStatementExec(
state = CaseState.Condition
curr = Some(conditions.head)
clauseIdx = 0
+ interrupted = false
conditions.foreach(c => c.reset())
conditionalBodies.foreach(b => b.reset())
elseBody.foreach(b => b.reset())
@@ -662,7 +672,7 @@ class SimpleCaseStatementExec(
conditionalBodies: Seq[CompoundBodyExec],
elseBody: Option[CompoundBodyExec],
session: SparkSession,
- context: SqlScriptingExecutionContext) extends NonLeafStatementExec {
+ context: SqlScriptingExecutionContext) extends ConditionalStatementExec {
private object CaseState extends Enumeration {
val Condition, Body = Value
}
@@ -699,7 +709,7 @@ class SimpleCaseStatementExec(
private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
- override def hasNext: Boolean = state match {
+ override def hasNext: Boolean = !interrupted && (state match {
case CaseState.Condition =>
// Equivalent to the "iteration hasn't started yet" - to avoid
computing cache
// before the first actual iteration.
@@ -710,7 +720,7 @@ class SimpleCaseStatementExec(
cachedConditionBodyIterator.hasNext ||
elseBody.isDefined
case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext)
- }
+ })
override def next(): CompoundStatementExec = state match {
case CaseState.Condition =>
@@ -779,6 +789,7 @@ class SimpleCaseStatementExec(
bodyExec = None
curr = None
isCacheValid = false
+ interrupted = false
caseVariableExec.reset()
conditionalBodies.foreach(b => b.reset())
elseBody.foreach(b => b.reset())
@@ -797,7 +808,7 @@ class RepeatStatementExec(
condition: SingleStatementExec,
body: CompoundBodyExec,
label: Option[String],
- session: SparkSession) extends NonLeafStatementExec {
+ session: SparkSession) extends ConditionalStatementExec {
private object RepeatState extends Enumeration {
val Condition, Body = Value
@@ -808,7 +819,7 @@ class RepeatStatementExec(
private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
- override def hasNext: Boolean = curr.nonEmpty
+ override def hasNext: Boolean = !interrupted && curr.nonEmpty
override def next(): CompoundStatementExec = state match {
case RepeatState.Condition =>
@@ -863,6 +874,7 @@ class RepeatStatementExec(
override def reset(): Unit = {
state = RepeatState.Body
curr = Some(body)
+ interrupted = false
body.reset()
condition.reset()
}
@@ -989,7 +1001,7 @@ class ForStatementExec(
statements: Seq[CompoundStatementExec],
val label: Option[String],
session: SparkSession,
- context: SqlScriptingExecutionContext) extends NonLeafStatementExec {
+ context: SqlScriptingExecutionContext) extends ConditionalStatementExec {
private object ForState extends Enumeration {
val VariableAssignment, Body = Value
@@ -1015,11 +1027,6 @@ class ForStatementExec(
private var bodyWithVariables: Option[CompoundBodyExec] = None
- /**
- * For can be interrupted by LeaveStatementExec
- */
- private var interrupted: Boolean = false
-
/**
* Whether this iteration of the FOR loop is the first one.
*/
@@ -1028,6 +1035,7 @@ class ForStatementExec(
private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
+ // Variable interrupted is being used by both EXIT and CONTINUE handlers
override def hasNext: Boolean = !interrupted && (state match {
// `firstIteration` NEEDS to be the first condition! This is to handle
edge-cases when
// query fails with an exception. If the
`cachedQueryResult().hasNext` is first, this
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index e0e11183d321..eebdb681f62c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
-import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody,
CompoundPlanStatement, ExceptionHandlerType, ForStatement, IfElseStatement,
IterateStatement, LeaveStatement, LoopStatement, OneRowRelation, Project,
RepeatStatement, SearchedCaseStatement, SimpleCaseStatement, SingleStatement,
WhileStatement}
+import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody,
CompoundPlanStatement, ForStatement, IfElseStatement, IterateStatement,
LeaveStatement, LoopStatement, OneRowRelation, Project, RepeatStatement,
SearchedCaseStatement, SimpleCaseStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.errors.SqlScriptingErrors
@@ -87,17 +87,11 @@ case class SqlScriptingInterpreter(session: SparkSession) {
args,
context)
- // Execution node of handler.
- val handlerScopeLabel = if (handler.handlerType ==
ExceptionHandlerType.EXIT) {
- Some(compoundBody.label.get)
- } else {
- None
- }
-
+ // Scope label should be Some(compoundBody.label.get) for both handler
types
val handlerExec = new ExceptionHandlerExec(
handlerBodyExec,
handler.handlerType,
- handlerScopeLabel)
+ Some(compoundBody.label.get))
// For each condition handler is defined for, add corresponding key
value pair
// to the conditionHandlerMap.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
index e24407912eb0..eb1342a780b0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
@@ -35,6 +35,17 @@ import org.apache.spark.sql.types.{IntegerType, StructField,
StructType}
* For full functionality tests, see SqlScriptingParserSuite and
SqlScriptingInterpreterSuite.
*/
class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ conf.setConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED, true)
+ }
+
+ protected override def afterAll(): Unit = {
+ conf.unsetConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED.key)
+ super.afterAll()
+ }
+
// Helpers
private def verifySqlScriptResult(
sqlText: String,
@@ -77,7 +88,7 @@ class SqlScriptingE2eSuite extends QueryTest with
SharedSparkSession {
}
}
- test("Scripting with exception handlers") {
+ test("Scripting with exit exception handlers") {
val sqlScript =
"""
|BEGIN
@@ -104,6 +115,36 @@ class SqlScriptingE2eSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, Seq(Row(2)))
}
+ test("Scripting with continue exception handlers") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE flag1 INT = -1;
+ | DECLARE flag2 INT = -1;
+ | DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO
+ | BEGIN
+ | SELECT flag1;
+ | SET flag1 = 1;
+ | END;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag1;
+ | SET flag1 = 2;
+ | END;
+ | SELECT 5;
+ | SET flag2 = 1;
+ | SELECT 1/0;
+ | SELECT 6;
+ | SET flag2 = 2;
+ | END;
+ | SELECT 7;
+ | SELECT flag1, flag2;
+ |END
+ |""".stripMargin
+ verifySqlScriptResult(sqlScript, Seq(Row(2, 2)))
+ }
+
test("single select") {
val sqlText = "SELECT 1;"
verifySqlScriptResult(sqlText, Seq(Row(1)))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
index 7e6de2b990ff..04b59e2f0817 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
@@ -36,6 +36,16 @@ import org.apache.spark.sql.test.SharedSparkSession
*/
class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ conf.setConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED, true)
+ }
+
+ protected override def afterAll(): Unit = {
+ conf.unsetConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED.key)
+ super.afterAll()
+ }
+
// Tests setup
override protected def sparkConf: SparkConf = {
super.sparkConf
@@ -74,7 +84,7 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
}
// Handler tests
- test("duplicate handler for the same condition") {
+ test("duplicate EXIT/EXIT handler for the same condition") {
val sqlScript =
"""
|BEGIN
@@ -100,7 +110,110 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
parameters = Map("condition" -> "DUPLICATE_CONDITION"))
}
- test("duplicate handler for the same sqlState") {
+ test("duplicate EXIT/CONTINUE handler for the same condition") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE duplicate_condition CONDITION FOR SQLSTATE '12345';
+ | DECLARE flag INT = -1;
+ | DECLARE EXIT HANDLER FOR duplicate_condition
+ | BEGIN
+ | SET flag = 1;
+ | END;
+ | DECLARE CONTINUE HANDLER FOR duplicate_condition
+ | BEGIN
+ | SET flag = 2;
+ | END;
+ | SELECT 1/0;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ verifySqlScriptResult(sqlScript, Seq.empty)
+ },
+ condition = "DUPLICATE_EXCEPTION_HANDLER.CONDITION",
+ parameters = Map("condition" -> "DUPLICATE_CONDITION"))
+ }
+
+ test("duplicate CONTINUE/EXIT handler for the same condition") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE duplicate_condition CONDITION FOR SQLSTATE '12345';
+ | DECLARE flag INT = -1;
+ | DECLARE CONTINUE HANDLER FOR duplicate_condition
+ | BEGIN
+ | SET flag = 1;
+ | END;
+ | DECLARE EXIT HANDLER FOR duplicate_condition
+ | BEGIN
+ | SET flag = 2;
+ | END;
+ | SELECT 1/0;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ verifySqlScriptResult(sqlScript, Seq.empty)
+ },
+ condition = "DUPLICATE_EXCEPTION_HANDLER.CONDITION",
+ parameters = Map("condition" -> "DUPLICATE_CONDITION"))
+ }
+
+ test("duplicate CONTINUE/CONTINUE handler for the same condition") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE duplicate_condition CONDITION FOR SQLSTATE '12345';
+ | DECLARE flag INT = -1;
+ | DECLARE CONTINUE HANDLER FOR duplicate_condition
+ | BEGIN
+ | SET flag = 1;
+ | END;
+ | DECLARE CONTINUE HANDLER FOR duplicate_condition
+ | BEGIN
+ | SET flag = 2;
+ | END;
+ | SELECT 1/0;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ verifySqlScriptResult(sqlScript, Seq.empty)
+ },
+ condition = "DUPLICATE_EXCEPTION_HANDLER.CONDITION",
+ parameters = Map("condition" -> "DUPLICATE_CONDITION"))
+ }
+
+ test("duplicate EXIT/EXIT handler for the same sqlState") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE flag INT = -1;
+ | DECLARE EXIT HANDLER FOR SQLSTATE '12345'
+ | BEGIN
+ | SET flag = 1;
+ | END;
+ | DECLARE EXIT HANDLER FOR SQLSTATE '12345'
+ | BEGIN
+ | SET flag = 2;
+ | END;
+ | SELECT 1/0;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ verifySqlScriptResult(sqlScript, Seq.empty)
+ },
+ condition = "DUPLICATE_EXCEPTION_HANDLER.SQLSTATE",
+ parameters = Map("sqlState" -> "12345"))
+ }
+
+ test("duplicate EXIT/CONTINUE handler for the same sqlState") {
val sqlScript =
"""
|BEGIN
@@ -109,6 +222,31 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
| BEGIN
| SET flag = 1;
| END;
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '12345'
+ | BEGIN
+ | SET flag = 2;
+ | END;
+ | SELECT 1/0;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ verifySqlScriptResult(sqlScript, Seq.empty)
+ },
+ condition = "DUPLICATE_EXCEPTION_HANDLER.SQLSTATE",
+ parameters = Map("sqlState" -> "12345"))
+ }
+
+ test("duplicate CONTINUE/EXIT handler for the same sqlState") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE flag INT = -1;
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '12345'
+ | BEGIN
+ | SET flag = 1;
+ | END;
| DECLARE EXIT HANDLER FOR SQLSTATE '12345'
| BEGIN
| SET flag = 2;
@@ -125,6 +263,31 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
parameters = Map("sqlState" -> "12345"))
}
+ test("duplicate CONTINUE/CONTINUE handler for the same sqlState") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE flag INT = -1;
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '12345'
+ | BEGIN
+ | SET flag = 1;
+ | END;
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '12345'
+ | BEGIN
+ | SET flag = 2;
+ | END;
+ | SELECT 1/0;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ verifySqlScriptResult(sqlScript, Seq.empty)
+ },
+ condition = "DUPLICATE_EXCEPTION_HANDLER.SQLSTATE",
+ parameters = Map("sqlState" -> "12345"))
+ }
+
test("Specific condition takes precedence over sqlState") {
val sqlScript =
"""
@@ -181,7 +344,7 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
- test("handler - exit resolve in the same block") {
+ test("exit handler - exit resolve in the same block") {
val sqlScript =
"""
|BEGIN
@@ -210,6 +373,35 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
+ test("continue handler - continue after the statement that caused the
exception") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | SELECT 2;
+ | SELECT 3;
+ | SELECT 1/0;
+ | SELECT 4;
+ | SELECT 5;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(2)), // select
+ Seq(Row(3)), // select
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
test("handler - exit resolve in the same block when if condition fails") {
val sqlScript =
"""
@@ -241,6 +433,37 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
+ test("continue handler - continue after the if statement when if condition
fails") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | DECLARE CONTINUE HANDLER FOR DIVIDE_BY_ZERO
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | SELECT 2;
+ | SELECT 3;
+ | IF (1 > 1/0) THEN
+ | SELECT 4;
+ | END IF;
+ | SELECT 5;
+ | SELECT 6;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(2)), // select
+ Seq(Row(3)), // select
+ Seq(Row(-1)), // select flag
+ Seq(Row(5)), // select
+ Seq(Row(6)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
test("handler - exit resolve in outer block") {
val sqlScript =
"""
@@ -693,7 +916,7 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
}
}
- test("handler - exit resolve when if condition fails") {
+ test("exit handler - exit resolve when if condition fails") {
val sqlScript =
"""
|BEGIN
@@ -720,7 +943,36 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
- test("handler - exit resolve when simple case variable computation fails") {
+ test("continue handler - continue when if condition fails") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | IF 1 > 1/0 THEN
+ | SELECT 10;
+ | END IF;
+ | SELECT 4;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
+ test("exit handler - exit resolve when simple case variable computation
fails") {
val sqlScript =
"""
|BEGIN
@@ -747,7 +999,36 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
- test("handler - exit resolve when simple case condition computation fails") {
+ test("continue handler - continue when simple case variable computation
fails") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | CASE 1/0
+ | WHEN flag THEN SELECT 10;
+ | END CASE;
+ | SELECT 4;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
+ test("exit handler - exit resolve when simple case condition computation
fails") {
val sqlScript =
"""
|BEGIN
@@ -774,7 +1055,36 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
- test("handler - exit resolve when simple case condition types are mismatch")
{
+ test("continue handler - continue when simple case condition computation
fails") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | CASE flag
+ | WHEN 1/0 THEN SELECT 10;
+ | END CASE;
+ | SELECT 4;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
+ test("exit handler - exit resolve when simple case condition types are
mismatch") {
val sqlScript =
"""
|BEGIN
@@ -801,7 +1111,36 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
- test("handler - exit resolve when searched case condition fails") {
+ test("continue handler - continue when simple case condition types are
mismatch") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR CAST_INVALID_INPUT
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | CASE flag
+ | WHEN 'teststr' THEN SELECT 10;
+ | END CASE;
+ | SELECT 4;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
+ test("exit handler - exit resolve when searched case condition fails") {
val sqlScript =
"""
|BEGIN
@@ -828,7 +1167,36 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
- test("handler - exit resolve when while condition fails") {
+ test("continue handler - continue when searched case condition fails") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | CASE
+ | WHEN flag = 1/0 THEN SELECT 10;
+ | END CASE;
+ | SELECT 4;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
+ test("exit handler - exit resolve when while condition fails") {
val sqlScript =
"""
|BEGIN
@@ -855,7 +1223,70 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
- test("handler - exit resolve when select fails in FOR statement") {
+ test("continue handler - continue when while condition fails") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | WHILE 1 > 1/0 DO
+ | SELECT 10;
+ | END WHILE;
+ | SELECT 4;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
+ test("continue handler - continue when select fails in REPEAT statement") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | SELECT 2;
+ | REPEAT
+ | SELECT 3;
+ | UNTIL
+ | 1 = 1/0
+ | END REPEAT;
+ | SELECT 4;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(2)), // select
+ Seq(Row(3)), // select
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
+ test("exit handler - exit resolve when select fails in FOR statement") {
val sqlScript =
"""
|BEGIN
@@ -882,6 +1313,64 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(sqlScript, expected = expected)
}
+ test("continue handler - continue when select fails in FOR statement") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | FOR iter AS (SELECT 1/0) DO
+ | SELECT 10;
+ | END FOR;
+ | SELECT 4;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
+ test("continue handler - continue when select fails inside FOR statement
body") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE VARIABLE flag INT = -1;
+ | BEGIN
+ | DECLARE CONTINUE HANDLER FOR SQLSTATE '22012'
+ | BEGIN
+ | SELECT flag;
+ | SET flag = 1;
+ | END;
+ | FOR iter AS SELECT 1 DO
+ | SELECT 1/0;
+ | SELECT 4;
+ | END FOR;
+ | SELECT 5;
+ | END;
+ | SELECT flag;
+ |END
+ |""".stripMargin
+ val expected = Seq(
+ Seq(Row(-1)), // select flag
+ Seq(Row(4)), // select
+ Seq(Row(5)), // select
+ Seq(Row(1)) // select flag from the outer body
+ )
+ verifySqlScriptResult(sqlScript, expected = expected)
+ }
+
// Tests
test("multi statement - simple") {
withTable("t") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]