miland-db commented on code in PR #47423:
URL: https://github.com/apache/spark/pull/47423#discussion_r1714913284


##########
sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4:
##########
@@ -82,6 +84,23 @@ singleStatement
     : (statement|setResetStatement) SEMICOLON* EOF
     ;
 
+conditionValue
+    : stringLit
+    | multipartIdentifier
+    ;
+
+conditionValueList
+    : ((conditionValues+=conditionValue (COMMA 
conditionValues+=conditionValue)*) | SQLEXCEPTION | NOT FOUND)
+    ;
+
+declareCondition
+    : DECLARE multipartIdentifier CONDITION (FOR stringLit)?
+    ;
+
+declareHandler
+    : DECLARE (CONTINUE | EXIT) HANDLER FOR conditionValueList (BEGIN 
compoundBody END | statement)

Review Comment:
   Done.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala:
##########
@@ -229,6 +273,29 @@ class AstBuilder extends DataTypeAstBuilder
     )
   }
 
+  override def visitDeclareCondition(ctx: DeclareConditionContext): 
ErrorCondition = {
+    val conditionName = ctx.multipartIdentifier().getText
+    val conditionValue = 
Option(ctx.stringLit()).map(_.getText).getOrElse("'45000'").
+      replace("'", "")
+
+    val sqlStateRegex = "^[A-Za-z0-9]{5}$".r
+    assert(sqlStateRegex.findFirstIn(conditionValue).isDefined)
+
+    ErrorCondition(conditionName, conditionValue)
+  }
+
+  override def visitDeclareHandler(ctx: DeclareHandlerContext): ErrorHandler = 
{
+    val conditions = visit(ctx.conditionValueList()).asInstanceOf[Seq[String]]
+    val handlerType = Option(ctx.EXIT()).map(_ => 
HandlerType.EXIT).getOrElse(HandlerType.CONTINUE)
+
+    val body = Option(ctx.compoundBody()).map(visit).getOrElse {
+      val logicalPlan = visit(ctx.statement()).asInstanceOf[LogicalPlan]
+      CompoundBody(Seq(SingleStatement(parsedPlan = logicalPlan)))
+    }.asInstanceOf[CompoundBody]

Review Comment:
   Is it ok now?



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -127,20 +148,112 @@ class SingleStatementExec(
     origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 
1)
   }
 
