This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 031ce8124591 [SPARK-52221][SQL] Refactor
SqlScriptingLocalVariableManager into more generic context manager
031ce8124591 is described below
commit 031ce81245916b8dd76731ff0ee09669360a194b
Author: David Milicevic <[email protected]>
AuthorDate: Tue May 20 13:15:45 2025 +0800
[SPARK-52221][SQL] Refactor SqlScriptingLocalVariableManager into more
generic context manager
### What changes were proposed in this pull request?
Replacing `SqlScriptingLocalVariableManager` Thread Local variable with
more generic `SqlScriptingContextManager`.
Newly introduced `SqlScriptingContextManager` encapsulates local variable
manager, but also includes `SqlScriptingExecutionContext` information. These
information will be required in the future, for example, to implement Stored
Procedure calls - since the access for call stack related operations will be
needed.
### Why are the changes needed?
These are refactor changes. The refactor is needed to better support future
implementations, as explained in JIRA ticket.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
This is a refactor PR. Already existing test cases ensure that nothing is
broken.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50938 from davidm-db/scripting_context_manager.
Authored-by: David Milicevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
...Manager.scala => SqlScriptingContextManager.scala} | 7 ++++---
.../catalyst/analysis/ColumnResolutionHelper.scala | 4 ++--
.../spark/sql/catalyst/analysis/ResolveCatalogs.scala | 7 ++++---
.../catalyst/analysis/resolver/ResolverGuard.scala | 4 ++--
.../SqlScriptingContextManager.scala} | 18 +++++++++++++-----
.../SqlScriptingExecutionContextExtension.scala} | 12 +++++-------
.../sql/execution/command/v2/CreateVariableExec.scala | 4 ++--
.../sql/execution/command/v2/SetVariableExec.scala | 4 ++--
.../scripting/SqlScriptingContextManagerImpl.scala} | 19 ++++++++++++++-----
.../spark/sql/scripting/SqlScriptingExecution.scala | 14 +++++++-------
.../sql/scripting/SqlScriptingExecutionContext.scala | 4 ++--
.../sql/scripting/SqlScriptingExecutionSuite.scala | 2 +-
.../sql/scripting/SqlScriptingInterpreterSuite.scala | 4 ++--
13 files changed, 60 insertions(+), 43 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingContextManager.scala
similarity index 76%
copy from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
copy to
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingContextManager.scala
index a3eef28d372a..ffcb81c32340 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingContextManager.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst
-import org.apache.spark.sql.catalyst.catalog.VariableManager
+import org.apache.spark.sql.catalyst.catalog.SqlScriptingContextManager
import org.apache.spark.util.LexicalThreadLocal
-object SqlScriptingLocalVariableManager extends
LexicalThreadLocal[VariableManager] {
- def create(variableManager: VariableManager): Handle =
createHandle(Option(variableManager))
+object SqlScriptingContextManager extends
LexicalThreadLocal[SqlScriptingContextManager] {
+ def create(contextManager: SqlScriptingContextManager): Handle =
+ createHandle(Option(contextManager))
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index b2e068fd990b..6823cdbf36ba 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager
+import org.apache.spark.sql.catalyst.SqlScriptingContextManager
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.SubExprUtils.wrapOuterReference
import
org.apache.spark.sql.catalyst.parser.SqlScriptingLabelContext.isForbiddenLabelName
@@ -260,7 +260,7 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
nameParts.map(_.toLowerCase(Locale.ROOT))
}
- SqlScriptingLocalVariableManager.get()
+ SqlScriptingContextManager.get().map(_.getVariableManager)
// If we are in EXECUTE IMMEDIATE lookup only session variables.
.filterNot(_ => AnalysisContext.get.isExecuteImmediate)
// If variable name is qualified with session.<varName> treat it as a
session variable.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
index 41da66c9ceaa..72d92e5a9445 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
@@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager
+import org.apache.spark.sql.catalyst.SqlScriptingContextManager
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -57,7 +57,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
Map("varName" -> toSQLId(nameParts)))
}
- SqlScriptingLocalVariableManager.get()
+ SqlScriptingContextManager.get().map(_.getVariableManager)
.getOrElse(throw SparkException.internalError(
"Scripting local variable manager should be present in SQL
script."))
.qualify(nameParts.last)
@@ -137,7 +137,8 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
}
private def withinSqlScript: Boolean =
- SqlScriptingLocalVariableManager.get().isDefined &&
!AnalysisContext.get.isExecuteImmediate
+ SqlScriptingContextManager.get().map(_.getVariableManager).isDefined &&
+ !AnalysisContext.get.isExecuteImmediate
private def assertValidSessionVariableNameParts(
nameParts: Seq[String],
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
index dc8bcaf5115e..d1a42ee1b080 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.{
FunctionIdentifier,
SQLConfHelper,
- SqlScriptingLocalVariableManager
+ SqlScriptingContextManager
}
import org.apache.spark.sql.catalyst.analysis.{
FunctionRegistry,
@@ -457,7 +457,7 @@ class ResolverGuard(catalogManager: CatalogManager) extends
SQLConfHelper {
catalogManager.tempVariableManager.isEmpty
private def checkScriptingVariables() =
- SqlScriptingLocalVariableManager.get().forall(_.isEmpty)
+
SqlScriptingContextManager.get().map(_.getVariableManager).forall(_.isEmpty)
private def tryThrowUnsupportedSinglePassAnalyzerFeature(operator:
LogicalPlan): Unit = {
tryThrowUnsupportedSinglePassAnalyzerFeature(s"${operator.getClass}
operator resolution")
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SqlScriptingContextManager.scala
similarity index 70%
copy from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
copy to
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SqlScriptingContextManager.scala
index a3eef28d372a..dcf38558afb7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SqlScriptingContextManager.scala
@@ -15,11 +15,19 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst
+package org.apache.spark.sql.catalyst.catalog
-import org.apache.spark.sql.catalyst.catalog.VariableManager
-import org.apache.spark.util.LexicalThreadLocal
+/**
+ * Trait which provides an interface for SQL scripting context manager.
+ */
+trait SqlScriptingContextManager {
+ /**
+ * Get execution context
+ */
+ def getContext: SqlScriptingExecutionContextExtension
-object SqlScriptingLocalVariableManager extends
LexicalThreadLocal[VariableManager] {
- def create(variableManager: VariableManager): Handle =
createHandle(Option(variableManager))
+ /**
+ * Get variable manager
+ */
+ def getVariableManager: VariableManager
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SqlScriptingExecutionContextExtension.scala
similarity index 70%
copy from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
copy to
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SqlScriptingExecutionContextExtension.scala
index a3eef28d372a..9a896d5e1c49 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SqlScriptingExecutionContextExtension.scala
@@ -15,11 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst
+package org.apache.spark.sql.catalyst.catalog
-import org.apache.spark.sql.catalyst.catalog.VariableManager
-import org.apache.spark.util.LexicalThreadLocal
-
-object SqlScriptingLocalVariableManager extends
LexicalThreadLocal[VariableManager] {
- def create(variableManager: VariableManager): Handle =
createHandle(Option(variableManager))
-}
+/**
+ * Trait which provides an interface extension for SQL scripting execution
context.
+ */
+trait SqlScriptingExecutionContextExtension {}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala
index f7ad62c2e1e3..1b9c1711853c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2
import java.util.Locale
-import org.apache.spark.sql.catalyst.{InternalRow,
SqlScriptingLocalVariableManager}
+import org.apache.spark.sql.catalyst.{InternalRow, SqlScriptingContextManager}
import org.apache.spark.sql.catalyst.analysis.{FakeLocalCatalog,
ResolvedIdentifier}
import org.apache.spark.sql.catalyst.catalog.VariableDefinition
import org.apache.spark.sql.catalyst.expressions.{Attribute,
ExpressionsEvaluator, Literal}
@@ -36,7 +36,7 @@ case class CreateVariableExec(
replace: Boolean) extends LeafV2CommandExec with ExpressionsEvaluator {
override protected def run(): Seq[InternalRow] = {
- val scriptingVariableManager = SqlScriptingLocalVariableManager.get()
+ val scriptingVariableManager =
SqlScriptingContextManager.get().map(_.getVariableManager)
val tempVariableManager =
session.sessionState.catalogManager.tempVariableManager
val exprs = prepareExpressions(Seq(defaultExpr.child),
subExprEliminationEnabled = false)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala
index 022b8c5869e7..a5b64736a618 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command.v2
import java.util.Locale
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.{InternalRow,
SqlScriptingLocalVariableManager}
+import org.apache.spark.sql.catalyst.{InternalRow, SqlScriptingContextManager}
import org.apache.spark.sql.catalyst.analysis.{FakeLocalCatalog,
FakeSystemCatalog}
import org.apache.spark.sql.catalyst.catalog.VariableDefinition
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal,
VariableReference}
@@ -66,7 +66,7 @@ case class SetVariableExec(variables: Seq[VariableReference],
query: SparkPlan)
}
val tempVariableManager =
session.sessionState.catalogManager.tempVariableManager
- val scriptingVariableManager = SqlScriptingLocalVariableManager.get()
+ val scriptingVariableManager =
SqlScriptingContextManager.get().map(_.getVariableManager)
val variableManager = variable.catalog match {
case FakeLocalCatalog if scriptingVariableManager.isEmpty =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingContextManagerImpl.scala
similarity index 60%
rename from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
rename to
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingContextManagerImpl.scala
index a3eef28d372a..884269d85f99 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlScriptingLocalVariableManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingContextManagerImpl.scala
@@ -15,11 +15,20 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst
+package org.apache.spark.sql.scripting
-import org.apache.spark.sql.catalyst.catalog.VariableManager
-import org.apache.spark.util.LexicalThreadLocal
+import org.apache.spark.sql.catalyst.catalog.{
+ SqlScriptingContextManager,
+ SqlScriptingExecutionContextExtension,
+ VariableManager
+}
+
+class SqlScriptingContextManagerImpl(context: SqlScriptingExecutionContext)
+ extends SqlScriptingContextManager {
+
+ private val variableManager = new SqlScriptingLocalVariableManager(context)
+
+ override def getContext: SqlScriptingExecutionContextExtension = context
-object SqlScriptingLocalVariableManager extends
LexicalThreadLocal[VariableManager] {
- def create(variableManager: VariableManager): Handle =
createHandle(Option(variableManager))
+ override def getVariableManager: VariableManager = variableManager
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
index ee72e6c358bb..362f3e51d3df 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.SparkThrowable
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.SqlScriptingLocalVariableManager
+import org.apache.spark.sql.catalyst.SqlScriptingContextManager
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult,
CompoundBody, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types.StructType
* SQL scripting executor - executes script and returns result statements.
* This supports returning multiple result statements from a single script.
* The caller of the SqlScriptingExecution API must wrap the interpretation
and execution of
- * statements with the [[withLocalVariableManager]] method, and adhere to the
contract of executing
+ * statements with the [[withContextManager]] method, and adhere to the
contract of executing
* the returned statement before continuing iteration. Executing the statement
needs to be done
* inside withErrorHandling block.
*
@@ -60,15 +60,15 @@ class SqlScriptingExecution(
ctx
}
- private val variableManager = new SqlScriptingLocalVariableManager(context)
- private val variableManagerHandle =
SqlScriptingLocalVariableManager.create(variableManager)
+ private val contextManager = new SqlScriptingContextManagerImpl(context)
+ private val contextManagerHandle =
SqlScriptingContextManager.create(contextManager)
/**
* Handles scripting context creation/access/deletion. Calls to execution
API must be wrapped
* with this method.
*/
- def withLocalVariableManager[R](f: => R): R = {
- variableManagerHandle.runWith(f)
+ def withContextManager[R](f: => R): R = {
+ contextManagerHandle.runWith(f)
}
/** Helper method to iterate get next statements from the first available
frame. */
@@ -197,7 +197,7 @@ object SqlScriptingExecution {
script: CompoundBody,
args: Map[String, Expression] = Map.empty): LogicalPlan = {
val sse = new SqlScriptingExecution(script, session, args)
- sse.withLocalVariableManager {
+ sse.withContextManager {
var result: Option[Seq[Row]] = None
// We must execute returned df before calling sse.getNextResult again
because sse.hasNext
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
index 2682656cb7ad..bfd5a4b43711 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala
@@ -23,13 +23,13 @@ import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.catalog.VariableDefinition
+import
org.apache.spark.sql.catalyst.catalog.{SqlScriptingExecutionContextExtension,
VariableDefinition}
import
org.apache.spark.sql.scripting.SqlScriptingFrameType.SqlScriptingFrameType
/**
* SQL scripting execution context - keeps track of the current execution
state.
*/
-class SqlScriptingExecutionContext {
+class SqlScriptingExecutionContext extends
SqlScriptingExecutionContextExtension {
// List of frames that are currently active.
private[scripting] val frames: ListBuffer[SqlScriptingExecutionFrame] =
ListBuffer.empty
private[scripting] var firstHandlerScopeLabel: Option[String] = None
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
index 3fe040b28d78..68c6d9607d32 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
@@ -50,7 +50,7 @@ class SqlScriptingExecutionSuite extends QueryTest with
SharedSparkSession {
val compoundBody =
spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody]
val sse = new SqlScriptingExecution(compoundBody, spark, args)
- sse.withLocalVariableManager {
+ sse.withContextManager {
val result: ListBuffer[Array[Row]] = ListBuffer.empty
var df = sse.getNextResult
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index 30116f7a0bf8..1c0f0ced5d49 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.{SparkConf, SparkException, SparkNumberFormatException}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
-import org.apache.spark.sql.catalyst.{QueryPlanningTracker,
SqlScriptingLocalVariableManager}
+import org.apache.spark.sql.catalyst.{QueryPlanningTracker,
SqlScriptingContextManager}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.CompoundBody
import org.apache.spark.sql.classic.{DataFrame, Dataset}
@@ -55,7 +55,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
executionPlan.enterScope()
val handle =
- SqlScriptingLocalVariableManager.create(new
SqlScriptingLocalVariableManager(context))
+ SqlScriptingContextManager.create(new
SqlScriptingContextManagerImpl(context))
handle.runWith {
executionPlan.getTreeIterator.flatMap {
case statement: SingleStatementExec =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]