This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b5edc11f7993 [SPARK-52129] Improve ParserUtils to support more 
Scripting constructs
b5edc11f7993 is described below

commit b5edc11f79930ae79f8d17b93c6e40b840820446
Author: David Milicevic <[email protected]>
AuthorDate: Wed May 14 21:44:01 2025 +0200

    [SPARK-52129] Improve ParserUtils to support more Scripting constructs
    
    ### What changes were proposed in this pull request?
    
    #### Label casing fix
    In `enterLabeledScope()` the original name was used instead of the 
lowercased one. This caused issues in some cases when matching labels.
    
    #### CompoundBodyParsingContext
    Renamed previous `SqlScriptingParsingContext` to a more specific 
`CompoundBodyParsingContext` to free up the name for a more generic context.
    
    #### Introduced new SqlScriptingParsingContext
    It consists of previously existing `SqlScriptingLabelContext` and newly 
introduced `SqlScriptingConditionContext` for the sake of tracking the declared 
condition names. This is required in order to be able to detect duplicate 
condition names and map specific conditions to their assigned SQL states.
    
    #### Other changes
    Propagating the changed structures through AstBuilder.
    Renaming error condition and related functions.
    
    ### Why are the changes needed?
    
    These changes introduce support for new correctness checks during parsing 
time for SQL scripts.
    The checks and changes are explained in the previous section.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing and new unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50893 from davidm-db/parsing_contexts.
    
    Lead-authored-by: David Milicevic <[email protected]>
    Co-authored-by: David Milicevic 
