This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 4d3b9775652c [SPARK-52129] Improve ParserUtils to support more
Scripting constructs
4d3b9775652c is described below
commit 4d3b9775652c843a2cca54ea81bbd15459f6cab4
Author: David Milicevic <[email protected]>
AuthorDate: Wed May 14 21:44:01 2025 +0200
[SPARK-52129] Improve ParserUtils to support more Scripting constructs
### What changes were proposed in this pull request?
#### Label casing fix
In `enterLabeledScope()` the original name was used instead of the
lowercased one. This caused issues in some cases when matching labels.
#### CompoundBodyParsingContext
Renamed previous `SqlScriptingParsingContext` to a more specific
`CompoundBodyParsingContext` to free up the name for a more generic context.
#### Introduced new SqlScriptingParsingContext
It consists of previously existing `SqlScriptingLabelContext` and newly
introduced `SqlScriptingConditionContext` for the sake of tracking the declared
condition names. This is required in order to be able to detect duplicate
condition names and map specific conditions to their assigned SQL states.
#### Other changes
Propagating the changed structures through AstBuilder.
Renaming error condition and related functions.
### Why are the changes needed?
These changes introduce support for new correctness checks during parsing
time for SQL scripts.
The checks and changes are explained in the previous section.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing and new unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50893 from davidm-db/parsing_contexts.
Lead-authored-by: David Milicevic <[email protected]>
Co-authored-by: David Milicevic
<[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit b5edc11f79930ae79f8d17b93c6e40b840820446)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 4 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 128 ++++++++---------
.../spark/sql/catalyst/parser/ParserUtils.scala | 35 ++++-
.../spark/sql/errors/SqlScriptingErrors.scala | 6 +-
.../catalyst/parser/SqlScriptingParserSuite.scala | 154 ++++++++++++++++++++-
5 files changed, 250 insertions(+), 77 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 334b95122be6..48aa084f1cb5 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -2457,9 +2457,9 @@
"Invalid condition declaration."
],
"subClass" : {
- "ONLY_AT_BEGINNING" : {
+ "NOT_AT_START_OF_COMPOUND_STATEMENT" : {
"message" : [
- "Condition <conditionName> can only be declared at the beginning of
the compound."
+ "Condition <conditionName> can only be declared at the start of a
BEGIN END compound statement."
]
},
"QUALIFIED_CONDITION_NAME" : {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 544b62a33d60..d237f0d732b5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -157,19 +157,19 @@ class AstBuilder extends DataTypeAstBuilder
}
override def visitSingleCompoundStatement(ctx:
SingleCompoundStatementContext): CompoundBody = {
- val labelCtx = new SqlScriptingLabelContext()
- val labelText = labelCtx.enterLabeledScope(None, None)
+ val parsingCtx = new SqlScriptingParsingContext
+
+ val labelText = parsingCtx.labelContext.enterLabeledScope(None, None)
val script = Option(ctx.compoundBody())
.map(visitCompoundBodyImpl(
_,
Some(labelText),
- allowVarDeclare = true,
- labelCtx,
+ parsingCtx,
isScope = true
)).getOrElse(CompoundBody(Seq.empty, Some(labelText), isScope = true))
- labelCtx.exitLabeledScope(None)
+ parsingCtx.labelContext.exitLabeledScope(None)
script
}
@@ -277,7 +277,7 @@ class AstBuilder extends DataTypeAstBuilder
private def visitDeclareHandlerStatementImpl(
ctx: DeclareHandlerStatementContext,
- labelCtx: SqlScriptingLabelContext): ExceptionHandler = {
+ parsingCtx: SqlScriptingParsingContext): ExceptionHandler = {
val exceptionHandlerTriggers =
visitConditionValuesImpl(ctx.conditionValues())
if (Option(ctx.CONTINUE()).isDefined) {
@@ -288,7 +288,7 @@ class AstBuilder extends DataTypeAstBuilder
val body = if (Option(ctx.beginEndCompoundBlock()).isDefined) {
visitBeginEndCompoundBlockImpl(
ctx.beginEndCompoundBlock(),
- labelCtx)
+ parsingCtx)
} else {
// If there is no compound body, then there must be a statement or set
statement.
val statement = Option(ctx.statement().asInstanceOf[ParserRuleContext])
@@ -305,27 +305,26 @@ class AstBuilder extends DataTypeAstBuilder
private def visitCompoundBodyImpl(
ctx: CompoundBodyContext,
label: Option[String],
- allowVarDeclare: Boolean,
- labelCtx: SqlScriptingLabelContext,
+ parsingCtx: SqlScriptingParsingContext,
isScope: Boolean): CompoundBody = {
val buff = ListBuffer[CompoundPlanStatement]()
val handlers = ListBuffer[ExceptionHandler]()
- val conditions = HashMap[String, String]()
+ val currentConditions = HashMap[String, String]()
- val scriptingParserContext = new SqlScriptingParsingContext()
+ val compoundBodyParserContext = new CompoundBodyParsingContext()
ctx.compoundStatements.forEach(compoundStatement => {
- val stmt = visitCompoundStatementImpl(compoundStatement, labelCtx)
+ val stmt = visitCompoundStatementImpl(compoundStatement, parsingCtx)
stmt match {
case handler: ExceptionHandler =>
- scriptingParserContext.handler()
+ compoundBodyParserContext.handler()
// All conditions are already visited when we encounter a handler.
handler.exceptionHandlerTriggers.conditions.foreach(conditionName =>
{
// Everything is stored in upper case so we can make
case-insensitive comparisons.
// If condition is not spark-defined error condition, check if
user defined it.
if (!SparkThrowableHelper.isValidErrorClass(conditionName)) {
- if (!conditions.contains(conditionName)) {
+ if (!parsingCtx.conditionContext.contains(conditionName)) {
throw SqlScriptingErrors
.conditionNotFound(CurrentOrigin.get, conditionName)
}
@@ -333,49 +332,53 @@ class AstBuilder extends DataTypeAstBuilder
})
handlers += handler
+
case condition: ErrorCondition =>
- scriptingParserContext.condition(condition)
+ compoundBodyParserContext.condition(condition, isScope)
// Check for duplicate condition names in each scope.
// When conditions are visited, everything is converted to upper-case
// for case-insensitive comparisons.
- if (conditions.contains(condition.conditionName)) {
+ if (parsingCtx.conditionContext.contains(condition.conditionName)) {
throw SqlScriptingErrors
.duplicateConditionInScope(CurrentOrigin.get,
condition.conditionName)
}
- conditions += condition.conditionName -> condition.sqlState
+ currentConditions += condition.conditionName -> condition.sqlState
+ parsingCtx.conditionContext.add(condition)
+
case statement =>
statement match {
case SingleStatement(createVariable: CreateVariable) =>
- scriptingParserContext.variable(createVariable, allowVarDeclare)
- case _ => scriptingParserContext.statement()
+ compoundBodyParserContext.variable(createVariable, isScope)
+ case _ => compoundBodyParserContext.statement()
}
buff += statement
}
})
- CompoundBody(buff.toSeq, label, isScope, handlers.toSeq, conditions)
+ parsingCtx.conditionContext.remove(currentConditions.keys)
+
+ CompoundBody(buff.toSeq, label, isScope, handlers.toSeq, currentConditions)
}
private def visitBeginEndCompoundBlockImpl(
ctx: BeginEndCompoundBlockContext,
- labelCtx: SqlScriptingLabelContext): CompoundBody = {
+ parsingCtx: SqlScriptingParsingContext): CompoundBody = {
val labelText =
- labelCtx.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
+ parsingCtx.labelContext.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
val body = Option(ctx.compoundBody())
.map(visitCompoundBodyImpl(
_,
Some(labelText),
- allowVarDeclare = true,
- labelCtx,
+ parsingCtx,
isScope = true
)).getOrElse(CompoundBody(Seq.empty, Some(labelText), isScope = true))
- labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+ parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
body
}
private def visitCompoundStatementImpl(
ctx: CompoundStatementContext,
- labelCtx: SqlScriptingLabelContext): CompoundPlanStatement =
+ parsingCtx: SqlScriptingParsingContext): CompoundPlanStatement =
withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setStatementInsideSqlScript().asInstanceOf[ParserRuleContext]))
@@ -385,23 +388,23 @@ class AstBuilder extends DataTypeAstBuilder
if (ctx.getChildCount == 1) {
ctx.getChild(0) match {
case compoundBodyContext: BeginEndCompoundBlockContext =>
- visitBeginEndCompoundBlockImpl(compoundBodyContext, labelCtx)
+ visitBeginEndCompoundBlockImpl(compoundBodyContext, parsingCtx)
case whileStmtContext: WhileStatementContext =>
- visitWhileStatementImpl(whileStmtContext, labelCtx)
+ visitWhileStatementImpl(whileStmtContext, parsingCtx)
case repeatStmtContext: RepeatStatementContext =>
- visitRepeatStatementImpl(repeatStmtContext, labelCtx)
+ visitRepeatStatementImpl(repeatStmtContext, parsingCtx)
case loopStatementContext: LoopStatementContext =>
- visitLoopStatementImpl(loopStatementContext, labelCtx)
+ visitLoopStatementImpl(loopStatementContext, parsingCtx)
case ifElseStmtContext: IfElseStatementContext =>
- visitIfElseStatementImpl(ifElseStmtContext, labelCtx)
+ visitIfElseStatementImpl(ifElseStmtContext, parsingCtx)
case searchedCaseContext: SearchedCaseStatementContext =>
- visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx)
+ visitSearchedCaseStatementImpl(searchedCaseContext, parsingCtx)
case simpleCaseContext: SimpleCaseStatementContext =>
- visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx)
+ visitSimpleCaseStatementImpl(simpleCaseContext, parsingCtx)
case forStatementContext: ForStatementContext =>
- visitForStatementImpl(forStatementContext, labelCtx)
+ visitForStatementImpl(forStatementContext, parsingCtx)
case declareHandlerContext: DeclareHandlerStatementContext =>
- visitDeclareHandlerStatementImpl(declareHandlerContext,
labelCtx)
+ visitDeclareHandlerStatementImpl(declareHandlerContext,
parsingCtx)
case declareConditionContext: DeclareConditionStatementContext =>
visitDeclareConditionStatementImpl(declareConditionContext)
case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement]
@@ -414,7 +417,7 @@ class AstBuilder extends DataTypeAstBuilder
private def visitIfElseStatementImpl(
ctx: IfElseStatementContext,
- labelCtx: SqlScriptingLabelContext): IfElseStatement = {
+ parsingCtx: SqlScriptingParsingContext): IfElseStatement = {
IfElseStatement(
conditions = ctx.booleanExpression().asScala.toList.map(boolExpr =>
withOrigin(boolExpr) {
SingleStatement(
@@ -424,20 +427,20 @@ class AstBuilder extends DataTypeAstBuilder
}),
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(
body =>
- visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx,
isScope = false)
+ visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
),
elseBody = Option(ctx.elseBody).map(
body =>
- visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx,
isScope = false)
+ visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
)
)
}
private def visitWhileStatementImpl(
ctx: WhileStatementContext,
- labelCtx: SqlScriptingLabelContext): WhileStatement = {
+ parsingCtx: SqlScriptingParsingContext): WhileStatement = {
val labelText =
- labelCtx.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
+ parsingCtx.labelContext.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()
val condition = withOrigin(boolExpr) {
@@ -448,18 +451,17 @@ class AstBuilder extends DataTypeAstBuilder
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
- allowVarDeclare = false,
- labelCtx,
+ parsingCtx,
isScope = false
)
- labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+ parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
WhileStatement(condition, body, Some(labelText))
}
private def visitSearchedCaseStatementImpl(
ctx: SearchedCaseStatementContext,
- labelCtx: SqlScriptingLabelContext): SearchedCaseStatement = {
+ parsingCtx: SqlScriptingParsingContext): SearchedCaseStatement = {
val conditions = ctx.conditions.asScala.toList.map(boolExpr =>
withOrigin(boolExpr) {
SingleStatement(
Project(
@@ -469,7 +471,7 @@ class AstBuilder extends DataTypeAstBuilder
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(
body =>
- visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx,
isScope = false)
+ visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
)
if (conditions.length != conditionalBodies.length) {
@@ -483,13 +485,13 @@ class AstBuilder extends DataTypeAstBuilder
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(
body =>
- visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx,
isScope = false)
+ visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
))
}
private def visitSimpleCaseStatementImpl(
ctx: SimpleCaseStatementContext,
- labelCtx: SqlScriptingLabelContext): SimpleCaseStatement = {
+ parsingCtx: SqlScriptingParsingContext): SimpleCaseStatement = {
val caseVariableExpr = withOrigin(ctx.caseVariable) {
expression(ctx.caseVariable)
}
@@ -501,7 +503,7 @@ class AstBuilder extends DataTypeAstBuilder
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(
body =>
- visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx,
isScope = false)
+ visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
)
if (conditionExpressions.length != conditionalBodies.length) {
@@ -516,15 +518,15 @@ class AstBuilder extends DataTypeAstBuilder
conditionalBodies,
elseBody = Option(ctx.elseBody).map(
body =>
- visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx,
isScope = false)
+ visitCompoundBodyImpl(body, None, parsingCtx, isScope = false)
))
}
private def visitRepeatStatementImpl(
ctx: RepeatStatementContext,
- labelCtx: SqlScriptingLabelContext): RepeatStatement = {
+ parsingCtx: SqlScriptingParsingContext): RepeatStatement = {
val labelText =
- labelCtx.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
+ parsingCtx.labelContext.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
val boolExpr = ctx.booleanExpression()
val condition = withOrigin(boolExpr) {
@@ -535,19 +537,19 @@ class AstBuilder extends DataTypeAstBuilder
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
- allowVarDeclare = false,
- labelCtx,
+ parsingCtx,
isScope = false
)
- labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+ parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
RepeatStatement(condition, body, Some(labelText))
}
private def visitForStatementImpl(
ctx: ForStatementContext,
- labelCtx: SqlScriptingLabelContext): ForStatement = {
- val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
+ parsingCtx: SqlScriptingParsingContext): ForStatement = {
+ val labelText = parsingCtx.labelContext.enterLabeledScope(
+ Option(ctx.beginLabel()), Option(ctx.endLabel()))
val queryCtx = ctx.query()
val query = withOrigin(queryCtx) {
@@ -557,11 +559,10 @@ class AstBuilder extends DataTypeAstBuilder
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
- allowVarDeclare = false,
- labelCtx,
+ parsingCtx,
isScope = false
)
- labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+ parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
ForStatement(query, varName, body, Some(labelText))
}
@@ -630,17 +631,16 @@ class AstBuilder extends DataTypeAstBuilder
private def visitLoopStatementImpl(
ctx: LoopStatementContext,
- labelCtx: SqlScriptingLabelContext): LoopStatement = {
+ parsingCtx: SqlScriptingParsingContext): LoopStatement = {
val labelText =
- labelCtx.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
+ parsingCtx.labelContext.enterLabeledScope(Option(ctx.beginLabel()),
Option(ctx.endLabel()))
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
- allowVarDeclare = false,
- labelCtx,
+ parsingCtx,
isScope = false
)
- labelCtx.exitLabeledScope(Option(ctx.beginLabel()))
+ parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
LoopStatement(body, Some(labelText))
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index fe5bdcc00d30..38e92cf9aebd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.parser
import java.util
import java.util.Locale
-import scala.collection.immutable
-import scala.collection.mutable
+import scala.collection.{immutable, mutable}
import scala.util.matching.Regex
import org.antlr.v4.runtime.{ParserRuleContext, Token}
@@ -139,7 +138,7 @@ object ParserUtils extends SparkParserUtils {
}
}
-class SqlScriptingParsingContext {
+class CompoundBodyParsingContext {
object State extends Enumeration {
type State = Value
@@ -158,7 +157,12 @@ class SqlScriptingParsingContext {
}
/** Transition to CONDITION state. */
- def condition(errorCondition: ErrorCondition): Unit = {
+ def condition(errorCondition: ErrorCondition, allowConditionDeclare:
Boolean): Unit = {
+ if (!allowConditionDeclare) {
+ throw SqlScriptingErrors.conditionDeclarationNotAtStartOfCompound(
+ errorCondition.origin, errorCondition.conditionName
+ )
+ }
transitionTo(State.CONDITION, None, errorCondition = Some(errorCondition))
}
@@ -234,7 +238,7 @@ class SqlScriptingParsingContext {
// Invalid transitions to CONDITION state.
case (State.STATEMENT, State.CONDITION) =>
- throw SqlScriptingErrors.conditionDeclarationOnlyAtBeginning(
+ throw SqlScriptingErrors.conditionDeclarationNotAtStartOfCompound(
CurrentOrigin.get,
errorCondition.get.conditionName)
@@ -255,6 +259,11 @@ class SqlScriptingParsingContext {
}
}
+class SqlScriptingParsingContext {
+ val labelContext: SqlScriptingLabelContext = new SqlScriptingLabelContext()
+ val conditionContext: SqlScriptingConditionContext = new
SqlScriptingConditionContext()
+}
+
class SqlScriptingLabelContext {
/** Set to keep track of labels seen so far */
private val seenLabels = mutable.Set[String]()
@@ -327,7 +336,7 @@ class SqlScriptingLabelContext {
throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
}
}
- seenLabels.add(beginLabelCtx.get.multipartIdentifier().getText)
+ seenLabels.add(txt)
txt
} else {
// Do not add the label to the seenLabels set if it is not defined.
@@ -360,3 +369,17 @@ object SqlScriptingLabelContext {
forbiddenLabelNames.exists(_.matches(labelName.toLowerCase(Locale.ROOT)))
}
}
+
+class SqlScriptingConditionContext {
+ private val conditionNameToSqlStateMap = mutable.HashMap[String, String]()
+
+ def contains(conditionName: String): Boolean =
conditionNameToSqlStateMap.contains(conditionName)
+
+ def getSqlStateForCondition(conditionName: String): Option[String] =
+ conditionNameToSqlStateMap.get(conditionName)
+
+ def add(condition: ErrorCondition): Unit =
+ conditionNameToSqlStateMap += condition.conditionName -> condition.sqlState
+
+ def remove(toRemove: Iterable[String]): Unit = conditionNameToSqlStateMap
--= toRemove
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
index 7e866d261485..ce0ed1f36a75 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
@@ -163,12 +163,12 @@ private[sql] object SqlScriptingErrors {
messageParameters = Map("conditionName" -> toSQLStmt(conditionName)))
}
- def conditionDeclarationOnlyAtBeginning(
+ def conditionDeclarationNotAtStartOfCompound(
origin: Origin,
conditionName: String): Throwable = {
new SqlScriptingException(
origin = origin,
- errorClass = "INVALID_ERROR_CONDITION_DECLARATION.ONLY_AT_BEGINNING",
+ errorClass =
"INVALID_ERROR_CONDITION_DECLARATION.NOT_AT_START_OF_COMPOUND_STATEMENT",
cause = null,
messageParameters = Map("conditionName" -> toSQLId(conditionName)))
}
@@ -188,7 +188,7 @@ private[sql] object SqlScriptingErrors {
origin = origin,
errorClass = "DUPLICATE_CONDITION_IN_SCOPE",
cause = null,
- messageParameters = Map("condition" -> condition))
+ messageParameters = Map("condition" -> toSQLId(condition)))
}
def handlerDeclarationInWrongPlace(origin: Origin): Throwable = {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index ec4e558fc467..7f37dee64fb9 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -269,6 +269,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
assert(tree.label.contains("lbl"))
}
+ test("compound: beginLabel + endLabel - case sensitivity check") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: BEGIN
+ | SELECT 1;
+ | SELECT 2;
+ | INSERT INTO A VALUES (a, b, 3);
+ | SELECT a, b, c FROM T;
+ | SELECT * FROM T;
+ | END LbL;
+ |END""".stripMargin
+ parsePlan(sqlScriptText)
+ }
+
test("compound: beginLabel + endLabel with different values") {
val sqlScriptText =
"""
@@ -518,7 +533,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
assert(exception.origin.line.contains(4))
}
- test("declare in wrong scope") {
+ test("declare variable in wrong scope") {
val sqlScriptText =
"""
|BEGIN
@@ -2001,6 +2016,44 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
parameters = Map("label" -> toSQLId("lbl")))
}
+ test("unique label names: nested begin-end blocks - case sensitivity check
1") {
+ val sqlScriptText =
+ """BEGIN
+ |LbL: BEGIN
+ | lbl: BEGIN
+ | SELECT 1;
+ | END;
+ |END;
+ |END
+ """.stripMargin
+ val exception = intercept[SqlScriptingException] {
+ parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ }
+ checkError(
+ exception = exception,
+ condition = "LABEL_ALREADY_EXISTS",
+ parameters = Map("label" -> toSQLId("lbl")))
+ }
+
+ test("unique label names: nested begin-end blocks - case sensitivity check
2") {
+ val sqlScriptText =
+ """BEGIN
+ |lbl: BEGIN
+ | LbL: BEGIN
+ | SELECT 1;
+ | END;
+ |END;
+ |END
+ """.stripMargin
+ val exception = intercept[SqlScriptingException] {
+ parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ }
+ checkError(
+ exception = exception,
+ condition = "LABEL_ALREADY_EXISTS",
+ parameters = Map("label" -> toSQLId("lbl")))
+ }
+
test("unique label names: nested begin-end blocks with same prefix") {
val sqlScriptText =
"""BEGIN
@@ -2580,11 +2633,29 @@ class SqlScriptingParserSuite extends SparkFunSuite
with SQLHelper {
}
checkError(
exception = exception,
- condition = "INVALID_ERROR_CONDITION_DECLARATION.ONLY_AT_BEGINNING",
+ condition =
"INVALID_ERROR_CONDITION_DECLARATION.NOT_AT_START_OF_COMPOUND_STATEMENT",
parameters = Map("conditionName" -> "`TEST_CONDITION`"))
assert(exception.origin.line.contains(2))
}
+ test("declare condition in wrong scope") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | IF 1=1 THEN
+ | DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12345';
+ | END IF;
+ |END""".stripMargin
+ val exception = intercept[SqlScriptingException] {
+ parsePlan(sqlScriptText)
+ }
+ checkError(
+ exception = exception,
+ condition =
"INVALID_ERROR_CONDITION_DECLARATION.NOT_AT_START_OF_COMPOUND_STATEMENT",
+ parameters = Map("conditionName" -> toSQLId("TEST_CONDITION")))
+ assert(exception.origin.line.contains(4))
+ }
+
test("declare qualified condition") {
val sqlScriptText =
"""
@@ -2617,6 +2688,48 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
assert(exception.origin.line.contains(3))
}
+ test("declare duplicate condition") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12000';
+ | DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '13000';
+ | SELECT 1;
+ |END""".stripMargin
+ val exception = intercept[SqlScriptingException] {
+ parsePlan(sqlScriptText)
+ }
+ checkError(
+ exception = exception,
+ condition = "DUPLICATE_CONDITION_IN_SCOPE",
+ parameters = Map("condition" -> toSQLId("TEST_CONDITION")))
+ assert(exception.origin.line.contains(2))
+ }
+
+ test("declare duplicate condition nested") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12000';
+ | BEGIN
+ | IF (1 = 1) THEN
+ | BEGIN
+ | DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '13000';
+ | END;
+ | END IF;
+ | END;
+ | SELECT 1;
+ |END""".stripMargin
+ val exception = intercept[SqlScriptingException] {
+ parsePlan(sqlScriptText)
+ }
+ checkError(
+ exception = exception,
+ condition = "DUPLICATE_CONDITION_IN_SCOPE",
+ parameters = Map("condition" -> toSQLId("TEST_CONDITION")))
+ assert(exception.origin.line.contains(6))
+ }
+
test("continue handler not supported") {
val sqlScript =
"""
@@ -2901,6 +3014,43 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
assert(tree.handlers.head.body.collection.size == 1)
}
+ test("declare handler for condition in parent scope") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12345';
+ | BEGIN
+ | DECLARE EXIT HANDLER FOR TEST_CONDITION SET test_var = 1;
+ | END;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ val handlerBody = tree.collection.head.asInstanceOf[CompoundBody]
+ assert(handlerBody.handlers.length == 1)
+ assert(handlerBody.handlers.head.isInstanceOf[ExceptionHandler])
+ assert(handlerBody.handlers.head.exceptionHandlerTriggers.conditions.size
== 1)
+
assert(handlerBody.handlers.head.exceptionHandlerTriggers.conditions.contains("TEST_CONDITION"))
+ assert(handlerBody.handlers.head.body.collection.size == 1)
+ }
+
+ test("declare nested handler for condition in parent scope of parent
handler") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE TEST_CONDITION CONDITION FOR SQLSTATE '12345';
+ | BEGIN
+ | DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO
+ | BEGIN
+ | DECLARE EXIT HANDLER FOR TEST_CONDITION SET test_var = 1;
+ | END;
+ | END;
+ |END""".stripMargin
+ val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
+ val handlerBody = tree
+ .collection.head.asInstanceOf[CompoundBody]
+ .handlers.head.body.asInstanceOf[CompoundBody]
+ .handlers.head
+
assert(handlerBody.exceptionHandlerTriggers.conditions.contains("TEST_CONDITION"))
+ }
// Helper methods
def cleanupStatementString(statementStr: String): String = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]