-  override def reset(): Unit = isExecuted = false
+  override def reset(): Unit = {
+    raisedError = false
+    errorState = None
+    error = None
+    rethrow = None

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala:
##########
@@ -58,44 +61,108 @@ case class SqlScriptingInterpreter() {
       case _ => None
     }
 
+  private def transformBodyIntoExec(
+      compoundBody: CompoundBody,
+      isExitHandler: Boolean = false,
+      label: String = ""): CompoundBodyExec = {
+    val variables = compoundBody.collection.flatMap {
+      case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan)
+      case _ => None
+    }
+    val dropVariables = variables
+      .map(varName => DropVariable(varName, ifExists = true))
+      .map(new SingleStatementExec(_, Origin(), isInternal = true))
+      .reverse
+
+    val conditionHandlerMap = mutable.HashMap[String, ErrorHandlerExec]()
+    val handlers = ListBuffer[ErrorHandlerExec]()
+    compoundBody.handlers.foreach(handler => {
+      val handlerBodyExec =
+        transformBodyIntoExec(handler.body,
+          handler.handlerType == HandlerType.EXIT,
+          compoundBody.label.get)
+      val handlerExec = new ErrorHandlerExec(handlerBodyExec)
+
+      handler.conditions.foreach(condition => {
+        val conditionValue = compoundBody.conditions.getOrElse(condition, 
condition)
+        conditionHandlerMap.get(conditionValue) match {
+          case Some(_) =>
+            throw SqlScriptingErrors.duplicateHandlerForSameSqlState(
+              CurrentOrigin.get, conditionValue)
+          case None => conditionHandlerMap.put(conditionValue, handlerExec)
+        }

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -127,20 +148,112 @@ class SingleStatementExec(
     origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 
1)
   }
 
-  override def reset(): Unit = isExecuted = false
+  override def reset(): Unit = {
+    raisedError = false
+    errorState = None
+    error = None
+    rethrow = None
+    result = None // Should we do this?
+  }
+
+  override def execute(session: SparkSession): Unit = {
+    try {
+      val rows = Some(Dataset.ofRows(session, parsedPlan).collect())
+      if (shouldCollectResult) {
+        result = rows
+      }
+    } catch {
+      case e: SparkThrowable =>
+        raisedError = true
+        errorState = Some(e.getSqlState)
+        error = Some(e)
+        e match {
+          case throwable: Throwable =>
+            rethrow = Some(throwable)
+          case _ =>
+        }
+      case throwable: Throwable =>
+        raisedError = true
+        errorState = Some("SQLEXCEPTION")
+        rethrow = Some(throwable)
+    }
+  }
 }
 
 /**
- * Abstract class for all statements that contain nested statements.
- * Implements recursive iterator logic over all child execution nodes.
- * @param collection
- *   Collection of child execution nodes.
+ * Executable node for CompoundBody.
+ * @param statements
+ *   Executable nodes for nested statements within the CompoundBody.
+ * @param session
+ *   Spark session.
  */
-abstract class CompoundNestedStatementIteratorExec(collection: 
Seq[CompoundStatementExec])
+class CompoundBodyExec(
+      label: Option[String] = None,
+      statements: Seq[CompoundStatementExec],
+      conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = 
mutable.HashMap(),
+      session: SparkSession)

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -127,20 +148,112 @@ class SingleStatementExec(
     origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 
1)
   }
 
-  override def reset(): Unit = isExecuted = false
+  override def reset(): Unit = {
+    raisedError = false
+    errorState = None
+    error = None
+    rethrow = None
+    result = None // Should we do this?
+  }
+
+  override def execute(session: SparkSession): Unit = {
+    try {
+      val rows = Some(Dataset.ofRows(session, parsedPlan).collect())
+      if (shouldCollectResult) {
+        result = rows
+      }
+    } catch {
+      case e: SparkThrowable =>
+        raisedError = true
+        errorState = Some(e.getSqlState)
+        error = Some(e)
+        e match {
+          case throwable: Throwable =>
+            rethrow = Some(throwable)
+          case _ =>
+        }
+      case throwable: Throwable =>
+        raisedError = true
+        errorState = Some("SQLEXCEPTION")
+        rethrow = Some(throwable)
+    }
+  }
 }
 
 /**
- * Abstract class for all statements that contain nested statements.
- * Implements recursive iterator logic over all child execution nodes.
- * @param collection
- *   Collection of child execution nodes.
+ * Executable node for CompoundBody.
+ * @param statements
+ *   Executable nodes for nested statements within the CompoundBody.
+ * @param session
+ *   Spark session.
  */
-abstract class CompoundNestedStatementIteratorExec(collection: 
Seq[CompoundStatementExec])
+class CompoundBodyExec(
+      label: Option[String] = None,
+      statements: Seq[CompoundStatementExec],
+      conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = 
mutable.HashMap(),
+      session: SparkSession)
   extends NonLeafStatementExec {
 
-  private var localIterator = collection.iterator
-  private var curr = if (localIterator.hasNext) Some(localIterator.next()) 
else None
+  private def getHandler(condition: String): Option[ErrorHandlerExec] = {
+    conditionHandlerMap.get(condition)
+      .orElse(conditionHandlerMap.get("NOT FOUND") match {
+        case Some(handler) if condition.startsWith("02") => Some(handler)

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala:
##########
@@ -58,44 +61,108 @@ case class SqlScriptingInterpreter() {
       case _ => None
     }
 
+  private def transformBodyIntoExec(
+      compoundBody: CompoundBody,
+      isExitHandler: Boolean = false,
+      label: String = ""): CompoundBodyExec = {
+    val variables = compoundBody.collection.flatMap {
+      case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan)
+      case _ => None
+    }
+    val dropVariables = variables
+      .map(varName => DropVariable(varName, ifExists = true))
+      .map(new SingleStatementExec(_, Origin(), isInternal = true))
+      .reverse
+
+    val conditionHandlerMap = mutable.HashMap[String, ErrorHandlerExec]()
+    val handlers = ListBuffer[ErrorHandlerExec]()
+    compoundBody.handlers.foreach(handler => {
+      val handlerBodyExec =
+        transformBodyIntoExec(handler.body,
+          handler.handlerType == HandlerType.EXIT,
+          compoundBody.label.get)
+      val handlerExec = new ErrorHandlerExec(handlerBodyExec)
+
+      handler.conditions.foreach(condition => {

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -127,20 +153,116 @@ class SingleStatementExec(
     origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 
1)
   }
 
-  override def reset(): Unit = isExecuted = false
+  override def reset(): Unit = {
+    super.reset()
+    result = None
+  }
+
+  override def execute(session: SparkSession): Unit = {
+    try {
+      val rows = Some(Dataset.ofRows(session, parsedPlan).collect())
+      if (shouldCollectResult) {
+        result = rows
+      }
+    } catch {
+      case e: SparkThrowable =>
+        raisedError = true
+        errorState = Some(e.getSqlState)
+        error = Some(e)
+        e match {
+          case throwable: Throwable =>
+            rethrow = Some(throwable)
+          case _ =>
+        }
+      case throwable: Throwable =>
+        raisedError = true
+        errorState = Some("SQLEXCEPTION")
+        rethrow = Some(throwable)
+    }
+  }
 }
 
 /**
- * Abstract class for all statements that contain nested statements.
- * Implements recursive iterator logic over all child execution nodes.
- * @param collection
- *   Collection of child execution nodes.
+ * Executable node for CompoundBody.
+ * @param statements
+ *   Executable nodes for nested statements within the CompoundBody.
+ * @param session
+ *   Spark session.
  */
-abstract class CompoundNestedStatementIteratorExec(collection: 
Seq[CompoundStatementExec])
+class CompoundBodyExec(
+    statements: Seq[CompoundStatementExec],
+    session: SparkSession,
+    label: Option[String] = None,

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -127,20 +148,112 @@ class SingleStatementExec(
     origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 
1)
   }
 
-  override def reset(): Unit = isExecuted = false
+  override def reset(): Unit = {
+    raisedError = false
+    errorState = None
+    error = None
+    rethrow = None
+    result = None // Should we do this?
+  }
+
+  override def execute(session: SparkSession): Unit = {
+    try {
+      val rows = Some(Dataset.ofRows(session, parsedPlan).collect())
+      if (shouldCollectResult) {
+        result = rows
+      }
+    } catch {
+      case e: SparkThrowable =>
+        raisedError = true
+        errorState = Some(e.getSqlState)
+        error = Some(e)
+        e match {
+          case throwable: Throwable =>
+            rethrow = Some(throwable)
+          case _ =>
+        }
+      case throwable: Throwable =>
+        raisedError = true
+        errorState = Some("SQLEXCEPTION")
+        rethrow = Some(throwable)
+    }
+  }
 }
 
 /**
- * Abstract class for all statements that contain nested statements.
- * Implements recursive iterator logic over all child execution nodes.
- * @param collection
- *   Collection of child execution nodes.
+ * Executable node for CompoundBody.
+ * @param statements
+ *   Executable nodes for nested statements within the CompoundBody.
+ * @param session
+ *   Spark session.
  */
-abstract class CompoundNestedStatementIteratorExec(collection: 
Seq[CompoundStatementExec])
+class CompoundBodyExec(
+      label: Option[String] = None,
+      statements: Seq[CompoundStatementExec],
+      conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = 
mutable.HashMap(),
+      session: SparkSession)
   extends NonLeafStatementExec {
 
-  private var localIterator = collection.iterator
-  private var curr = if (localIterator.hasNext) Some(localIterator.next()) 
else None
+  private def getHandler(condition: String): Option[ErrorHandlerExec] = {
+    conditionHandlerMap.get(condition)
+      .orElse(conditionHandlerMap.get("NOT FOUND") match {

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -127,20 +153,116 @@ class SingleStatementExec(
     origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 
1)
   }
 
-  override def reset(): Unit = isExecuted = false
+  override def reset(): Unit = {
+    super.reset()
+    result = None
+  }
+
+  override def execute(session: SparkSession): Unit = {
+    try {
+      val rows = Some(Dataset.ofRows(session, parsedPlan).collect())
+      if (shouldCollectResult) {
+        result = rows
+      }
+    } catch {
+      case e: SparkThrowable =>
+        raisedError = true
+        errorState = Some(e.getSqlState)
+        error = Some(e)
+        e match {
+          case throwable: Throwable =>
+            rethrow = Some(throwable)
+          case _ =>
+        }
+      case throwable: Throwable =>
+        raisedError = true
+        errorState = Some("SQLEXCEPTION")
+        rethrow = Some(throwable)
+    }
+  }
 }
 
 /**
- * Abstract class for all statements that contain nested statements.
- * Implements recursive iterator logic over all child execution nodes.
- * @param collection
- *   Collection of child execution nodes.
+ * Executable node for CompoundBody.
+ * @param statements
+ *   Executable nodes for nested statements within the CompoundBody.
+ * @param session
+ *   Spark session.
  */
-abstract class CompoundNestedStatementIteratorExec(collection: 
Seq[CompoundStatementExec])
+class CompoundBodyExec(
+    statements: Seq[CompoundStatementExec],
+    session: SparkSession,
+    label: Option[String] = None,

Review Comment:
   Let's discuss this offline first.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -45,7 +47,28 @@ sealed trait CompoundStatementExec extends Logging {
 /**
  * Leaf node in the execution tree.
  */
-trait LeafStatementExec extends CompoundStatementExec
+trait LeafStatementExec extends CompoundStatementExec {
+
+  /** Whether an error was raised during the execution of this statement. */
+  var raisedError: Boolean = false
+
+  /**
+   * Error state of the statement.
+   */
+  var errorState: Option[String] = None
+
+  /** Error raised during statement execution. */
+  var error: Option[SparkThrowable] = None

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -127,20 +148,112 @@ class SingleStatementExec(
     origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 
1)
   }
 
-  override def reset(): Unit = isExecuted = false
+  override def reset(): Unit = {
+    raisedError = false
+    errorState = None
+    error = None
+    rethrow = None
+    result = None // Should we do this?
+  }
+
+  override def execute(session: SparkSession): Unit = {
+    try {
+      val rows = Some(Dataset.ofRows(session, parsedPlan).collect())
+      if (shouldCollectResult) {
+        result = rows
+      }
+    } catch {
+      case e: SparkThrowable =>
+        raisedError = true
+        errorState = Some(e.getSqlState)
+        error = Some(e)
+        e match {
+          case throwable: Throwable =>
+            rethrow = Some(throwable)
+          case _ =>
+        }
+      case throwable: Throwable =>
+        raisedError = true
+        errorState = Some("SQLEXCEPTION")
+        rethrow = Some(throwable)
+    }
+  }
 }
 
 /**
- * Abstract class for all statements that contain nested statements.
- * Implements recursive iterator logic over all child execution nodes.
- * @param collection
- *   Collection of child execution nodes.
+ * Executable node for CompoundBody.
+ * @param statements
+ *   Executable nodes for nested statements within the CompoundBody.
+ * @param session
+ *   Spark session.
  */
-abstract class CompoundNestedStatementIteratorExec(collection: 
Seq[CompoundStatementExec])
+class CompoundBodyExec(
+      label: Option[String] = None,
+      statements: Seq[CompoundStatementExec],
+      conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = 
mutable.HashMap(),
+      session: SparkSession)
   extends NonLeafStatementExec {
 
-  private var localIterator = collection.iterator
-  private var curr = if (localIterator.hasNext) Some(localIterator.next()) 
else None
+  private def getHandler(condition: String): Option[ErrorHandlerExec] = {

Review Comment:
   Done.



##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -127,20 +148,112 @@ class SingleStatementExec(
     origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get + 
1)
   }
 
-  override def reset(): Unit = isExecuted = false
+  override def reset(): Unit = {
+    raisedError = false
+    errorState = None
+    error = None
+    rethrow = None
+    result = None // Should we do this?
+  }
+
+  override def execute(session: SparkSession): Unit = {
+    try {
+      val rows = Some(Dataset.ofRows(session, parsedPlan).collect())
+      if (shouldCollectResult) {
+        result = rows
+      }
+    } catch {
+      case e: SparkThrowable =>
+        raisedError = true
+        errorState = Some(e.getSqlState)
+        error = Some(e)
+        e match {
+          case throwable: Throwable =>
+            rethrow = Some(throwable)
+          case _ =>
+        }
+      case throwable: Throwable =>
+        raisedError = true
+        errorState = Some("SQLEXCEPTION")
+        rethrow = Some(throwable)
+    }
+  }
 }
 
 /**
- * Abstract class for all statements that contain nested statements.
- * Implements recursive iterator logic over all child execution nodes.
- * @param collection
- *   Collection of child execution nodes.
+ * Executable node for CompoundBody.
+ * @param statements
+ *   Executable nodes for nested statements within the CompoundBody.
+ * @param session
+ *   Spark session.
  */
-abstract class CompoundNestedStatementIteratorExec(collection: 
Seq[CompoundStatementExec])
+class CompoundBodyExec(
+      label: Option[String] = None,
+      statements: Seq[CompoundStatementExec],
+      conditionHandlerMap: mutable.HashMap[String, ErrorHandlerExec] = 
mutable.HashMap(),
+      session: SparkSession)
   extends NonLeafStatementExec {
 
-  private var localIterator = collection.iterator
-  private var curr = if (localIterator.hasNext) Some(localIterator.next()) 
else None
+  private def getHandler(condition: String): Option[ErrorHandlerExec] = {
+    conditionHandlerMap.get(condition)
+      .orElse(conditionHandlerMap.get("NOT FOUND") match {
+        case Some(handler) if condition.startsWith("02") => Some(handler)
+        case _ => None
+      })
+      .orElse(conditionHandlerMap.get("SQLEXCEPTION"))
+  }
+
+  /**
+   * Handle error raised during the execution of the statement.
+   * @param statement statement that possibly raised the error
+   * @return pass through the statement
+   */
+  private def handleError(statement: LeafStatementExec): LeafStatementExec = {
+    if (statement.raisedError) {
+      getHandler(statement.errorState.get).foreach { handler =>
+        statement.reset() // Clear all flags and result
+        handler.reset()
+        returnHere = curr
+        curr = Some(handler.getHandlerBody)
+      }
+    }
+    statement
+  }
+
+  /**
+   * Check if the leave statement was used, if it is not used stop iterating 
surrounding
+   * [[CompoundBodyExec]] and move iterator forward. If the label of the block 
matches the label of
+   * the leave statement, mark the leave statement as used.
+   * @param leave leave  statement
+   * @return pass through the leave statement
+   */
+  private def handleLeave(leave: LeaveStatementExec): LeaveStatementExec = {
+    if (!leave.used) {
+      // Hard stop the iteration of the current begin/end block
+      stopIteration = true
+      // If label of the block matches the label of the leave statement,
+      // mark the leave statement as used
+      if (label.getOrElse("").equals(leave.getLabel)) {
+        leave.used = true
+      }
+    }
+    curr = if (localIterator.hasNext) Some(localIterator.next()) else None
+    leave
+  }
+
+  private var localIterator: Iterator[CompoundStatementExec] = 
statements.iterator
+  private var curr: Option[CompoundStatementExec] =
+    if (localIterator.hasNext) Some(localIterator.next()) else None
+  private var stopIteration: Boolean = false  // hard stop iteration flag

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to