This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-4.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push: new 6549ec807945 [SPARK-52908][CORE] Prevent for iterator variable name clashing with names of labels in the path to the root of AST 6549ec807945 is described below commit 6549ec807945ecc8c642ab8ab96540b5e0cb2beb Author: Teodor Djelic <130703036+teodordje...@users.noreply.github.com> AuthorDate: Fri Jul 25 08:09:36 2025 +0800 [SPARK-52908][CORE] Prevent for iterator variable name clashing with names of labels in the path to the root of AST ### What changes were proposed in this pull request? Proposed change is to explicitly prohibit the interaction of iterator variable hiding the scoped variable if the label of scope and the iterator variable names are the same. ### Why are the changes needed? For iterator variable hides scoped variables if the label of the scope and iterator variable name are the same. This interaction leads to undesirable behavior: - Column of the iterator variable and a variable in scope having the same name will result in the column of the iterator variable hiding the variable in scope; - Trying to access the variable in scope that does not clash with the column of the iterator variable will result in the compiler not being able to resolve the variable in scope. ### Does this PR introduce _any_ user-facing change? Yes, it does. Changes are: - Error LABEL_ALREADY_DEFINED was renamed to LABEL_OR_FOR_VARIABLE_ALREADY_DEFINED; - Error LABEL_NAME_FORBIDDEN was renamed to LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN. Old behavior:  <img width="1198" height="214" alt="image" src="https://github.com/user-attachments/assets/b2586697-49e9-4cbc-b57e-53b6c91700bc" /> New behavior: <img width="1618" height="162" alt="image" src="https://github.com/user-attachments/assets/d023715c-08a1-47a2-9db1-3a19758140d6" /> <img width="1393" height="110" alt="image" src="https://github.com/user-attachments/assets/855d80c1-3fbd-42a6-ab1e-f664c4b4b47e" /> ### How was this patch tested? New tests in SqlScriptingExecutionSuite and existing tests. Instead of printing a variable resolution exception, exception printed is stating the prohibition of such interactions. Old behavior: <img width="1335" height="380" alt="467960247-895da398-3ace-4334-b597-1be4a400acf4" src="https://github.com/user-attachments/assets/2d9412e1-5896-4286-b230-b728315a0fc6" /> New behavior: <img width="1651" height="263" alt="467961070-92a0cbb7-bb93-410a-8266-3ef2591350f8" src="https://github.com/user-attachments/assets/5a44b505-6be2-4c1a-9f05-7fa563eab441" /> ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51595 from TeodorDjelic/prevent-for-iterator-variable-name-clashing-with-names-of-labels-in-the-path-to-the-root-of-ast. Authored-by: Teodor Djelic <130703036+teodordje...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit eb63949298b374486db9013338a9f0a22e05a972) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../src/main/resources/error/error-conditions.json | 8 +- .../catalyst/analysis/ColumnResolutionHelper.scala | 5 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 2 + .../spark/sql/catalyst/parser/ParserUtils.scala | 58 +++++++- .../spark/sql/errors/SqlScriptingErrors.scala | 6 +- .../catalyst/parser/SqlScriptingParserSuite.scala | 156 ++++++++++++++++++--- 6 files changed, 206 insertions(+), 29 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 69397c550a7a..b3f4c7a12832 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3744,15 +3744,15 @@ ], "sqlState" : "42K0L" }, - "LABEL_ALREADY_EXISTS" : { + "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS" : { "message" : [ - "The label <label> already exists. Choose another name or rename the existing label." + "The label or FOR variable <label> already exists. Choose another name or rename the existing one." ], "sqlState" : "42K0L" }, - "LABEL_NAME_FORBIDDEN" : { + "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN" : { "message" : [ - "The label name <label> is forbidden." + "The label or FOR variable name <label> is forbidden." ], "sqlState" : "42K0L" }, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index bfd6a3613ac9..ebf328a38a6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils.wrapOuterReference -import org.apache.spark.sql.catalyst.parser.SqlScriptingLabelContext.isForbiddenLabelName +import org.apache.spark.sql.catalyst.parser.SqlScriptingLabelContext.isForbiddenLabelOrForVariableName import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -287,7 +287,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { .filterNot(_ => AnalysisContext.get.isExecuteImmediate) // If variable name is qualified with session.<varName> treat it as a session variable. .filterNot(_ => - nameParts.length > 2 || (nameParts.length == 2 && isForbiddenLabelName(nameParts.head))) + nameParts.length > 2 + || (nameParts.length == 2 && isForbiddenLabelOrForVariableName(nameParts.head))) .flatMap(_.get(namePartsCaseAdjusted)) .map { varDef => VariableReference( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d237f0d732b5..03f8249977f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -555,6 +555,7 @@ class AstBuilder extends DataTypeAstBuilder val query = withOrigin(queryCtx) { SingleStatement(visitQuery(queryCtx)) } + parsingCtx.labelContext.enterForScope(Option(ctx.multipartIdentifier())) val varName = Option(ctx.multipartIdentifier()).map(_.getText) val body = visitCompoundBodyImpl( ctx.compoundBody(), @@ -562,6 +563,7 @@ class AstBuilder extends DataTypeAstBuilder parsingCtx, isScope = false ) + parsingCtx.labelContext.exitForScope(Option(ctx.multipartIdentifier())) parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel())) ForStatement(query, varName, body, Some(labelText)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 38e92cf9aebd..f48bdde8ebf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -28,7 +28,7 @@ import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl} import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext} +import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext, MultipartIdentifierContext} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, ErrorCondition} import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.catalyst.util.SparkParserUtils @@ -316,6 +316,23 @@ class SqlScriptingLabelContext { beginLabelCtx.map(_.multipartIdentifier().getText).isDefined } + /** + * Assert the identifier is not contained within seenLabels. + * If the identifier is contained within seenLabels, raise an exception. + */ + private def assertIdentifierNotInSeenLabels( + identifierCtx: Option[MultipartIdentifierContext]): Unit = { + identifierCtx.foreach { ctx => + val identifierName = ctx.getText + if (seenLabels.contains(identifierName.toLowerCase(Locale.ROOT))) { + withOrigin(ctx) { + throw SqlScriptingErrors + .duplicateLabels(CurrentOrigin.get, identifierName.toLowerCase(Locale.ROOT)) + } + } + } + } + /** * Enter a labeled scope and return the label text. * If the label is defined, it will be returned and added to seenLabels. @@ -342,9 +359,9 @@ class SqlScriptingLabelContext { // Do not add the label to the seenLabels set if it is not defined. java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT) } - if (SqlScriptingLabelContext.isForbiddenLabelName(labelText)) { + if (SqlScriptingLabelContext.isForbiddenLabelOrForVariableName(labelText)) { withOrigin(beginLabelCtx.get) { - throw SqlScriptingErrors.labelNameForbidden(CurrentOrigin.get, labelText) + throw SqlScriptingErrors.labelOrForVariableNameForbidden(CurrentOrigin.get, labelText) } } labelText @@ -359,13 +376,46 @@ class SqlScriptingLabelContext { seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT)) } } + + /** + * Enter a for loop scope. + * If the for loop variable is defined, it will be asserted to not be inside seenLabels; + * Then, if the for loop variable is defined, it will be added to seenLabels. + */ + def enterForScope(identifierCtx: Option[MultipartIdentifierContext]): Unit = { + identifierCtx.foreach { ctx => + val identifierName = ctx.getText + assertIdentifierNotInSeenLabels(identifierCtx) + seenLabels.add(identifierName.toLowerCase(Locale.ROOT)) + + if (SqlScriptingLabelContext.isForbiddenLabelOrForVariableName(identifierName)) { + withOrigin(ctx) { + throw SqlScriptingErrors.labelOrForVariableNameForbidden( + CurrentOrigin.get, + identifierName.toLowerCase(Locale.ROOT)) + } + } + } + } + + /** + * Exit a for loop scope. + * If the for loop variable is defined, it will be removed from seenLabels. + */ + def exitForScope(identifierCtx: Option[MultipartIdentifierContext]): Unit = { + identifierCtx.foreach { ctx => + val identifierName = ctx.getText + seenLabels.remove(identifierName.toLowerCase(Locale.ROOT)) + } + } + } object SqlScriptingLabelContext { private val forbiddenLabelNames: immutable.Set[Regex] = immutable.Set("builtin".r, "session".r, "sys.*".r) - def isForbiddenLabelName(labelName: String): Boolean = { + def isForbiddenLabelOrForVariableName(labelName: String): Boolean = { forbiddenLabelNames.exists(_.matches(labelName.toLowerCase(Locale.ROOT))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 23b863f24bc8..0b7b60cbe8da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -33,7 +33,7 @@ private[sql] object SqlScriptingErrors { def duplicateLabels(origin: Origin, label: String): Throwable = { new SqlScriptingException( origin = origin, - errorClass = "LABEL_ALREADY_EXISTS", + errorClass = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", cause = null, messageParameters = Map("label" -> toSQLId(label))) } @@ -54,10 +54,10 @@ private[sql] object SqlScriptingErrors { messageParameters = Map("endLabel" -> toSQLId(endLabel))) } - def labelNameForbidden(origin: Origin, label: String): Throwable = { + def labelOrForVariableNameForbidden(origin: Origin, label: String): Throwable = { new SqlScriptingException( origin = origin, - errorClass = "LABEL_NAME_FORBIDDEN", + errorClass = "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN", cause = null, messageParameters = Map("label" -> toSQLId(label)) ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 7f37dee64fb9..abcea96f0831 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -323,7 +323,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_NAME_FORBIDDEN", + condition = "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN", parameters = Map("label" -> toSQLId("system"))) assert(exception.origin.line.contains(3)) } @@ -345,7 +345,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_NAME_FORBIDDEN", + condition = "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN", parameters = Map("label" -> toSQLId("sysxyz"))) assert(exception.origin.line.contains(3)) } @@ -367,7 +367,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_NAME_FORBIDDEN", + condition = "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN", parameters = Map("label" -> toSQLId("session"))) assert(exception.origin.line.contains(3)) } @@ -389,7 +389,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_NAME_FORBIDDEN", + condition = "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN", parameters = Map("label" -> toSQLId("builtin"))) assert(exception.origin.line.contains(3)) } @@ -411,7 +411,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_NAME_FORBIDDEN", + condition = "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN", parameters = Map("label" -> toSQLId("system"))) assert(exception.origin.line.contains(3)) } @@ -433,7 +433,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_NAME_FORBIDDEN", + condition = "LABEL_OR_FOR_VARIABLE_NAME_FORBIDDEN", parameters = Map("label" -> toSQLId("session"))) assert(exception.origin.line.contains(3)) } @@ -2012,7 +2012,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("lbl"))) } @@ -2031,7 +2031,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("lbl"))) } @@ -2050,7 +2050,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("lbl"))) } @@ -2092,7 +2092,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("lbl_1"))) } @@ -2111,7 +2111,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("lbl"))) } @@ -2130,7 +2130,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("lbl"))) } @@ -2149,7 +2149,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("w_loop"))) } @@ -2170,7 +2170,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("r_loop"))) } @@ -2189,7 +2189,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("l_loop"))) } @@ -2208,7 +2208,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } checkError( exception = exception, - condition = "LABEL_ALREADY_EXISTS", + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", parameters = Map("label" -> toSQLId("f_loop"))) } @@ -2336,6 +2336,130 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStatement.label.get == "lbl_4") } + test("for variable not the same as labels in scope") { + val sqlScriptText = + """ + |BEGIN + | L1: BEGIN + | L2: BEGIN + | L3: FOR L4 AS SELECT 1 DO + | SELECT 1; + | FOR L5 AS SELECT 3 DO + | BEGIN + | SELECT L4; + | END; + | SELECT 4; + | END FOR; + | END FOR L3; + | END L2; + | L4: BEGIN + | SELECT 3; + | END L4; + | END L1; + |END""".stripMargin + + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CompoundBody]) + } + + test("for variable name is the same as a label in scope - should fail") { + val sqlScriptText = + """ + |BEGIN + | L1: BEGIN + | L2: BEGIN + | L3: FOR L2 AS SELECT 1 DO + | SELECT 1; + | SELECT 2; + | END FOR L3; + | END L2; + | L4: BEGIN + | SELECT 3; + | END L4; + | END L1; + |END""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + }, + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", + parameters = Map("label" -> "`l2`")) + } + + test("for variable name is the same as the label of the for loop - should fail") { + val sqlScriptText = + """ + |BEGIN + | L1: FOR L1 AS SELECT 1 DO + | SELECT 2; + | END FOR L1; + |END""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + }, + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", + parameters = Map("label" -> "`l1`")) + } + + test("label name is the same as the for loop variable name - should fail") { + val sqlScriptText = + """ + |BEGIN + | FOR L1 AS SELECT 1 DO + | L1: BEGIN + | SELECT 2; + | END L1; + | END FOR; + |END""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + }, + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", + parameters = Map("label" -> "`l1`")) + } + + test("nested for loop variable names are the same - should fail") { + val sqlScriptText = + """ + |BEGIN + | FOR L1 AS SELECT 1 DO + | FOR L1 AS SELECT 2 DO + | SELECT 3; + | END FOR; + | END FOR; + |END""".stripMargin + + checkError( + exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + }, + condition = "LABEL_OR_FOR_VARIABLE_ALREADY_EXISTS", + parameters = Map("label" -> "`l1`")) + } + + test("for loop variable names are the same but for loops are not nested") { + val sqlScriptText = + """ + |BEGIN + | FOR L1 AS SELECT 1 DO + | SELECT 2; + | END FOR; + | FOR L1 AS SELECT 3 DO + | SELECT 4; + | END FOR; + |END""".stripMargin + + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 2) + assert(tree.collection.forall(_.isInstanceOf[ForStatement])) + } + test("for statement") { val sqlScriptText = """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org