<[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |   4 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala     | 128 ++++++++---------
 .../spark/sql/catalyst/parser/ParserUtils.scala    |  35 ++++-
 .../spark/sql/errors/SqlScriptingErrors.scala      |   6 +-
 .../catalyst/parser/SqlScriptingParserSuite.scala  | 154 ++++++++++++++++++++-
 5 files changed, 250 insertions(+), 77 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index cf733b2a20ad..dc856221bead 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -2592,9 +2592,9 @@
       "Invalid condition declaration."
     ],
     "subClass" : {
-      "ONLY_AT_BEGINNING" : {
+      "NOT_AT_START_OF_COMPOUND_STATEMENT" : {
         "message" : [
-          "Condition <conditionName> can only be declared at the beginning of 
the compound."
+          "Condition <conditionName> can only be declared at the start of a 
BEGIN END compound statement."
         ]
       },
       "QUALIFIED_CONDITION_NAME" : {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index ef8d04e2e890..8018b36f7282 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -157,19 +157,19 @@ class AstBuilder extends DataTypeAstBuilder
   }
 
   override def visitSingleCompoundStatement(ctx: 
SingleCompoundStatementContext): CompoundBody = {
-    val labelCtx = new SqlScriptingLabelContext()
-    val labelText = labelCtx.enterLabeledScope(None, None)
+    val parsingCtx = new SqlScriptingParsingContext
+
+    val labelText = parsingCtx.labelContext.enterLabeledScope(None, None)
 
     val script = Option(ctx.compoundBody())
       .map(visitCompoundBodyImpl(
         _,
         Some(labelText),
-        allowVarDeclare = true,
-        labelCtx,
+        parsingCtx,
         isScope = true
       )).getOrElse(CompoundBody(Seq.empty, Some(labelText), isScope = true))
 
-    labelCtx.exitLabeledScope(None)
+    parsingCtx.labelContext.exitLabeledScope(None)
     script
   }
 
@@ -277,7 +277,7 @@ class AstBuilder extends DataTypeAstBuilder
 
   private def visitDeclareHandlerStatementImpl(
       ctx: DeclareHandlerStatementContext,
-      labelCtx: SqlScriptingLabelContext): ExceptionHandler = {
+      parsingCtx: SqlScriptingParsingContext): ExceptionHandler = {
     val exceptionHandlerTriggers = 
visitConditionValuesImpl(ctx.conditionValues())
 
     if (Option(ctx.CONTINUE()).isDefined) {
@@ -288,7 +288,7 @@ class AstBuilder extends DataTypeAstBuilder
     val body = if (Option(ctx.beginEndCompoundBlock()).isDefined) {
       visitBeginEndCompoundBlockImpl(
         ctx.beginEndCompoundBlock(),
-        labelCtx)
+        parsingCtx)
     } else {
       // If there is no compound body, then there must be a statement or set 
statement.
       val statement = Option(ctx.statement().asInstanceOf[ParserRuleContext])
@@ -305,27 +305,26 @@ class AstBuilder extends DataTypeAstBuilder
   private def visitCompoundBodyImpl(
       ctx: CompoundBodyContext,
       label: Option[String],
-      allowVarDeclare: Boolean,
-      labelCtx: SqlScriptingLabelContext,
+      parsingCtx: SqlScriptingParsingContext,
       isScope: Boolean): CompoundBody = {
     val buff = ListBuffer[CompoundPlanStatement]()
 
     val handlers = ListBuffer[ExceptionHandler]()
-    val conditions = HashMap[String, String]()
+    val currentConditions = HashMap[String, String]()
 
-    val scriptingParserContext = new SqlScriptingParsingContext()
+    val compoundBodyParserContext = new CompoundBodyParsingContext()
 
     ctx.compoundStatements.forEach(compoundStatement => {
-      val stmt = visitCompoundStatementImpl(compoundStatement, labelCtx)
+      val stmt = visitCompoundStatementImpl(compoundStatement, parsingCtx)
       stmt match {
         case handler: ExceptionHandler =>
-          scriptingParserContext.handler()
+          compoundBodyParserContext.handler()
           // All conditions are already visited when we encounter a handler.
           handler.exceptionHandlerTriggers.conditions.foreach(conditionName => 
{
             // Everything is stored in upper case so we can make 
case-insensitive comparisons.
             // If condition is not spark-defined error condition, check if 
user defined it.
             if (!SparkThrowableHelper.isValidErrorClass(conditionName)) {
-              if (!conditions.contains(conditionName)) {
+              if (!parsingCtx.conditionContext.contains(conditionName)) {
                 throw SqlScriptingErrors
                   .conditionNotFound(CurrentOrigin.get, conditionName)
               }
@@ -333,49 +332,53 @@ class AstBuilder extends DataTypeAstBuilder
           })
 
           handlers += handler
+
         case condition: ErrorCondition =>
-          scriptingParserContext.condition(condition)
+          compoundBodyParserContext.condition(condition, isScope)
           // Check for duplicate condition names in each scope.
           // When conditions are visited, everything is converted to upper-case
           // for case-insensitive comparisons.
-          if (conditions.contains(condition.conditionName)) {
+          if (parsingCtx.conditionContext.contains(condition.conditionName)) {
             throw SqlScriptingErrors
               .duplicateConditionInScope(CurrentOrigin.get, 
condition.conditionName)
           }
-          conditions += condition.conditionName -> condition.sqlState
+          currentConditions += condition.conditionName -> condition.sqlState
+          parsingCtx.conditionContext.add(condition)
+
         case statement =>
           statement match {
             case SingleStatement(createVariable: CreateVariable) =>
-              scriptingParserContext.variable(createVariable, allowVarDeclare)
-            case _ => scriptingParserContext.statement()
+              compoundBodyParserContext.variable(createVariable, isScope)
+            case _ => compoundBodyParserContext.statement()
           }
           buff += statement
       }
     })
 
-    CompoundBody(buff.toSeq, label, isScope, handlers.toSeq, conditions)
+    parsingCtx.conditionContext.remove(currentConditions.keys)
+
+    CompoundBody(buff.toSeq, label, isScope, handlers.toSeq, currentConditions)
   }
 
   private def visitBeginEndCompoundBlockImpl(
       ctx: BeginEndCompoundBlockContext,
-      labelCtx: SqlScriptingLabelContext): CompoundBody = {
+      parsingCtx: SqlScriptingParsingContext): CompoundBody = {
     val labelText =
-      labelCtx.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
+      parsingCtx.labelContext.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
     val body = Option(ctx.compoundBody())
       .map(visitCompoundBodyImpl(
         _,
         Some(labelText),
-        allowVarDeclare = true,
-        labelCtx,
+        parsingCtx,
         isScope = true
       )).getOrElse(CompoundBody(Seq.empty, Some(labelText), isScope = true))
-    labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+    parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
     body
   }
 
   private def visitCompoundStatementImpl(
       ctx: CompoundStatementContext,
-      labelCtx: SqlScriptingLabelContext): CompoundPlanStatement =
+      parsingCtx: SqlScriptingParsingContext): CompoundPlanStatement =
     withOrigin(ctx) {
       Option(ctx.statement().asInstanceOf[ParserRuleContext])
         
.orElse(Option(ctx.setStatementInsideSqlScript().asInstanceOf[ParserRuleContext]))
@@ -385,23 +388,23 @@ class AstBuilder extends DataTypeAstBuilder
           if (ctx.getChildCount == 1) {
             ctx.getChild(0) match {
               case compoundBodyContext: BeginEndCompoundBlockContext =>
-                visitBeginEndCompoundBlockImpl(compoundBodyContext, labelCtx)
+                visitBeginEndCompoundBlockImpl(compoundBodyContext, parsingCtx)
               case whileStmtContext: WhileStatementContext =>
-                visitWhileStatementImpl(whileStmtContext, labelCtx)
+                visitWhileStatementImpl(whileStmtContext, parsingCtx)
               case repeatStmtContext: RepeatStatementContext =>
-                visitRepeatStatementImpl(repeatStmtContext, labelCtx)
+                visitRepeatStatementImpl(repeatStmtContext, parsingCtx)
               case loopStatementContext: LoopStatementContext =>
-                visitLoopStatementImpl(loopStatementContext, labelCtx)
+                visitLoopStatementImpl(loopStatementContext, parsingCtx)
               case ifElseStmtContext: IfElseStatementContext =>
-                visitIfElseStatementImpl(ifElseStmtContext, labelCtx)
+                visitIfElseStatementImpl(ifElseStmtContext, parsingCtx)
               case searchedCaseContext: SearchedCaseStatementContext =>
-                visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx)
+                visitSearchedCaseStatementImpl(searchedCaseContext, parsingCtx)
               case simpleCaseContext: SimpleCaseStatementContext =>
-                visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx)
+                visitSimpleCaseStatementImpl(simpleCaseContext, parsingCtx)
               case forStatementContext: ForStatementContext =>
-                visitForStatementImpl(forStatementContext, labelCtx)
+                visitForStatementImpl(forStatementContext, parsingCtx)
               case declareHandlerContext: DeclareHandlerStatementContext =>
-                visitDeclareHandlerStatementImpl(declareHandlerContext, 
labelCtx)
+                visitDeclareHandlerStatementImpl(declareHandlerContext, 
parsingCtx)
               case declareConditionContext: DeclareConditionStatementContext =>
                 visitDeclareConditionStatementImpl(declareConditionContext)
               case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement]
@@ -414,7 +417,7 @@ class AstBuilder extends DataTypeAstBuilder
 
   private def visitIfElseStatementImpl(
       ctx: IfElseStatementContext,
-      labelCtx: SqlScriptingLabelContext): IfElseStatement = {
+      parsingCtx: SqlScriptingParsingContext): IfElseStatement = {
     IfElseStatement(
       conditions = ctx.booleanExpression().asScala.toList.map(boolExpr => 
withOrigin(boolExpr) {
         SingleStatement(
@@ -424,20 +427,20 @@ class AstBuilder extends DataTypeAstBuilder
       }),
       conditionalBodies = ctx.conditionalBodies.asScala.toList.map(
         body =>
-          visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, 
isScope = false)
+          visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
       ),
       elseBody = Option(ctx.elseBody).map(
         body =>
-          visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, 
isScope = false)
+          visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
       )
     )
   }
 
   private def visitWhileStatementImpl(
       ctx: WhileStatementContext,
-      labelCtx: SqlScriptingLabelContext): WhileStatement = {
+      parsingCtx: SqlScriptingParsingContext): WhileStatement = {
     val labelText =
-      labelCtx.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
+      parsingCtx.labelContext.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
     val boolExpr = ctx.booleanExpression()
 
     val condition = withOrigin(boolExpr) {
@@ -448,18 +451,17 @@ class AstBuilder extends DataTypeAstBuilder
     val body = visitCompoundBodyImpl(
       ctx.compoundBody(),
       None,
-      allowVarDeclare = false,
-      labelCtx,
+      parsingCtx,
       isScope = false
     )
-    labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+    parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
 
     WhileStatement(condition, body, Some(labelText))
   }
 
   private def visitSearchedCaseStatementImpl(
       ctx: SearchedCaseStatementContext,
-      labelCtx: SqlScriptingLabelContext): SearchedCaseStatement = {
+      parsingCtx: SqlScriptingParsingContext): SearchedCaseStatement = {
     val conditions = ctx.conditions.asScala.toList.map(boolExpr => 
withOrigin(boolExpr) {
       SingleStatement(
         Project(
@@ -469,7 +471,7 @@ class AstBuilder extends DataTypeAstBuilder
     val conditionalBodies =
       ctx.conditionalBodies.asScala.toList.map(
         body =>
-          visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, 
isScope = false)
+          visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
       )
 
     if (conditions.length != conditionalBodies.length) {
@@ -483,13 +485,13 @@ class AstBuilder extends DataTypeAstBuilder
       conditionalBodies = conditionalBodies,
       elseBody = Option(ctx.elseBody).map(
         body =>
-          visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, 
isScope = false)
+          visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
       ))
   }
 
   private def visitSimpleCaseStatementImpl(
       ctx: SimpleCaseStatementContext,
-      labelCtx: SqlScriptingLabelContext): SimpleCaseStatement = {
+      parsingCtx: SqlScriptingParsingContext): SimpleCaseStatement = {
     val caseVariableExpr = withOrigin(ctx.caseVariable) {
       expression(ctx.caseVariable)
     }
@@ -501,7 +503,7 @@ class AstBuilder extends DataTypeAstBuilder
     val conditionalBodies =
       ctx.conditionalBodies.asScala.toList.map(
         body =>
-          visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, 
isScope = false)
+          visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
       )
 
     if (conditionExpressions.length != conditionalBodies.length) {
@@ -516,15 +518,15 @@ class AstBuilder extends DataTypeAstBuilder
       conditionalBodies,
       elseBody = Option(ctx.elseBody).map(
         body =>
-          visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, 
isScope = false)
+          visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
       ))
   }
 
   private def visitRepeatStatementImpl(
       ctx: RepeatStatementContext,
-      labelCtx: SqlScriptingLabelContext): RepeatStatement = {
+      parsingCtx: SqlScriptingParsingContext): RepeatStatement = {
     val labelText =
-      labelCtx.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
+      parsingCtx.labelContext.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
     val boolExpr = ctx.booleanExpression()
 
     val condition = withOrigin(boolExpr) {
@@ -535,19 +537,19 @@ class AstBuilder extends DataTypeAstBuilder
     val body = visitCompoundBodyImpl(
       ctx.compoundBody(),
       None,
-      allowVarDeclare = false,
-      labelCtx,
+      parsingCtx,
       isScope = false
     )
-    labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+    parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
 
     RepeatStatement(condition, body, Some(labelText))
   }
 
   private def visitForStatementImpl(
       ctx: ForStatementContext,
-      labelCtx: SqlScriptingLabelContext): ForStatement = {
-    val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
+      parsingCtx: SqlScriptingParsingContext): ForStatement = {
+    val labelText = parsingCtx.labelContext.enterLabeledScope(
+      Option(ctx.beginLabel()), Option(ctx.endLabel()))
 
     val queryCtx = ctx.query()
     val query = withOrigin(queryCtx) {
@@ -557,11 +559,10 @@ class AstBuilder extends DataTypeAstBuilder
     val body = visitCompoundBodyImpl(
       ctx.compoundBody(),
       None,
-      allowVarDeclare = false,
-      labelCtx,
+      parsingCtx,
       isScope = false
     )
-    labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+    parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
 
     ForStatement(query, varName, body, Some(labelText))
   }
@@ -630,17 +631,16 @@ class AstBuilder extends DataTypeAstBuilder
 
   private def visitLoopStatementImpl(
       ctx: LoopStatementContext,
-      labelCtx: SqlScriptingLabelContext): LoopStatement = {
+      parsingCtx: SqlScriptingParsingContext): LoopStatement = {
     val labelText =
-      labelCtx.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
+      parsingCtx.labelContext.enterLabeledScope(Option(ctx.beginLabel()), 
Option(ctx.endLabel()))
     val body = visitCompoundBodyImpl(
       ctx.compoundBody(),
       None,
-      allowVarDeclare = false,
-      labelCtx,
+      parsingCtx,
       isScope = false
     )
-    labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+    parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
 
     LoopStatement(body, Some(labelText))
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index fe5bdcc00d30..38e92cf9aebd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.parser
 import java.util
 import java.util.Locale
 
-import scala.collection.immutable
-import scala.collection.mutable
+import scala.collection.{immutable, mutable}
 import scala.util.matching.Regex
 
 import org.antlr.v4.runtime.{ParserRuleContext, Token}
@@ -139,7 +138,7 @@ object ParserUtils extends SparkParserUtils {
   }
 }
 
-class SqlScriptingParsingContext {
+class CompoundBodyParsingContext {
 
   object State extends Enumeration {
     type State = Value
@@ -158,7 +157,12 @@ class SqlScriptingParsingContext {
   }
 
   /** Transition to CONDITION state. */
-  def condition(errorCondition: ErrorCondition): Unit = {
+  def condition(errorCondition: ErrorCondition, allowConditionDeclare: 
Boolean): Unit = {
+    if (!allowConditionDeclare) {
+      throw SqlScriptingErrors.conditionDeclarationNotAtStartOfCompound(
+        errorCondition.origin, errorCondition.conditionName
+      )
+    }
     transitionTo(State.CONDITION, None, errorCondition = Some(errorCondition))
   }
 
@@ -234,7 +238,7 @@ class SqlScriptingParsingContext {
 
       // Invalid transitions to CONDITION state.
       case (State.STATEMENT, State.CONDITION) =>
-        throw SqlScriptingErrors.conditionDeclarationOnlyAtBeginning(
+        throw SqlScriptingErrors.conditionDeclarationNotAtStartOfCompound(
           CurrentOrigin.get,
           errorCondition.get.conditionName)
 
@@ -255,6 +259,11 @@ class SqlScriptingParsingContext {
   }
 }
 
+class SqlScriptingParsingContext {
+  val labelContext: SqlScriptingLabelContext = new SqlScriptingLabelContext()
+  val conditionContext: SqlScriptingConditionContext = new 
SqlScriptingConditionContext()
+}
+
 class SqlScriptingLabelContext {
   /** Set to keep track of labels seen so far */
   private val seenLabels = mutable.Set[String]()
@@ -327,7 +336,7 @@ class SqlScriptingLabelContext {
           throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
         }
       }
-      seenLabels.add(beginLabelCtx.get.multipartIdentifier().getText)
+      seenLabels.add(txt)
       txt
     } else {
       // Do not add the label to the seenLabels set if it is not defined.
@@ -360,3 +369,17 @@ object SqlScriptingLabelContext {
     forbiddenLabelNames.exists(_.matches(labelName.toLowerCase(Locale.ROOT)))
   }
 }
+
+class SqlScriptingConditionContext {
+  private val conditionNameToSqlStateMap = mutable.HashMap[String, String]()
+
+  def contains(conditionName: String): Boolean = 
conditionNameToSqlStateMap.contains(conditionName)
+
+  def getSqlStateForCondition(conditionName: String): Option[String] =
+    conditionNameToSqlStateMap.get(conditionName)
+
+  def add(condition: ErrorCondition): Unit =
+    conditionNameToSqlStateMap += condition.conditionName -> condition.sqlState
+
+  def remove(toRemove: Iterable[String]): Unit = conditionNameToSqlStateMap 
--= toRemove
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
index 7e866d261485..ce0ed1f36a75 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
@@ -163,12 +163,12 @@ private[sql] object SqlScriptingErrors {
       messageParameters = Map("conditionName" -> toSQLStmt(conditionName)))
   }
 
-  def conditionDeclarationOnlyAtBeginning(
+  def conditionDeclarationNotAtStartOfCompound(
       origin: Origin,
       conditionName: String): Throwable = {
     new SqlScriptingException(
       origin = origin,
-      errorClass = "INVALID_ERROR_CONDITION_DECLARATION.ONLY_AT_BEGINNING",
+      errorClass = 
"INVALID_ERROR_CONDITION_DECLARATION.NOT_AT_START_OF_COMPOUND_STATEMENT",
       cause = null,
       messageParameters = Map("conditionName" -> toSQLId(conditionName)))
   }
@@ -188,7 +188,7 @@ private[sql] object SqlScriptingErrors {
       origin = origin,
       errorClass = "DUPLICATE_CONDITION_IN_SCOPE",
       cause = null,
-      messageParameters = Map("condition" -> condition))
+      messageParameters = Map("condition" -> toSQLId(condition)))
   }
 
   def handlerDeclarationInWrongPlace(origin: Origin): Throwable = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index ec4e558fc467..7f37dee64fb9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -269,6 +269,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with 
SQLHelper {
     assert(tree.label.contains("lbl"))
   }
 
+  test("compound: beginLabel + endLabel - case sensitivity check") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  lbl: BEGIN
+        |    SELECT 1;
+        |    SELECT 2;
+        |    INSERT INTO A VALUES (a, b, 3);
+        |    SELECT a, b, c FROM T;
+        |    SELECT * FROM T;
+        |  END LbL;
+        |END""".stripMargin
+    parsePlan(sqlScriptText)
+  }
+
   test("compound: beginLabel + endLabel with different values") {
     val sqlScriptText =
       """
@@ -518,7 +533,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with 
SQLHelper {
     assert(exception.origin.line.contains(4))
   }
 
-  test("declare in wrong scope") {
+  test("declare variable in wrong scope") {
     val sqlScriptText =
       """
         |BEGIN
@@ -2001,6 +2016,44 @@ class SqlScriptingParserSuite extends SparkFunSuite with 
SQLHelper {
       parameters = Map("label" -> toSQLId("lbl")))
   }
 
+  test("unique label names: nested begin-end blocks - case sensitivity check 
1") {
+    val sqlScriptText =
+      """BEGIN
+        |LbL: BEGIN
+        |  lbl: BEGIN
+        |    SELECT 1;
+        |  END;
+        |END;
+        |END
+      """.stripMargin
+    val exception = intercept[SqlScriptingException] {
+      parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+    }
+    checkError(
+      exception = exception,
+      condition = "LABEL_ALREADY_EXISTS",
+      parameters = Map("label" -> toSQLId("lbl")))
+  }
+
+  test("unique label names: nested begin-end blocks - case sensitivity check 
2") {
+    val sqlScriptText =
+      """BEGIN
+        |lbl: BEGIN
+        |  LbL: BEGIN
+        |    SELECT 1;
+        |  END;
+        |END;
+        |END
+      """.stripMargin
+    val exception = intercept[SqlScriptingException] {
+      parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+    }
+    checkError(
+      exception = exception,
+      condition = "LABEL_ALREADY_EXISTS",
+      parameters = Map("label" -> toSQLId("lbl")))
+  }
+
   test("unique label names: nested begin-end blocks with same prefix") {
     val sqlScriptText =
       """BEGIN
@@ -2580,11 +2633,29 @@ class SqlScriptingParserSuite extends SparkFunSuite 
with SQLHelper {
     }
     checkError(
       exception = exception,
-      condition = "INVALID_ERROR_CONDITION_DECLARATION.ONLY_AT_BEGINNING",
+      condition = 
"INVALID_ERROR_CONDITION_DECLARATION.NOT_AT_START_OF_COMPOUND_STATEMENT",
       parameters = Map("conditionName" -> "`TEST_CONDITION`"))
     assert(exception.origin.line.contains(2))
   }
 
+  test("declare condition in wrong scope") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        | IF 1=1 THEN
+        |   DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12345';
+        | END IF;
+        |END""".stripMargin
+    val exception = intercept[SqlScriptingException] {
+      parsePlan(sqlScriptText)
+    }
+    checkError(
+      exception = exception,
+      condition = 
"INVALID_ERROR_CONDITION_DECLARATION.NOT_AT_START_OF_COMPOUND_STATEMENT",
+      parameters = Map("conditionName" -> toSQLId("TEST_CONDITION")))
+    assert(exception.origin.line.contains(4))
+  }
+
   test("declare qualified condition") {
     val sqlScriptText =
       """
@@ -2617,6 +2688,48 @@ class SqlScriptingParserSuite extends SparkFunSuite with 
SQLHelper {
     assert(exception.origin.line.contains(3))
   }
 
+  test("declare duplicate condition") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12000';
+        |  DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '13000';
+        |  SELECT 1;
+        |END""".stripMargin
+    val exception = intercept[SqlScriptingException] {
+      parsePlan(sqlScriptText)
+    }
+    checkError(
+      exception = exception,
+      condition = "DUPLICATE_CONDITION_IN_SCOPE",
+      parameters = Map("condition" -> toSQLId("TEST_CONDITION")))
+    assert(exception.origin.line.contains(2))
+  }
+
+  test("declare duplicate condition nested") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12000';
+        |  BEGIN
+        |    IF (1 = 1) THEN
+        |      BEGIN
+        |        DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '13000';
+        |      END;
+        |    END IF;
+        |  END;
+        |  SELECT 1;
+        |END""".stripMargin
+    val exception = intercept[SqlScriptingException] {
+      parsePlan(sqlScriptText)
+    }
+    checkError(
+      exception = exception,
+      condition = "DUPLICATE_CONDITION_IN_SCOPE",
+      parameters = Map("condition" -> toSQLId("TEST_CONDITION")))
+    assert(exception.origin.line.contains(6))
+  }
+
   test("continue handler not supported") {
     val sqlScript =
       """
@@ -2901,6 +3014,43 @@ class SqlScriptingParserSuite extends SparkFunSuite with 
SQLHelper {
     assert(tree.handlers.head.body.collection.size == 1)
   }
 
+  test("declare handler for condition in parent scope") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12345';
+        |  BEGIN
+        |    DECLARE EXIT HANDLER FOR TEST_CONDITION SET test_var = 1;
+        |  END;
+        |END""".stripMargin
+    val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+    val handlerBody = tree.collection.head.asInstanceOf[CompoundBody]
+    assert(handlerBody.handlers.length == 1)
+    assert(handlerBody.handlers.head.isInstanceOf[ExceptionHandler])
+    assert(handlerBody.handlers.head.exceptionHandlerTriggers.conditions.size 
== 1)
+    
assert(handlerBody.handlers.head.exceptionHandlerTriggers.conditions.contains("TEST_CONDITION"))
+    assert(handlerBody.handlers.head.body.collection.size == 1)
+  }
+
+  test("declare nested handler for condition in parent scope of parent 
handler") {
+    val sqlScriptText =
+      """
+        |BEGIN
+        |  DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12345';
+        |  BEGIN
+        |    DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO
+        |      BEGIN
+        |        DECLARE EXIT HANDLER FOR TEST_CONDITION SET test_var = 1;
+        |      END;
+        |  END;
+        |END""".stripMargin
+    val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+    val handlerBody = tree
+      .collection.head.asInstanceOf[CompoundBody]
+      .handlers.head.body.asInstanceOf[CompoundBody]
+      .handlers.head
+    
assert(handlerBody.exceptionHandlerTriggers.conditions.contains("TEST_CONDITION"))
+  }
 
   // Helper methods
   def cleanupStatementString(statementStr: String): String = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to