This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fb17856a22be [SPARK-48530][SQL] Support for local variables in SQL Scripting fb17856a22be is described below commit fb17856a22be6968b2ed55ccbd7cf72111920bea Author: Dušan Tišma <dusan.ti...@databricks.com> AuthorDate: Thu Feb 20 21:22:16 2025 +0800 [SPARK-48530][SQL] Support for local variables in SQL Scripting ### What changes were proposed in this pull request? This pull request introduces support for local variables in SQL scripting. #### Behavior: Local variables are declared in the headers of compound bodies, and are bound to it's scope. Variables of the same name are allowed in nested scopes, where the innermost variable will be resolved. Optionally, a local variable can be qualified with the label of the compound body in which it was declared, which would allow accessing variables which are not the innermost in the current scope. Local variables have resolution priority over session variables, session variable resolution is attempted after local variable resolution. The exception to this is with fully qualified session variables, in the format `system.session.<varName>` or `session.<varName>`. System and session are forbidden for use as compound body labels. Local variables must not be qualified on declaration, can be set using `SET VAR` and cannot be `DROPPED`. They also should not be allowed to be declared with `DECLARE OR REPLACE`, however this is not implemented on this PR as `FOR` statement relies on this behavior. `FOR` statement must be updated in a separate PR to use proper local variables, as the current implementation is simulating them using session variables. #### Implementation notes: As core depends on catalyst, it's impossible to import code from core(where most of SQL scripting implementation is located) to catalyst. To solve this a trait `VariableManager` is introduced, which is then implemented in core and injected to catalyst. This `VariableManager` is basically a wrapper around `SqlScriptingExecutionContext` and provides methods for getting/setting/creating variables. This injection is tricky because we want to have one `ScriptingVariableManager` **per script**. Options considered to achieve this are: - Pass the manager/context to the analyzer using function calls. If possible, this solution would be ideal because it would allow every run of the analyzer to have it's own scripting context which is automatically cleaned up (AnalysisContext). This would also allow more control over the variable resolution, i.e. for `EXECUTE IMMEDIATE` we could simply not pass in the script context and it would behave as if outside of a script. This is the intended behavior for `EXECUTE IMMEDIATE`. Th [...] - Store the context in `CatalogManager`. `CatalogManager's` lifetime is tied to the session, so to allow for multiple scripts to execute in the same time we would need to e.g. have a map `scriptUUID -> VariableManager`, and to have the `scriptUUID` as a `ThreadLocal` variable in the `CatalogManager`. The drawback of this approach is that the script has to clean up it's resources after execution, and also that it's more complicated to e.g. forbid `EXECUTE IMMEDIATE` from accessing loca [...] Currently the second option seems better to me, however I am open to suggestions on how to approach this. EDIT: An option similar to the second one was chosen, except a ThreadLocal Singleton instance of context is used instead of storing it in `CatalogManager`. EDIT: Execute Immediate needs to be reworked in order to work properly with local variables. The generated query should not be able to access local variables, which means EXECUTE IMMEDIATE needs to somehow sandbox that query. This is done by analyzing it's entire subtree in SubstituteExecuteImmediate, with context so we know we are in EXECUTE IMMEDIATE. PR for this refactor - https://github.com/apache/spark/pull/49993 ### Why are the changes needed? Currently, local variables are simulated using session variables in SQL scripting, which is a temporary solution and bad in many ways. ### Does this PR introduce _any_ user-facing change? Yes, this change introduces multiple new types of errors. ### How was this patch tested? Tests were added to SqlScriptingExecutionSuite and SqlScriptingParserSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49445 from dusantism-db/scripting-local-variables. Authored-by: Dušan Tišma <dusan.ti...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../src/main/resources/error/error-conditions.json | 21 + .../org/apache/spark/util/LexicalThreadLocal.scala | 66 ++ .../SqlScriptingLocalVariableManager.scala | 25 + .../catalyst/analysis/ColumnResolutionHelper.scala | 53 +- .../sql/catalyst/analysis/ResolveCatalogs.scala | 88 +- .../sql/catalyst/analysis/ResolveSetVariable.scala | 5 +- .../catalyst/analysis/resolver/ResolverGuard.scala | 6 +- .../sql/catalyst/analysis/v2ResolutionPlans.scala | 8 + .../sql/catalyst/catalog/TempVariableManager.scala | 72 -- .../sql/catalyst/catalog/VariableManager.scala | 152 ++++ .../spark/sql/catalyst/parser/ParserUtils.scala | 20 +- .../spark/sql/errors/SqlScriptingErrors.scala | 9 + .../sql/catalyst/analysis/AnalysisSuite.scala | 16 +- .../spark/sql/catalyst/analysis/AnalysisTest.scala | 13 +- .../catalyst/parser/SqlScriptingParserSuite.scala | 132 +++ .../apache/spark/sql/classic/SparkSession.scala | 40 +- .../execution/command/v2/CreateVariableExec.scala | 39 +- .../execution/command/v2/DropVariableExec.scala | 2 +- .../sql/execution/command/v2/SetVariableExec.scala | 53 +- .../execution/command/v2/V2CommandStrategy.scala | 2 +- .../sql/scripting/SqlScriptingExecution.scala | 14 +- .../scripting/SqlScriptingExecutionContext.scala | 8 + .../sql/scripting/SqlScriptingExecutionNode.scala | 13 +- .../sql/scripting/SqlScriptingInterpreter.scala | 31 +- .../SqlScriptingLocalVariableManager.scala | 127 +++ .../spark/sql/scripting/SqlScriptingE2eSuite.scala | 1 - .../scripting/SqlScriptingExecutionNodeSuite.scala | 45 +- .../sql/scripting/SqlScriptingExecutionSuite.scala | 922 ++++++++++++++++++++- .../scripting/SqlScriptingInterpreterSuite.scala | 121 +-- .../org/apache/spark/sql/test/SQLTestUtils.scala | 11 + 30 files changed, 1788 insertions(+), 327 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 616a22b54ba7..f1012edd2de2 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3587,6 +3587,16 @@ "message" : [ "Variable <varName> can only be declared at the beginning of the compound." ] + }, + "QUALIFIED_LOCAL_VARIABLE" : { + "message" : [ + "The variable <varName> must be declared without a qualifier, as qualifiers are not allowed for local variable declarations." + ] + }, + "REPLACE_LOCAL_VARIABLE" : { + "message" : [ + "The variable <varName> does not support DECLARE OR REPLACE, as local variables cannot be replaced." + ] } }, "sqlState" : "42K0M" @@ -3733,6 +3743,12 @@ ], "sqlState" : "42K0L" }, + "LABEL_NAME_FORBIDDEN" : { + "message" : [ + "The label name <label> is forbidden." + ], + "sqlState" : "42K0L" + }, "LOAD_DATA_PATH_NOT_EXISTS" : { "message" : [ "LOAD DATA input path does not exist: <path>." @@ -5807,6 +5823,11 @@ "SQL Scripting is under development and not all features are supported. SQL Scripting enables users to write procedural SQL including control flow and error handling. To enable existing features set <sqlScriptingEnabled> to `true`." ] }, + "SQL_SCRIPTING_DROP_TEMPORARY_VARIABLE" : { + "message" : [ + "DROP TEMPORARY VARIABLE is not supported within SQL scripts. To bypass this, use `EXECUTE IMMEDIATE 'DROP TEMPORARY VARIABLE ...'` ." + ] + }, "SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS" : { "message" : [ "Positional parameters are not supported with SQL Scripting." diff --git a/common/utils/src/main/scala/org/apache/spark/util/LexicalThreadLocal.scala b/common/utils/src/main/scala/org/apache/spark/util/LexicalThreadLocal.scala new file mode 100644 index 000000000000..a63b76cb476a --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/LexicalThreadLocal.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Helper trait for defining thread locals with lexical scoping. With this helper, the thread local + * is private and can only be set by the [[Handle]]. The [[Handle]] only exposes the thread local + * value to functions passed into its runWith method. This pattern allows for + * the lifetime of the thread local value to be strictly controlled. + * + * Rather than calling `tl.set(...)` and `tl.remove()` you would get a handle and execute your code + * in `handle.runWith { ... }`. + * + * Example: + * {{{ + * object Credentials extends LexicalThreadLocal[Int] { + * def create(creds: Map[String, String]) = new Handle(Some(creds)) + * } + * ... + * val handle = Credentials.create(Map("key" -> "value")) + * assert(Credentials.get() == None) + * handle.runWith { + * assert(Credentials.get() == Some(Map("key" -> "value"))) + * } + * }}} + */ +trait LexicalThreadLocal[T] { + private val tl = new ThreadLocal[T] + + private def set(opt: Option[T]): Unit = { + opt match { + case Some(x) => tl.set(x) + case None => tl.remove() + } + } + + protected def createHandle(opt: Option[T]): Handle = new Handle(opt) + + def get(): Option[T] = Option(tl.get) + + /** Final class representing a handle to a thread local value. */ + final class Handle private[LexicalThreadLocal] (private val opt: Option[T]) { + def runWith[R](f: => R): R = { + val old = get() + set(opt) + try f finally { + set(old) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala new file mode 100644 index 000000000000..a3eef28d372a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.catalog.VariableManager +import org.apache.spark.util.LexicalThreadLocal + +object SqlScriptingLocalVariableManager extends LexicalThreadLocal[VariableManager] { + def create(variableManager: VariableManager): Handle = createHandle(Option(variableManager)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index ae80f64243fa..e778342d0837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -23,8 +23,10 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils.wrapOuterReference +import org.apache.spark.sql.catalyst.parser.SqlScriptingLabelContext.isForbiddenLabelName import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -229,6 +231,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } } + /** + * Look up variable by nameParts. + * If in SQL Script, first check local variables, unless in EXECUTE IMMEDIATE + * (EXECUTE IMMEDIATE generated query cannot access local variables). + * if not found fall back to session variables. + * @param nameParts NameParts of the variable. + * @return Reference to the variable. + */ def lookupVariable(nameParts: Seq[String]): Option[VariableReference] = { // The temp variables live in `SYSTEM.SESSION`, and the name can be qualified or not. def maybeTempVariableName(nameParts: Seq[String]): Boolean = { @@ -244,22 +254,41 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } } - if (maybeTempVariableName(nameParts)) { - val variableName = if (conf.caseSensitiveAnalysis) { - nameParts.last - } else { - nameParts.last.toLowerCase(Locale.ROOT) - } - catalogManager.tempVariableManager.get(variableName).map { varDef => + val namePartsCaseAdjusted = if (conf.caseSensitiveAnalysis) { + nameParts + } else { + nameParts.map(_.toLowerCase(Locale.ROOT)) + } + + SqlScriptingLocalVariableManager.get() + // If we are in EXECUTE IMMEDIATE lookup only session variables. + .filterNot(_ => AnalysisContext.get.isExecuteImmediate) + // If variable name is qualified with session.<varName> treat it as a session variable. + .filterNot(_ => + nameParts.length > 2 || (nameParts.length == 2 && isForbiddenLabelName(nameParts.head))) + .flatMap(_.get(namePartsCaseAdjusted)) + .map { varDef => VariableReference( nameParts, - FakeSystemCatalog, - Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), + FakeLocalCatalog, + Identifier.of(Array(varDef.identifier.namespace().last), namePartsCaseAdjusted.last), varDef) } - } else { - None - } + .orElse( + if (maybeTempVariableName(nameParts)) { + catalogManager.tempVariableManager + .get(namePartsCaseAdjusted) + .map { varDef => + VariableReference( + nameParts, + FakeSystemCatalog, + Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), namePartsCaseAdjusted.last), + varDef + )} + } else { + None + } + ) } // Resolves `UnresolvedAttribute` to its value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 664b68008080..ab9487aa6664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -19,9 +19,13 @@ package org.apache.spark.sql.catalyst.analysis import scala.jdk.CollectionConverters._ +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Identifier, LookupCatalog, SupportsNamespaces} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.util.ArrayImplicits._ @@ -35,10 +39,42 @@ class ResolveCatalogs(val catalogManager: CatalogManager) // We only support temp variables for now and the system catalog is not properly implemented // yet. We need to resolve `UnresolvedIdentifier` for variable commands specially. case c @ CreateVariable(UnresolvedIdentifier(nameParts, _), _, _) => - val resolved = resolveVariableName(nameParts) + // From scripts we can only create local variables, which must be unqualified, + // and must not be DECLARE OR REPLACE. + val resolved = if (withinSqlScript) { + // TODO [SPARK-50785]: Uncomment this when For Statement starts properly using local vars. +// if (c.replace) { +// throw new AnalysisException( +// "INVALID_VARIABLE_DECLARATION.REPLACE_LOCAL_VARIABLE", +// Map("varName" -> toSQLId(nameParts)) +// ) +// } + + if (nameParts.length != 1) { + throw new AnalysisException( + "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", + Map("varName" -> toSQLId(nameParts))) + } + + SqlScriptingLocalVariableManager.get() + .getOrElse(throw SparkException.internalError( + "Scripting local variable manager should be present in SQL script.")) + .qualify(nameParts.last) + } else { + val resolvedIdentifier = catalogManager.tempVariableManager.qualify(nameParts.last) + + assertValidSessionVariableNameParts(nameParts, resolvedIdentifier) + resolvedIdentifier + } + c.copy(name = resolved) case d @ DropVariable(UnresolvedIdentifier(nameParts, _), _) => - val resolved = resolveVariableName(nameParts) + if (withinSqlScript) { + throw new AnalysisException( + "UNSUPPORTED_FEATURE.SQL_SCRIPTING_DROP_TEMPORARY_VARIABLE", Map.empty) + } + val resolved = catalogManager.tempVariableManager.qualify(nameParts.last) + assertValidSessionVariableNameParts(nameParts, resolved) d.copy(name = resolved) case UnresolvedIdentifier(nameParts, allowTemp) => @@ -73,28 +109,34 @@ class ResolveCatalogs(val catalogManager: CatalogManager) } } - private def resolveVariableName(nameParts: Seq[String]): ResolvedIdentifier = { - def ident: Identifier = Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), nameParts.last) - if (nameParts.length == 1) { - ResolvedIdentifier(FakeSystemCatalog, ident) - } else if (nameParts.length == 2) { - if (nameParts.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { - ResolvedIdentifier(FakeSystemCatalog, ident) - } else { - throw QueryCompilationErrors.unresolvedVariableError( - nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) - } - } else if (nameParts.length == 3) { - if (nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && - nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { - ResolvedIdentifier(FakeSystemCatalog, ident) - } else { - throw QueryCompilationErrors.unresolvedVariableError( - nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) - } - } else { + private def withinSqlScript: Boolean = + SqlScriptingLocalVariableManager.get().isDefined && !AnalysisContext.get.isExecuteImmediate + + private def assertValidSessionVariableNameParts( + nameParts: Seq[String], + resolvedIdentifier: ResolvedIdentifier): Unit = { + if (!validSessionVariableName(nameParts)) { throw QueryCompilationErrors.unresolvedVariableError( - nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) + nameParts, + Seq( + resolvedIdentifier.catalog.name(), + resolvedIdentifier.identifier.namespace().head) + ) + } + + def validSessionVariableName(nameParts: Seq[String]): Boolean = nameParts.length match { + case 1 => true + + // On declare variable, local variables support only unqualified names. + // On drop variable, local variables are not supported at all. + case 2 if nameParts.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => true + + // When there are 3 nameParts the variable must be a fully qualified session variable + // i.e. "system.session.<varName>" + case 3 if nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => true + + case _ => false } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala index bd0204ba06fd..24b6b04de514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala @@ -53,11 +53,12 @@ class ResolveSetVariable(val catalogManager: CatalogManager) extends Rule[Logica // Names are normalized when the variables are created. // No need for case insensitive comparison here. // TODO: we need to group by the qualified variable name once other catalogs support it. - val dups = resolvedVars.groupBy(_.identifier.name).filter(kv => kv._2.length > 1) + val dups = resolvedVars.groupBy(_.identifier).filter(kv => kv._2.length > 1) if (dups.nonEmpty) { throw new AnalysisException( errorClass = "DUPLICATE_ASSIGNMENTS", - messageParameters = Map("nameList" -> dups.keys.map(toSQLId).mkString(", "))) + messageParameters = Map("nameList" -> + dups.keys.map(key => toSQLId(key.name())).mkString(", "))) } setVariable.copy(targetVariables = resolvedVars) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala index a6a5abdc45a1..833f50a5203d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver import java.util.Locale -import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, SqlScriptingLocalVariableManager} import org.apache.spark.sql.catalyst.analysis.{ FunctionRegistry, GetViewColumnByNameAndOrdinal, @@ -266,7 +266,9 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { LegacyBehaviorPolicy.withName(conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY)) == LegacyBehaviorPolicy.CORRECTED - private def checkVariables() = catalogManager.tempVariableManager.isEmpty + private def checkVariables() = + catalogManager.tempVariableManager.isEmpty && + SqlScriptingLocalVariableManager.get().forall(_.isEmpty) } object ResolverGuard { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index dee78b8f03af..6f657e931a49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -261,3 +261,11 @@ object FakeSystemCatalog extends CatalogPlugin { override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} override def name(): String = "system" } + +/** + * A fake v2 catalog to hold local variables for SQL scripting. + */ +object FakeLocalCatalog extends CatalogPlugin { + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} + override def name(): String = "local" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala deleted file mode 100644 index 2c262da1f444..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.catalog - -import javax.annotation.concurrent.GuardedBy - -import scala.collection.mutable - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.connector.catalog.CatalogManager.{SESSION_NAMESPACE, SYSTEM_CATALOG_NAME} -import org.apache.spark.sql.errors.DataTypeErrorsBase - -/** - * A thread-safe manager for temporary SQL variables (that live in the schema `SYSTEM.SESSION`), - * providing atomic operations to manage them, e.g. create, get, remove, etc. - * - * Note that, the variable name is always case-sensitive here, callers are responsible to format the - * variable name w.r.t. case-sensitive config. - */ -class TempVariableManager extends DataTypeErrorsBase { - - @GuardedBy("this") - private val variables = new mutable.HashMap[String, VariableDefinition] - - def create( - name: String, - defaultValueSQL: String, - initValue: Literal, - overrideIfExists: Boolean): Unit = synchronized { - if (!overrideIfExists && variables.contains(name)) { - throw new AnalysisException( - errorClass = "VARIABLE_ALREADY_EXISTS", - messageParameters = Map( - "variableName" -> toSQLId(Seq(SYSTEM_CATALOG_NAME, SESSION_NAMESPACE, name)))) - } - variables.put(name, VariableDefinition(defaultValueSQL, initValue)) - } - - def get(name: String): Option[VariableDefinition] = synchronized { - variables.get(name) - } - - def remove(name: String): Boolean = synchronized { - variables.remove(name).isDefined - } - - def clear(): Unit = synchronized { - variables.clear() - } - - def isEmpty: Boolean = synchronized { - variables.isEmpty - } -} - -case class VariableDefinition(defaultValueSQL: String, currentValue: Literal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/VariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/VariableManager.scala new file mode 100644 index 000000000000..ae313f66c9f3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/VariableManager.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{FakeSystemCatalog, ResolvedIdentifier} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} +import org.apache.spark.sql.connector.catalog.CatalogManager.{SESSION_NAMESPACE, SYSTEM_CATALOG_NAME} +import org.apache.spark.sql.errors.DataTypeErrorsBase +import org.apache.spark.sql.errors.QueryCompilationErrors.unresolvedVariableError + +/** + * Trait which provides an interface for variable managers. Methods are case sensitive regarding + * the variable name/nameParts/identifier, callers are responsible to + * format them w.r.t. case-sensitive config. + */ +trait VariableManager { + /** + * Create a variable. + * @param nameParts NameParts of the variable. + * @param varDef The VariableDefinition of the variable. + * @param overrideIfExists If true, the new variable will replace an existing one + * with the same identifier, if it exists. + */ + def create(nameParts: Seq[String], varDef: VariableDefinition, overrideIfExists: Boolean): Unit + + /** + * Set an existing variable to a new value. + * + * @param nameParts Name parts of the variable. + * @param varDef The new VariableDefinition of the variable. + */ + def set(nameParts: Seq[String], varDef: VariableDefinition): Unit + +/** + * Get an existing variable. + * + * @param nameParts Name parts of the variable. + */ + def get(nameParts: Seq[String]): Option[VariableDefinition] + + /** + * Delete an existing variable. + * + * @param nameParts Name parts of the variable. + */ + def remove(nameParts: Seq[String]): Boolean + + /** + * Create an identifier for the provided variable name. Could be context dependent. + * @param name Name for which an identifier is created. + */ + def qualify(name: String): ResolvedIdentifier + + /** + * Delete all variables. + */ + def clear(): Unit + + /** + * @return true if at least one variable exists, false otherwise. + */ + def isEmpty: Boolean +} + +/** + * @param identifier Identifier of the variable. + * @param defaultValueSQL SQL text of the variable's DEFAULT expression. + * @param currentValue Current value of the variable. + */ +case class VariableDefinition( + identifier: Identifier, + defaultValueSQL: String, + currentValue: Literal) + +/** + * A thread-safe manager for temporary SQL variables (that live in the schema `SYSTEM.SESSION`), + * providing atomic operations to manage them, e.g. create, get, remove, etc. + * + * Note that, the variable name is always case-sensitive here, callers are responsible to format the + * variable name w.r.t. case-sensitive config. + */ +class TempVariableManager extends VariableManager with DataTypeErrorsBase { + + @GuardedBy("this") + private val variables = new mutable.HashMap[String, VariableDefinition] + + override def create( + nameParts: Seq[String], + varDef: VariableDefinition, + overrideIfExists: Boolean): Unit = synchronized { + val name = nameParts.last + if (!overrideIfExists && variables.contains(name)) { + throw new AnalysisException( + errorClass = "VARIABLE_ALREADY_EXISTS", + messageParameters = Map( + "variableName" -> toSQLId(Seq(SYSTEM_CATALOG_NAME, SESSION_NAMESPACE, name)))) + } + variables.put(name, varDef) + } + + override def set(nameParts: Seq[String], varDef: VariableDefinition): Unit = synchronized { + val name = nameParts.last + // Sanity check as this is already checked in ResolveSetVariable. + if (!variables.contains(name)) { + throw unresolvedVariableError(nameParts, Seq("SYSTEM", "SESSION")) + } + variables.put(name, varDef) + } + + override def get(nameParts: Seq[String]): Option[VariableDefinition] = synchronized { + variables.get(nameParts.last) + } + + override def remove(nameParts: Seq[String]): Boolean = synchronized { + variables.remove(nameParts.last).isDefined + } + + override def qualify(name: String): ResolvedIdentifier = + ResolvedIdentifier( + FakeSystemCatalog, + Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), name) + ) + + override def clear(): Unit = synchronized { + variables.clear() + } + + override def isEmpty: Boolean = synchronized { + variables.isEmpty + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 235f2ae70c0d..4377f6b5bc0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.parser import java.util import java.util.Locale -import scala.collection.mutable.Set +import scala.collection.immutable +import scala.collection.mutable +import scala.util.matching.Regex import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval @@ -261,7 +263,7 @@ class SqlScriptingParsingContext { class SqlScriptingLabelContext { /** Set to keep track of labels seen so far */ - private val seenLabels = Set[String]() + private val seenLabels = mutable.Set[String]() /** * Check if the beginLabelCtx and endLabelCtx match. @@ -337,6 +339,11 @@ class SqlScriptingLabelContext { // Do not add the label to the seenLabels set if it is not defined. java.util.UUID.randomUUID.toString.toLowerCase(Locale.ROOT) } + if (SqlScriptingLabelContext.isForbiddenLabelName(labelText)) { + withOrigin(beginLabelCtx.get) { + throw SqlScriptingErrors.labelNameForbidden(CurrentOrigin.get, labelText) + } + } labelText } @@ -350,3 +357,12 @@ class SqlScriptingLabelContext { } } } + +object SqlScriptingLabelContext { + private val forbiddenLabelNames: immutable.Set[Regex] = + immutable.Set("builtin".r, "session".r, "sys.*".r) + + def isForbiddenLabelName(labelName: String): Boolean = { + forbiddenLabelNames.exists(_.matches(labelName.toLowerCase(Locale.ROOT))) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala index 993efa1c8f6b..7e866d261485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala @@ -54,6 +54,15 @@ private[sql] object SqlScriptingErrors { messageParameters = Map("endLabel" -> toSQLId(endLabel))) } + def labelNameForbidden(origin: Origin, label: String): Throwable = { + new SqlScriptingException( + origin = origin, + errorClass = "LABEL_NAME_FORBIDDEN", + cause = null, + messageParameters = Map("label" -> toSQLId(label)) + ) + } + def variableDeclarationNotAllowedInScope( origin: Origin, varName: Seq[String]): Throwable = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 2ffe6de974c7..eab4ddc666be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.SparkException import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{AliasIdentifier, QueryPlanningTracker, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog, VariableDefinition} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.connector.catalog.InMemoryTable +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf @@ -1500,10 +1500,13 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("Execute Immediate plan transformation") { try { + val varDef1 = VariableDefinition(Identifier.of(Array("res"), "res"), "1", Literal(1)) SimpleAnalyzer.catalogManager.tempVariableManager.create( - "res", "1", Literal(1), overrideIfExists = true) + Seq("res", "res"), varDef1, overrideIfExists = true) + + val varDef2 = VariableDefinition(Identifier.of(Array("res2"), "res2"), "1", Literal(1)) SimpleAnalyzer.catalogManager.tempVariableManager.create( - "res2", "1", Literal(1), overrideIfExists = true) + Seq("res2", "res2"), varDef2, overrideIfExists = true) val actual1 = parsePlan("EXECUTE IMMEDIATE 'SELECT 42 WHERE ? = 1' USING 2").analyze val expected1 = parsePlan("SELECT 42 where 2 = 1").analyze comparePlans(actual1, expected1) @@ -1514,11 +1517,12 @@ class AnalysisSuite extends AnalysisTest with Matchers { // Test that plan is transformed to SET operation val actual3 = parsePlan( "EXECUTE IMMEDIATE 'SELECT 17, 7 WHERE ? = 1' INTO res, res2 USING 2").analyze + // Normalize to make the plan equivalent to the below set statement. val expected3 = parsePlan("SET var (res, res2) = (SELECT 17, 7 where 2 = 1)").analyze comparePlans(actual3, expected3) } finally { - SimpleAnalyzer.catalogManager.tempVariableManager.remove("res") - SimpleAnalyzer.catalogManager.tempVariableManager.remove("res2") + SimpleAnalyzer.catalogManager.tempVariableManager.remove(Seq("res")) + SimpleAnalyzer.catalogManager.tempVariableManager.remove(Seq("res2")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 58e6cd7fe169..f65523c844f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -23,7 +23,7 @@ import java.util.Locale import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{QueryPlanningTracker, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog, TemporaryViewRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog, TemporaryViewRelation, VariableDefinition} import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.optimizer.InlineCTE @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{StringType, StructType} @@ -87,9 +88,15 @@ trait AnalysisTest extends PlanTest { overrideIfExists = true) new Analyzer(catalog) { catalogManager.tempVariableManager.create( - "testVarA", "1", Literal(1), overrideIfExists = true) + Seq("testA", "testVarA"), + VariableDefinition(Identifier.of(Array("testA"), "testVarA"), "1", Literal(1)), + overrideIfExists = true) + catalogManager.tempVariableManager.create( - "testVarNull", null, Literal(null, StringType), overrideIfExists = true) + Seq("testVarNull", "testVarNull"), + VariableDefinition( + Identifier.of(Array("testVarNull"), "testVarNull"), null, Literal(null, StringType)), + overrideIfExists = true) override val extendedResolutionRules = extendedAnalysisRules } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index c3d114836b67..9de5d09feb76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -269,6 +269,138 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(exception.origin.line.contains(3)) } + test("compound: forbidden label - system") { + val sqlScriptText = + """ + |BEGIN + | system: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + | END; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "LABEL_NAME_FORBIDDEN", + parameters = Map("label" -> toSQLId("system"))) + assert(exception.origin.line.contains(3)) + } + + test("compound: forbidden label - starting with sys") { + val sqlScriptText = + """ + |BEGIN + | sysXYZ: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + | END; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "LABEL_NAME_FORBIDDEN", + parameters = Map("label" -> toSQLId("sysxyz"))) + assert(exception.origin.line.contains(3)) + } + + test("compound: forbidden label - session") { + val sqlScriptText = + """ + |BEGIN + | session: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + | END; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "LABEL_NAME_FORBIDDEN", + parameters = Map("label" -> toSQLId("session"))) + assert(exception.origin.line.contains(3)) + } + + test("compound: forbidden label - builtin") { + val sqlScriptText = + """ + |BEGIN + | builtin: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + | END; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "LABEL_NAME_FORBIDDEN", + parameters = Map("label" -> toSQLId("builtin"))) + assert(exception.origin.line.contains(3)) + } + + test("compound: forbidden label - system - case insensitive") { + val sqlScriptText = + """ + |BEGIN + | SySTeM: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + | END; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "LABEL_NAME_FORBIDDEN", + parameters = Map("label" -> toSQLId("system"))) + assert(exception.origin.line.contains(3)) + } + + test("compound: forbidden label - session - case insensitive") { + val sqlScriptText = + """ + |BEGIN + | SEsSiON: BEGIN + | SELECT 1; + | SELECT 2; + | INSERT INTO A VALUES (a, b, 3); + | SELECT a, b, c FROM T; + | SELECT * FROM T; + | END; + |END""".stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText) + } + checkError( + exception = exception, + condition = "LABEL_NAME_FORBIDDEN", + parameters = Map("label" -> toSQLId("session"))) + assert(exception.origin.line.contains(3)) + } + test("compound: endLabel") { val sqlScriptText = """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index bb2000eccc73..205abf365091 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -445,27 +445,29 @@ class SparkSession private( script: CompoundBody, args: Map[String, Expression] = Map.empty): DataFrame = { val sse = new SqlScriptingExecution(script, this, args) - var result: Option[Seq[Row]] = None - - // We must execute returned df before calling sse.getNextResult again because sse.hasNext - // advances the script execution and executes all statements until the next result. We must - // collect results immediately to maintain execution order. - // This ensures we respect the contract of SqlScriptingExecution API. - var df: Option[DataFrame] = sse.getNextResult - while (df.isDefined) { - sse.withErrorHandling { - // Collect results from the current DataFrame. - result = Some(df.get.collect().toSeq) + sse.withLocalVariableManager { + var result: Option[Seq[Row]] = None + + // We must execute returned df before calling sse.getNextResult again because sse.hasNext + // advances the script execution and executes all statements until the next result. We must + // collect results immediately to maintain execution order. + // This ensures we respect the contract of SqlScriptingExecution API. + var df: Option[DataFrame] = sse.getNextResult + while (df.isDefined) { + sse.withErrorHandling { + // Collect results from the current DataFrame. + result = Some(df.get.collect().toSeq) + } + df = sse.getNextResult } - df = sse.getNextResult - } - if (result.isEmpty) { - emptyDataFrame - } else { - val attributes = DataTypeUtils.toAttributes(result.get.head.schema) - Dataset.ofRows( - self, LocalRelation.fromExternalRows(attributes, result.get)) + if (result.isEmpty) { + emptyDataFrame + } else { + val attributes = DataTypeUtils.toAttributes(result.get.head.schema) + Dataset.ofRows( + self, LocalRelation.fromExternalRows(attributes, result.get)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala index 0ed1c104edb9..f7ad62c2e1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala @@ -19,29 +19,50 @@ package org.apache.spark.sql.execution.command.v2 import java.util.Locale -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, SqlScriptingLocalVariableManager} +import org.apache.spark.sql.catalyst.analysis.{FakeLocalCatalog, ResolvedIdentifier} +import org.apache.spark.sql.catalyst.catalog.VariableDefinition import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionsEvaluator, Literal} import org.apache.spark.sql.catalyst.plans.logical.DefaultValueExpression +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.execution.datasources.v2.LeafV2CommandExec /** * Physical plan node for creating a variable. */ -case class CreateVariableExec(name: String, defaultExpr: DefaultValueExpression, replace: Boolean) - extends LeafV2CommandExec with ExpressionsEvaluator { +case class CreateVariableExec( + resolvedIdentifier: ResolvedIdentifier, + defaultExpr: DefaultValueExpression, + replace: Boolean) extends LeafV2CommandExec with ExpressionsEvaluator { override protected def run(): Seq[InternalRow] = { - val variableManager = session.sessionState.catalogManager.tempVariableManager + val scriptingVariableManager = SqlScriptingLocalVariableManager.get() + val tempVariableManager = session.sessionState.catalogManager.tempVariableManager + val exprs = prepareExpressions(Seq(defaultExpr.child), subExprEliminationEnabled = false) initializeExprs(exprs, 0) val initValue = Literal(exprs.head.eval(), defaultExpr.dataType) - val normalizedName = if (session.sessionState.conf.caseSensitiveAnalysis) { - name + + val normalizedIdentifier = if (session.sessionState.conf.caseSensitiveAnalysis) { + resolvedIdentifier.identifier } else { - name.toLowerCase(Locale.ROOT) + Identifier.of( + resolvedIdentifier.identifier.namespace().map(_.toLowerCase(Locale.ROOT)), + resolvedIdentifier.identifier.name().toLowerCase(Locale.ROOT)) } - variableManager.create( - normalizedName, defaultExpr.originalSQL, initValue, replace) + val varDef = VariableDefinition(normalizedIdentifier, defaultExpr.originalSQL, initValue) + + // create local variable if we are in a script, otherwise create session variable + scriptingVariableManager + .filter(_ => resolvedIdentifier.catalog == FakeLocalCatalog) + // If resolvedIdentifier.catalog is FakeLocalCatalog, scriptingVariableManager + // will always be present. + .getOrElse(tempVariableManager) + .create( + normalizedIdentifier.namespace().toSeq :+ normalizedIdentifier.name(), + varDef, + replace) + Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/DropVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/DropVariableExec.scala index 22538076879f..b3062c7afe1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/DropVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/DropVariableExec.scala @@ -39,7 +39,7 @@ case class DropVariableExec(name: String, ifExists: Boolean) extends LeafV2Comma } else { name.toLowerCase(Locale.ROOT) } - if (!variableManager.remove(normalizedName)) { + if (!variableManager.remove(Seq(normalizedName))) { // The variable does not exist if (!ifExists) { throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala index a5d90b4d154c..022b8c5869e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.execution.command.v2 +import java.util.Locale + import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.TempVariableManager +import org.apache.spark.sql.catalyst.{InternalRow, SqlScriptingLocalVariableManager} +import org.apache.spark.sql.catalyst.analysis.{FakeLocalCatalog, FakeSystemCatalog} +import org.apache.spark.sql.catalyst.catalog.VariableDefinition import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, VariableReference} import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.errors.QueryCompilationErrors.unresolvedVariableError import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.v2.V2CommandExec @@ -32,11 +36,10 @@ case class SetVariableExec(variables: Seq[VariableReference], query: SparkPlan) extends V2CommandExec with UnaryLike[SparkPlan] { override protected def run(): Seq[InternalRow] = { - val variableManager = session.sessionState.catalogManager.tempVariableManager val values = query.executeCollect() if (values.length == 0) { variables.foreach { v => - createVariable(variableManager, v, null) + setVariable(v, null) } } else if (values.length > 1) { throw new SparkException( @@ -47,21 +50,47 @@ case class SetVariableExec(variables: Seq[VariableReference], query: SparkPlan) val row = values(0) variables.zipWithIndex.foreach { case (v, index) => val value = row.get(index, v.dataType) - createVariable(variableManager, v, value) + setVariable(v, value) } } Seq.empty } - private def createVariable( - variableManager: TempVariableManager, + private def setVariable( variable: VariableReference, value: Any): Unit = { - variableManager.create( - variable.identifier.name, - variable.varDef.defaultValueSQL, - Literal(value, variable.dataType), - overrideIfExists = true) + val namePartsCaseAdjusted = if (session.sessionState.conf.caseSensitiveAnalysis) { + variable.originalNameParts + } else { + variable.originalNameParts.map(_.toLowerCase(Locale.ROOT)) + } + + val tempVariableManager = session.sessionState.catalogManager.tempVariableManager + val scriptingVariableManager = SqlScriptingLocalVariableManager.get() + + val variableManager = variable.catalog match { + case FakeLocalCatalog if scriptingVariableManager.isEmpty => + throw SparkException.internalError("SetVariableExec: Variable has FakeLocalCatalog, " + + "but ScriptingVariableManager is None.") + + case FakeLocalCatalog if scriptingVariableManager.get.get(namePartsCaseAdjusted).isEmpty => + throw SparkException.internalError("Local variable should be present in SetVariableExec" + + "because ResolveSetVariable has already determined it exists.") + + case FakeLocalCatalog => scriptingVariableManager.get + + case FakeSystemCatalog if tempVariableManager.get(namePartsCaseAdjusted).isEmpty => + throw unresolvedVariableError(namePartsCaseAdjusted, Seq("SYSTEM", "SESSION")) + + case FakeSystemCatalog => tempVariableManager + + case c => throw SparkException.internalError("Unexpected catalog in SetVariableExec: " + c) + } + + val varDef = VariableDefinition( + variable.identifier, variable.varDef.defaultValueSQL, Literal(value, variable.dataType)) + + variableManager.set(namePartsCaseAdjusted, varDef) } override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala index 704502d118b8..3e073202d4c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala @@ -28,7 +28,7 @@ object V2CommandStrategy extends Strategy { // TODO: move v2 commands to here which are not data source v2 related. override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case CreateVariable(ident: ResolvedIdentifier, defaultExpr, replace) => - CreateVariableExec(ident.identifier.name, defaultExpr, replace) :: Nil + CreateVariableExec(ident, defaultExpr, replace) :: Nil case DropVariable(ident: ResolvedIdentifier, ifExists) => DropVariableExec(ident.identifier.name, ifExists) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 5eccf1bcee2f..68a5a60079e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkThrowable +import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody} import org.apache.spark.sql.classic.{DataFrame, SparkSession} @@ -25,7 +26,8 @@ import org.apache.spark.sql.classic.{DataFrame, SparkSession} /** * SQL scripting executor - executes script and returns result statements. * This supports returning multiple result statements from a single script. - * The caller of the SqlScriptingExecution API must adhere to the contract of executing + * The caller of the SqlScriptingExecution API must wrap the interpretation and execution of + * statements with the [[withLocalVariableManager]] method, and adhere to the contract of executing * the returned statement before continuing iteration. Executing the statement needs to be done * inside withErrorHandling block. * @@ -55,6 +57,16 @@ class SqlScriptingExecution( ctx } + private val variableManager = new SqlScriptingLocalVariableManager(context) + private val variableManagerHandle = SqlScriptingLocalVariableManager.create(variableManager) + + /** + * Handles scripting context creation/access/deletion. Calls to execution API must be wrapped + * with this method. + */ + def withLocalVariableManager[R](f: => R): R = { + variableManagerHandle.runWith(f) + } /** Helper method to iterate get next statements from the first available frame. */ private def getNextStatement: Option[CompoundStatementExec] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala index 2c001a196a8f..2682656cb7ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.scripting import java.util.Locale +import scala.collection.mutable import scala.collection.mutable.ListBuffer import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.catalog.VariableDefinition import org.apache.spark.sql.scripting.SqlScriptingFrameType.SqlScriptingFrameType /** @@ -48,6 +50,9 @@ class SqlScriptingExecutionContext { frames.last.exitScope(label) } + def currentFrame: SqlScriptingExecutionFrame = frames.last + def currentScope: SqlScriptingExecutionScope = currentFrame.currentScope + def findHandler(condition: String, sqlState: String): Option[ExceptionHandlerExec] = { if (frames.isEmpty) { throw SparkException.internalError(s"Cannot find handler: no frames.") @@ -127,6 +132,8 @@ class SqlScriptingExecutionFrame( } } + def currentScope: SqlScriptingExecutionScope = scopes.last + // TODO: Introduce a separate class for different frame types (Script, Stored Procedure, // Error Handler) implementing SqlScriptingExecutionFrame interface. def findHandler( @@ -170,6 +177,7 @@ class SqlScriptingExecutionFrame( class SqlScriptingExecutionScope( val label: String, val triggerToExceptionHandlerMap: TriggerToExceptionHandlerMap) { + val variables = new mutable.HashMap[String, VariableDefinition] /** * Finds the most appropriate error handler for exception based on its condition and SQL state. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 97f68b9ca52d..ce0876e8f629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -22,9 +22,9 @@ import java.util import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.analysis.{ExecuteImmediateQuery, NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} @@ -38,7 +38,6 @@ sealed trait CompoundStatementExec extends Logging { /** * Whether the statement originates from the SQL script or is created during the interpretation. - * Example: DropVariable statements are automatically created at the end of each compound. */ val isInternal: Boolean = false @@ -115,8 +114,7 @@ trait NonLeafStatementExec extends CompoundStatementExec { * A map of parameter names to SQL literal expressions. * @param isInternal * Whether the statement originates from the SQL script or it is created during the - * interpretation. Example: DropVariable statements are automatically created at the end of each - * compound. + * interpretation. * @param context * SqlScriptingExecutionContext keeps the execution state of current script. */ @@ -988,7 +986,10 @@ class ForStatementExec( } private def createDropVarExec(varName: String): SingleStatementExec = { - val dropVar = DropVariable(UnresolvedIdentifier(Seq(varName)), ifExists = true) + // As DROP TEMPORARY VARIABLE is forbidden within a script, use EXECUTE IMMEDIATE to bypass + // this limitation. This will be removed once FOR is updated to properly use local variables. + val dropVar = ExecuteImmediateQuery( + Seq.empty, Left("DROP TEMPORARY VARIABLE IF EXISTS " + varName), Seq.empty) new SingleStatementExec(dropVar, Origin(), Map.empty, isInternal = true, context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 2a446e4bbb25..919d63bb0c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.scripting import scala.collection.mutable.HashMap import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ExceptionHandlerType, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, ExceptionHandlerType, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.SqlScriptingErrors @@ -56,19 +55,6 @@ case class SqlScriptingInterpreter(session: SparkSession) { .asInstanceOf[CompoundBodyExec] } - /** - * Fetch the name of the Create Variable plan. - * @param plan - * Plan to fetch the name from. - * @return - * Name of the variable. - */ - private def getDeclareVarNameFromPlan(plan: LogicalPlan): Option[UnresolvedIdentifier] = - plan match { - case CreateVariable(name: UnresolvedIdentifier, _, _) => Some(name) - case _ => None - } - /** * Transform [[CompoundBody]] into [[CompoundBodyExec]]. * @@ -85,16 +71,6 @@ case class SqlScriptingInterpreter(session: SparkSession) { compoundBody: CompoundBody, args: Map[String, Expression], context: SqlScriptingExecutionContext): CompoundBodyExec = { - // Add drop variables to the end of the body. - val variables = compoundBody.collection.flatMap { - case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) - case _ => None - } - val dropVariables = variables - .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), args, isInternal = true, context)) - .reverse - // Map of conditions to their respective handlers. val conditionToExceptionHandlerMap: HashMap[String, ExceptionHandlerExec] = HashMap.empty // Map of SqlStates to their respective handlers. @@ -163,7 +139,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { notFoundHandler = notFoundHandler) val statements = compoundBody.collection - .map(st => transformTreeIntoExecutable(st, args, context)) ++ dropVariables match { + .map(st => transformTreeIntoExecutable(st, args, context)) match { case Nil => Seq(new NoOpStatementExec) case s => s } @@ -194,7 +170,6 @@ case class SqlScriptingInterpreter(session: SparkSession) { context: SqlScriptingExecutionContext): CompoundStatementExec = node match { case body: CompoundBody => - // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. transformBodyIntoExec(body, args, context) case IfElseStatement(conditions, conditionalBodies, elseBody) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingLocalVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingLocalVariableManager.scala new file mode 100644 index 000000000000..f875f2154a92 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingLocalVariableManager.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{FakeLocalCatalog, ResolvedIdentifier} +import org.apache.spark.sql.catalyst.catalog.{VariableDefinition, VariableManager} +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.errors.DataTypeErrorsBase +import org.apache.spark.sql.errors.QueryCompilationErrors.unresolvedVariableError + +class SqlScriptingLocalVariableManager(context: SqlScriptingExecutionContext) + extends VariableManager with DataTypeErrorsBase { + + override def create( + nameParts: Seq[String], + varDef: VariableDefinition, + overrideIfExists: Boolean): Unit = { + val name = nameParts.last + + // overrideIfExists should not be supported because local variables don't support + // DECLARE OR REPLACE. However ForStatementExec currently uses this to handle local vars, + // so we support it for now. + // TODO [SPARK-50785]: Refactor ForStatementExec to use local variables properly. + if (!overrideIfExists && context.currentScope.variables.contains(name)) { + throw new AnalysisException( + errorClass = "VARIABLE_ALREADY_EXISTS", + messageParameters = Map( + "variableName" -> toSQLId(Seq(context.currentScope.label, name)))) + } + context.currentScope.variables.put(name, varDef) + } + + override def set(nameParts: Seq[String], varDef: VariableDefinition): Unit = { + val scope = findScopeOfVariable(nameParts) + .getOrElse( + throw unresolvedVariableError(nameParts, varDef.identifier.namespace().toIndexedSeq)) + + if (!scope.variables.contains(nameParts.last)) { + throw unresolvedVariableError(nameParts, varDef.identifier.namespace().toIndexedSeq) + } + + scope.variables.put(nameParts.last, varDef) + } + + override def get(nameParts: Seq[String]): Option[VariableDefinition] = { + findScopeOfVariable(nameParts).flatMap(_.variables.get(nameParts.last)) + } + + private def findScopeOfVariable(nameParts: Seq[String]): Option[SqlScriptingExecutionScope] = { + // TODO: Update logic and comments once stored procedures are introduced. + def isScopeOfVar( + nameParts: Seq[String], + scope: SqlScriptingExecutionScope + ): Boolean = nameParts match { + case Seq(name) => scope.variables.contains(name) + // Qualified case. + case Seq(label, _) => scope.label == label + case _ => + throw SparkException.internalError("ScriptingVariableManager expects 1 or 2 nameParts.") + } + + // First search for variable in entire current frame. + val resCurrentFrame = context.currentFrame.scopes + .findLast(scope => isScopeOfVar(nameParts, scope)) + if (resCurrentFrame.isDefined) { + return resCurrentFrame + } + + // When searching in previous frames, for each frame we have to check only scopes before and + // including the scope where the previously checked exception handler frame is defined. + // Exception handler frames should not have access to variables from scopes + // which are nested below the scope where the handler is defined. + var previousFrameDefinitionLabel = context.currentFrame.scopeLabel + + // dropRight(1) removes the current frame, which we already checked above. + context.frames.dropRight(1).reverseIterator.foreach(frame => { + // Drop scopes until we encounter the scope in which the previously checked + // frame was defined. If it was not defined in this scope candidateScopes will be + // empty. + val candidateScopes = frame.scopes.reverse.dropWhile( + scope => !previousFrameDefinitionLabel.contains(scope.label)) + + val scope = candidateScopes.findLast(scope => isScopeOfVar(nameParts, scope)) + if (scope.isDefined) { + return scope + } + // If candidateScopes is nonEmpty that means that we found the previous frame definition + // in this frame. If we still have not found the variable, we now have to find the definition + // of this new frame, so we reassign the frame definition label to search for. + if (candidateScopes.nonEmpty) { + previousFrameDefinitionLabel = frame.scopeLabel + } + }) + None + } + + override def qualify(name: String): ResolvedIdentifier = + ResolvedIdentifier(FakeLocalCatalog, Identifier.of(Array(context.currentScope.label), name)) + + override def remove(nameParts: Seq[String]): Boolean = { + throw SparkException.internalError( + "ScriptingVariableManager.remove should never be called as local variables cannot be dropped." + ) + } + + override def clear(): Unit = context.frames.clear() + + // Empty if all scopes of all frames in the script context contain no variables. + override def isEmpty: Boolean = context.frames.forall(_.scopes.forall(_.variables.isEmpty)) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala index 8cd071ad126b..b12114041d0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala @@ -134,7 +134,6 @@ class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession { |BEGIN | DECLARE x INT; | SET x = 1; - | DROP TEMPORARY VARIABLE x; |END |""".stripMargin verifySqlScriptResult(sqlScript, Seq.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 1393aff69c43..f83ae87290ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.ExecuteImmediateQuery import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LeafNode, OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin @@ -159,6 +160,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case forStmt: TestForStatement => forStmt.label.get case dropStmt: SingleStatementExec if dropStmt.parsedPlan.isInstanceOf[DropVariable] => "DropVariable" + case execImm: SingleStatementExec if execImm.parsedPlan.isInstanceOf[ExecuteImmediateQuery] + => "ExecuteImmediate" case _ => fail("Unexpected statement type") } @@ -759,8 +762,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", - "DropVariable", // drop for query var intCol - "DropVariable" // drop for loop var x + "ExecuteImmediate", // drop for query var intCol + "ExecuteImmediate" // drop for loop var x )) } @@ -782,8 +785,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "statement2", "statement1", "statement2", - "DropVariable", // drop for query var intCol - "DropVariable" // drop for loop var x + "ExecuteImmediate", // drop for query var intCol + "ExecuteImmediate" // drop for loop var x )) } @@ -824,14 +827,14 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi assert(statements === Seq( "body", "body", - "DropVariable", // drop for query var intCol1 - "DropVariable", // drop for loop var y + "ExecuteImmediate", // drop for query var intCol1 + "ExecuteImmediate", // drop for loop var y "body", "body", - "DropVariable", // drop for query var intCol1 - "DropVariable", // drop for loop var y - "DropVariable", // drop for query var intCol - "DropVariable" // drop for loop var x + "ExecuteImmediate", // drop for query var intCol1 + "ExecuteImmediate", // drop for loop var y + "ExecuteImmediate", // drop for query var intCol + "ExecuteImmediate" // drop for loop var x )) } @@ -848,7 +851,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", - "DropVariable" // drop for query var intCol + "ExecuteImmediate" // drop for query var intCol )) } @@ -867,7 +870,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "statement1", "statement2", "statement1", "statement2", - "DropVariable" // drop for query var intCol + "ExecuteImmediate" // drop for query var intCol )) } @@ -906,10 +909,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", "body", - "DropVariable", // drop for query var intCol1 + "ExecuteImmediate", // drop for query var intCol1 "body", "body", - "DropVariable", // drop for query var intCol1 - "DropVariable" // drop for query var intCol + "ExecuteImmediate", // drop for query var intCol1 + "ExecuteImmediate" // drop for query var intCol )) } @@ -932,8 +935,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "lbl1", "statement1", "lbl1", - "DropVariable", // drop for query var intCol - "DropVariable" // drop for loop var x + "ExecuteImmediate", // drop for query var intCol + "ExecuteImmediate" // drop for loop var x )) } @@ -984,8 +987,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "outer_body", "body1", "lbl1", - "DropVariable", // drop for query var intCol - "DropVariable" // drop for loop var x + "ExecuteImmediate", // drop for query var intCol + "ExecuteImmediate" // drop for loop var x )) } @@ -1030,7 +1033,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "statement1", "lbl1", "statement1", "lbl1", - "DropVariable" // drop for query var intCol + "ExecuteImmediate" // drop for query var intCol )) } @@ -1076,7 +1079,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "outer_body", "body1", "lbl1", "outer_body", "body1", "lbl1", - "DropVariable" // drop for query var intCol + "ExecuteImmediate" // drop for query var intCol )) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 219094e8d217..3af4a539d2a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -48,18 +48,21 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { sqlText: String, args: Map[String, Expression] = Map.empty): Seq[Array[Row]] = { val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] - val sse = new SqlScriptingExecution(compoundBody, spark, args) - val result: ListBuffer[Array[Row]] = ListBuffer.empty - var df = sse.getNextResult - while (df.isDefined) { - // Collect results from the current DataFrame. - sse.withErrorHandling { - result.append(df.get.collect()) + val sse = new SqlScriptingExecution(compoundBody, spark, args) + sse.withLocalVariableManager { + val result: ListBuffer[Array[Row]] = ListBuffer.empty + + var df = sse.getNextResult + while (df.isDefined) { + // Collect results from the current DataFrame. + sse.withErrorHandling { + result.append(df.get.collect()) + } + df = sse.getNextResult } - df = sse.getNextResult + result.toSeq } - result.toSeq } private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { @@ -715,20 +718,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScript, expected) } - test("session vars - drop var statement") { - val sqlScript = - """ - |BEGIN - |DECLARE var = 1; - |SET VAR var = var + 1; - |SELECT var; - |DROP TEMPORARY VARIABLE var; - |END - |""".stripMargin - val expected = Seq(Seq(Row(2))) - verifySqlScriptResult(sqlScript, expected) - } - test("if") { val commands = """ @@ -1623,4 +1612,891 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScriptText, expected) } + + test("local variable") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select localVar + Seq(Row(1)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - nested compounds") { + val sqlScript = + """ + |BEGIN + | lbl1: BEGIN + | DECLARE localVar = 1; + | lbl2: BEGIN + | DECLARE localVar = 2; + | SELECT localVar; + | SELECT lbl1.localVar; + | SELECT lbl2.localVar; + | END; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(2)), // select localVar + Seq(Row(1)), // select lbl1.localVar + Seq(Row(2)) // select lbl2.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - resolved over session variable") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 1") + + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - resolved over session variable nested") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 1") + + val sqlScript = + """ + |BEGIN + | SELECT localVar; + | lbl: BEGIN + | DECLARE localVar = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - session variable resolved over local if fully qualified") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 1") + + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 5; + | SELECT system.session.localVar; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select system.session.localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - session variable resolved over local if qualified with session.varname") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 1") + + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 5; + | SELECT session.localVar; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select system.session.localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - case insensitive name") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LOCALVAR; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)) // select LOCALVAR + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - case sensitive name") { + val e = intercept[AnalysisException] { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> true.toString) { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LOCALVAR; + | END; + |END + |""".stripMargin + + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + } + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("LOCALVAR")), + context = ExpectedContext( + fragment = "LOCALVAR", + start = 57, + stop = 64) + ) + } + + test("local variable - case insensitive label") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LBL.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)) // select LBL.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - case sensitive label") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LBL.localVar; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> true.toString) { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + } + + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("LBL.localVar")), + context = ExpectedContext( + fragment = "LBL.localVar", + start = 57, + stop = 68) + ) + } + + test("local variable - leaves scope unqualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | END; + | SELECT localVar; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("localVar")), + context = ExpectedContext(fragment = "localVar", start = 83, stop = 90) + ) + } + + test("local variable - leaves scope qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | END; + | SELECT lbl.localVar; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("lbl.localVar")), + context = ExpectedContext(fragment = "lbl.localVar", start = 87, stop = 98) + ) + } + + test("local variable - leaves inner scope") { + val sqlScript = + """ + |BEGIN + | DECLARE localVar = 1; + | lbl: BEGIN + | DECLARE localVar = 2; + | SELECT localVar; + | END; + | SELECT localVar; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(2)), // select localVar + Seq(Row(1)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - inner inner scope -> inner scope -> session var") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 0") + + val sqlScript = + """ + |BEGIN + | lbl1: BEGIN + | DECLARE localVar = 1; + | lbl: BEGIN + | DECLARE localVar = 2; + | SELECT localVar; + | END; + | SELECT localVar; + | END; + | SELECT localVar; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(2)), // select localVar + Seq(Row(1)), // select localVar + Seq(Row(0)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - declare - qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE lbl.localVar = 1; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", + sqlState = "42K0M", + parameters = Map("varName" -> toSQLId("lbl.localVar")) + ) + } + + // TODO [SPARK-50785]: Unignore this when For Statement starts properly using local vars. + ignore("local variable - declare - declare or replace") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE OR REPLACE localVar = 1; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "INVALID_VARIABLE_DECLARATION.REPLACE_LOCAL_VARIABLE", + sqlState = "42K0M", + parameters = Map("varName" -> toSQLId("localVar")) + ) + } + + test("local variable - declare - duplicate names") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | DECLARE localVar = 2; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "VARIABLE_ALREADY_EXISTS", + sqlState = "42723", + parameters = Map("variableName" -> toSQLId("lbl.localvar")) + ) + } + + // Variables cannot be dropped within SQL scripts. + test("local variable - drop") { + val sqlScript = + """ + |BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | DROP TEMPORARY VARIABLE localVar; + |END + |""".stripMargin + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + checkError( + exception = e, + condition = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_DROP_TEMPORARY_VARIABLE", + parameters = Map.empty + ) + } + + test("drop variable - drop - too many nameparts") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | DROP TEMPORARY VARIABLE a.b.c.d; + | END; + |END + |""".stripMargin + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + checkError( + exception = e, + condition = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_DROP_TEMPORARY_VARIABLE", + parameters = Map.empty + ) + } + + test("local variable - drop session variable without EXECUTE IMMEDIATE") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 0") + + val sqlScript = + """ + |BEGIN + | DECLARE localVar = 1; + | SELECT system.session.localVar; + | DROP TEMPORARY VARIABLE system.session.localVar; + | SELECT system.session.localVar; + |END + |""".stripMargin + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + checkError( + exception = e, + condition = "UNSUPPORTED_FEATURE.SQL_SCRIPTING_DROP_TEMPORARY_VARIABLE", + parameters = Map.empty + ) + } + } + + test("local variable - drop session variable with EXECUTE IMMEDIATE") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 0") + + val sqlScript = + """ + |BEGIN + | DECLARE localVar = 1; + | SELECT system.session.localVar; + | EXECUTE IMMEDIATE 'DROP TEMPORARY VARIABLE localVar'; + | SELECT system.session.localVar; + |END + |""".stripMargin + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("system.session.localVar")), + context = ExpectedContext( + fragment = "system.session.localVar", + start = 130, + stop = 152) + ) + } + } + + test("local variable - set - qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | SET lbl.localVar = 5; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select lbl.localVar + Seq(Row(5)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set - unqualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | SET localVar = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set - set unqualified select qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | SET localVar = 5; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select lbl.localVar + Seq(Row(5)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set - nested") { + val sqlScript = + """ + |BEGIN + | lbl1: BEGIN + | DECLARE localVar = 1; + | lbl2: BEGIN + | DECLARE localVar = 2; + | SELECT localVar; + | SELECT lbl1.localVar; + | SELECT lbl2.localVar; + | SET lbl1.localVar = 5; + | SELECT lbl1.localVar; + | END; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(2)), // select localVar + Seq(Row(1)), // select lbl1.localVar + Seq(Row(2)), // select lbl2.localVar + Seq(Row(5)) // select lbl1.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set - case insensitive name") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SET LOCALVAR = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set - case sensitive name") { + val e = intercept[AnalysisException] { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> true.toString) { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SET LOCALVAR = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + } + + checkError( + exception = e, + condition = "UNRESOLVED_VARIABLE", + sqlState = "42883", + parameters = Map( + "variableName" -> toSQLId("LOCALVAR"), + "searchPath" -> toSQLId("SYSTEM.SESSION")) + ) + } + + test("local variable - set - session variable") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 0").collect() + + val sqlScript = + """ + |BEGIN + | SELECT localVar; + | SET localVar = 1; + | SELECT localVar; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(0)), // select localVar + Seq(Row(1)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - set - duplicate assignment") { + val sqlScript = + """ + |BEGIN + | lbl1: BEGIN + | DECLARE localVar = 1; + | lbl2: BEGIN + | SELECT localVar; + | SET (localVar, lbl1.localVar) = (select 1, 2); + | END; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "DUPLICATE_ASSIGNMENTS", + sqlState = "42701", + parameters = Map("nameList" -> toSQLId("localvar")) + ) + } + + test("local variable - set - no duplicate assignment error with session var") { + withSessionVariable("localVar") { + spark.sql("DECLARE localVar = 0") + + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SET (localVar, session.localVar) = (select 2, 3); + | SELECT localVar; + | SELECT session.localVar; + | END; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(2)), // select localVar + Seq(Row(3)) // select session.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - set - duplicate assignment of session vars") { + withSessionVariable("sessionVar") { + spark.sql("DECLARE sessionVar = 0") + + val sqlScript = + """ + |BEGIN + | lbl1: BEGIN + | SET (sessionVar, session.sessionVar) = (select 1, 2); + | END; + |END + |""".stripMargin + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "DUPLICATE_ASSIGNMENTS", + sqlState = "42701", + parameters = Map("nameList" -> toSQLId("sessionvar")) + ) + } + } + + test("local variable - execute immediate using local var") { + withSessionVariable("testVar") { + spark.sql("DECLARE testVar = 0") + val sqlScript = + """ + |BEGIN + | DECLARE testVar = 5; + | EXECUTE IMMEDIATE 'SELECT ?' USING testVar; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)) // select testVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - execute immediate into local var") { + val sqlScript = + """ + |BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | EXECUTE IMMEDIATE 'SELECT 5' INTO localVar; + | SELECT localVar; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(1)), // select localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - execute immediate can't access local var") { + val sqlScript = + """ + |BEGIN + | DECLARE localVar = 5; + | EXECUTE IMMEDIATE 'SELECT localVar'; + |END + |""".stripMargin + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("localVar")), + context = ExpectedContext( + fragment = "localVar", + start = 7, + stop = 14) + ) + } + + test("local variable - execute immediate create session var") { + withSessionVariable("sessionVar") { + val sqlScript = + """ + |BEGIN + | EXECUTE IMMEDIATE 'DECLARE sessionVar = 5'; + | SELECT system.session.sessionVar; + | SELECT sessionVar; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)), // select system.session.sessionVar + Seq(Row(5)) // select sessionVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - execute immediate create qualified session var") { + withSessionVariable("sessionVar") { + val sqlScript = + """ + |BEGIN + | EXECUTE IMMEDIATE 'DECLARE system.session.sessionVar = 5'; + | SELECT system.session.sessionVar; + | SELECT sessionVar; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)), // select system.session.sessionVar + Seq(Row(5)) // select sessionVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - execute immediate set session var") { + withSessionVariable("testVar") { + spark.sql("DECLARE testVar = 0") + val sqlScript = + """ + |BEGIN + | DECLARE testVar = 1; + | SELECT system.session.testVar; + | SELECT testVar; + | EXECUTE IMMEDIATE 'SET VARIABLE testVar = 5'; + | SELECT system.session.testVar; + | SELECT testVar; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(0)), // select system.session.testVar + Seq(Row(1)), // select testVar + Seq(Row(5)), // select system.session.testVar + Seq(Row(1)) // select testVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - handlers - triple chained handlers") { + val sqlScript = + """ + |BEGIN + | DECLARE OR REPLACE VARIABLE varOuter INT = 0; + | l1: BEGIN + | DECLARE OR REPLACE VARIABLE varL1 INT = 1; + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT varOuter; + | SELECT varL1; + | END; + | l2: BEGIN + | DECLARE OR REPLACE VARIABLE varL2 = 2; + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT varOuter; + | SELECT varL1; + | SELECT varL2; + | SELECT 1/0; + | END; + | l3: BEGIN + | DECLARE OR REPLACE VARIABLE varL3 = 3; + | DECLARE EXIT HANDLER FOR SQLEXCEPTION + | BEGIN + | SELECT varOuter; + | SELECT varL1; + | SELECT varL2; + | SELECT varL3; + | SELECT 1/0; + | END; + + | SELECT 5; + | SELECT 1/0; + | SELECT 6; + | END; + | END; + | END; + |END + |""".stripMargin + val expected = Seq( + Seq(Row(5)), + Seq(Row(0)), + Seq(Row(1)), + Seq(Row(2)), + Seq(Row(3)), + Seq(Row(0)), + Seq(Row(1)), + Seq(Row(2)), + Seq(Row(0)), + Seq(Row(1)) + ) + verifySqlScriptResult(sqlScript, expected = expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index a52f93d4fc80..30efac0737dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.{SparkConf, SparkException, SparkNumberFormatException} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.catalyst.{QueryPlanningTracker, SqlScriptingLocalVariableManager} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.CompoundBody import org.apache.spark.sql.classic.{DataFrame, Dataset} @@ -54,15 +54,19 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { executionPlan, SqlScriptingFrameType.SQL_SCRIPT)) executionPlan.enterScope() - executionPlan.getTreeIterator.flatMap { - case statement: SingleStatementExec => - if (statement.isExecuted) { - None - } else { - Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) - } - case _ => None - }.toArray + val handle = + SqlScriptingLocalVariableManager.create(new SqlScriptingLocalVariableManager(context)) + handle.runWith { + executionPlan.getTreeIterator.flatMap { + case statement: SingleStatementExec => + if (statement.isExecuted) { + None + } else { + Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) + } + case _ => None + }.toArray + } } private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { @@ -185,8 +189,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // declare var Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row] // drop var + Seq(Row(2)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -203,8 +206,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // declare var Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row] // drop var + Seq(Row(2)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -231,14 +233,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // declare var Seq(Row(1)), // select - Seq.empty[Row], // drop var Seq.empty[Row], // declare var Seq(Row(2)), // select - Seq.empty[Row], // drop var Seq.empty[Row], // declare var Seq.empty[Row], // set var - Seq(Row(4)), // select - Seq.empty[Row] // drop var + Seq(Row(4)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -270,26 +269,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { ) } - test("session vars - drop var statement") { - val sqlScript = - """ - |BEGIN - |DECLARE var = 1; - |SET VAR var = var + 1; - |SELECT var; - |DROP TEMPORARY VARIABLE var; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row], // drop var - explicit - Seq.empty[Row] // drop var - implicit - ) - verifySqlScriptResult(sqlScript, expected) - } - test("if") { val commands = """ @@ -1017,8 +996,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select i Seq.empty[Row], // set i Seq(Row(2)), // select i - Seq.empty[Row], // set i - Seq.empty[Row] // drop var + Seq.empty[Row] // set i ) verifySqlScriptResult(commands, expected) } @@ -1036,8 +1014,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare i - Seq.empty[Row] // drop i + Seq.empty[Row] // declare i ) verifySqlScriptResult(commands, expected) } @@ -1073,9 +1050,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // increase j Seq(Row(1, 1)), // select i, j Seq.empty[Row], // increase j - Seq.empty[Row], // increase i - Seq.empty[Row], // drop j - Seq.empty[Row] // drop i + Seq.empty[Row] // increase i ) verifySqlScriptResult(commands, expected) } @@ -1125,8 +1100,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select i Seq.empty[Row], // set i Seq(Row(2)), // select i - Seq.empty[Row], // set i - Seq.empty[Row] // drop var + Seq.empty[Row] // set i ) verifySqlScriptResult(commands, expected) } @@ -1148,8 +1122,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // declare i Seq(Row(3)), // select i - Seq.empty[Row], // set i - Seq.empty[Row] // drop i + Seq.empty[Row] // set i ) verifySqlScriptResult(commands, expected) } @@ -1223,9 +1196,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // increase j Seq(Row(1, 1)), // select i, j Seq.empty[Row], // increase j - Seq.empty[Row], // increase i - Seq.empty[Row], // drop j - Seq.empty[Row] // drop i + Seq.empty[Row] // increase i ) verifySqlScriptResult(commands, expected) } @@ -1411,8 +1382,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 0 Seq.empty[Row], // set x = 1 Seq.empty[Row], // set x = 2 - Seq(Row(2)), // select - Seq.empty[Row] // drop + Seq(Row(2)) // select ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1436,8 +1406,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 0 Seq.empty[Row], // set x = 1 Seq.empty[Row], // set x = 2 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1534,8 +1503,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select 1 Seq.empty[Row], // set x = 2 Seq(Row(1)), // select 1 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1568,8 +1536,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 2 Seq(Row(1)), // select 1 Seq(Row(2)), // select 2 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1598,8 +1565,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select 1 Seq.empty[Row], // set x = 2 Seq(Row(1)), // select 1 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1629,8 +1595,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(2)), // select x Seq.empty[Row], // set x = 3 Seq(Row(3)), // select x - Seq(Row(3)), // select x - Seq.empty[Row] // drop + Seq(Row(3)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1672,9 +1637,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // increase y Seq(Row(1, 1)), // select x, y Seq.empty[Row], // increase y - Seq.empty[Row], // increase x - Seq.empty[Row], // drop y - Seq.empty[Row] // drop x + Seq.empty[Row] // increase x ) verifySqlScriptResult(commands, expected) } @@ -1700,8 +1663,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 0 Seq.empty[Row], // set x = 1 Seq.empty[Row], // set x = 2 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1750,8 +1712,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 2 Seq(Row(1)), // select 1 Seq.empty[Row], // set x = 3 - Seq(Row(3)), // select x - Seq.empty[Row] // drop + Seq(Row(3)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1851,8 +1812,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set sumOfCols Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq(Row(10)), // select sumOfCols - Seq.empty[Row] // drop sumOfCols + Seq(Row(10)) // select sumOfCols ) verifySqlScriptResult(sqlScript, expected) } @@ -2174,8 +2134,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select intCol Seq.empty[Row], // insert Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var - Seq.empty[Row] // drop cnt + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2263,8 +2222,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select intCol Seq.empty[Row], // insert Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var - Seq.empty[Row] // drop cnt + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2461,8 +2419,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set sumOfCols Seq.empty[Row], // set sumOfCols Seq.empty[Row], // drop local var - Seq(Row(10)), // select sumOfCols - Seq.empty[Row] // drop sumOfCols + Seq(Row(10)) // select sumOfCols ) verifySqlScriptResult(sqlScript, expected) } @@ -2693,8 +2650,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(0)), // select intCol Seq(Row(1)), // select intCol Seq.empty[Row], // insert - Seq.empty[Row], // drop local var - Seq.empty[Row] // drop cnt + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2767,8 +2723,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(0)), // select intCol Seq(Row(1)), // select intCol Seq.empty[Row], // insert - Seq.empty[Row], // drop local var - Seq.empty[Row] // drop cnt + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 3ceffc74adc2..f0f3f94b811f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -423,6 +423,17 @@ private[sql] trait SQLTestUtilsBase } } + /** + * Drops temporary variable `variableName` after calling `f`. + */ + protected def withSessionVariable(variableNames: String*)(f: => Unit): Unit = { + Utils.tryWithSafeFinally(f) { + variableNames.foreach { name => + spark.sql(s"DROP TEMPORARY VARIABLE IF EXISTS $name") + } + } + } + /** * Activates database `db` before executing `f`, then switches back to `default` database after * `f` returns. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org