This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 604bb34b26dc [SPARK-52034][SQL] Add common methods in SparkOperation 
trait for thriftserver operations
604bb34b26dc is described below

commit 604bb34b26dcd3462c888782620a3d5c6a31884f
Author: Kent Yao <y...@apache.org>
AuthorDate: Thu May 8 13:36:38 2025 +0800

    [SPARK-52034][SQL] Add common methods in SparkOperation trait for 
thriftserver operations
    
    ### What changes were proposed in this pull request?
    
    Add common methods in the SparkOperation trait for ThriftServer operations.
    - sessionState
    - conf
    - catalog // called by every meta operation
    - sparkContext // used by listeners, cancellation, e.t.c..
    - withClassLoader // for switch hive state jar loader
    
    ### Why are the changes needed?
    
    Remove duplicated code and also help solove the IDE type inference issue
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    passing ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50824 from yaooqinn/SPARK-52034.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Kent Yao <y...@apache.org>
---
 .../SparkExecuteStatementOperation.scala           | 30 ++++++++++------------
 .../thriftserver/SparkGetCatalogsOperation.scala   |  6 +----
 .../thriftserver/SparkGetColumnsOperation.scala    |  9 +------
 .../thriftserver/SparkGetFunctionsOperation.scala  |  7 +----
 .../thriftserver/SparkGetSchemasOperation.scala    | 10 +++-----
 .../thriftserver/SparkGetTableTypesOperation.scala |  5 +---
 .../thriftserver/SparkGetTablesOperation.scala     |  6 +----
 .../thriftserver/SparkGetTypeInfoOperation.scala   |  5 +---
 .../sql/hive/thriftserver/SparkOperation.scala     | 25 ++++++++++++++----
 9 files changed, 42 insertions(+), 61 deletions(-)

diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index ce857ce8f866..2ba4d72c9e42 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -54,7 +54,7 @@ private[hive] class SparkExecuteStatementOperation(
   // a global timeout value, we use the user-specified value.
   // This code follows the Hive timeout behaviour (See #29933 for details).
   private val timeout = {
-    val globalTimeout = 
session.sessionState.conf.getConf(SQLConf.THRIFTSERVER_QUERY_TIMEOUT)
+    val globalTimeout = conf.getConf(SQLConf.THRIFTSERVER_QUERY_TIMEOUT)
     if (globalTimeout > 0 && (queryTimeout <= 0 || globalTimeout < 
queryTimeout)) {
       globalTimeout
     } else {
@@ -63,13 +63,13 @@ private[hive] class SparkExecuteStatementOperation(
   }
   private var timeoutExecutor: ScheduledExecutorService = _
 
-  private val forceCancel = 
session.sessionState.conf.getConf(SQLConf.THRIFTSERVER_FORCE_CANCEL)
+  private val forceCancel = conf.getConf(SQLConf.THRIFTSERVER_FORCE_CANCEL)
 
   private val redactedStatement = {
-    val substitutorStatement = 
SQLConf.withExistingConf(session.sessionState.conf) {
+    val substitutorStatement = SQLConf.withExistingConf(conf) {
       new VariableSubstitution().substitute(statement)
     }
-    SparkUtils.redact(session.sessionState.conf.stringRedactionPattern, 
substitutorStatement)
+    SparkUtils.redact(conf.stringRedactionPattern, substitutorStatement)
   }
 
   private var result: DataFrame = _
@@ -89,10 +89,10 @@ private[hive] class SparkExecuteStatementOperation(
 
   def getNextRowSet(order: FetchOrientation, maxRowsL: Long): TRowSet = 
withLocalProperties {
     try {
-      session.sparkContext.setJobGroup(statementId, redactedStatement, 
forceCancel)
+      sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
       getNextRowSetInternal(order, maxRowsL)
     } finally {
-      session.sparkContext.clearJobGroup()
+      sparkContext.clearJobGroup()
     }
   }
 
@@ -209,7 +209,7 @@ private[hive] class SparkExecuteStatementOperation(
     }
   }
 
-  private def execute(): Unit = {
+  private def execute(): Unit = withClassLoader { classLoader =>
     try {
       synchronized {
         if (getStatus.getState.isTerminal) {
@@ -222,16 +222,12 @@ private[hive] class SparkExecuteStatementOperation(
           setState(OperationState.RUNNING)
         }
       }
-      // Always use the latest class loader provided by executionHive's state.
-      val executionHiveClassLoader = session.sharedState.jarClassLoader
-      Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
-
       // Always set the session state classloader to 
`executionHiveClassLoader` even for sync mode
       if (!runInBackground) {
-        
parentSession.getSessionState.getConf.setClassLoader(executionHiveClassLoader)
+        parentSession.getSessionState.getConf.setClassLoader(classLoader)
       }
 
-      session.sparkContext.setJobGroup(statementId, redactedStatement, 
forceCancel)
+      sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
       result = session.sql(statement)
       logDebug(result.queryExecution.toString())
       HiveThriftServer2.eventManager.onStatementParsed(statementId,
@@ -253,7 +249,7 @@ private[hive] class SparkExecuteStatementOperation(
         // task interrupted, it may have started some spark job, so we need to 
cancel again to
         // make sure job was cancelled when background thread was interrupted
         if (statementId != null) {
-          session.sparkContext.cancelJobGroup(statementId,
+          sparkContext.cancelJobGroup(statementId,
             "The corresponding Thriftserver query has failed.")
         }
         val currentState = getStatus().getState()
@@ -271,7 +267,7 @@ private[hive] class SparkExecuteStatementOperation(
           e match {
             case _: HiveSQLException => throw e
             case _ => throw HiveThriftServerErrors.runningQueryError(
-              e, session.sessionState.conf.errorMessageFormat)
+              e, conf.errorMessageFormat)
           }
         }
     } finally {
@@ -281,7 +277,7 @@ private[hive] class SparkExecuteStatementOperation(
           HiveThriftServer2.eventManager.onStatementFinish(statementId)
         }
       }
-      session.sparkContext.clearJobGroup()
+      sparkContext.clearJobGroup()
     }
   }
 
@@ -318,7 +314,7 @@ private[hive] class SparkExecuteStatementOperation(
     }
     // RDDs will be cleaned automatically upon garbage collection.
     if (statementId != null) {
-      session.sparkContext.cancelJobGroup(statementId)
+      sparkContext.cancelJobGroup(statementId)
     }
     // Shutdown the timeout thread if any, while cleaning up this operation
     if (timeoutExecutor != null &&
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala
index e4bb91d466ff..8dfe551892fa 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala
@@ -39,14 +39,10 @@ private[hive] class SparkGetCatalogsOperation(
   with SparkOperation
   with Logging {
 
-  override def runInternal(): Unit = {
+  override def runInternal(): Unit = withClassLoader { _ =>
     val logMsg = "Listing catalogs"
     logInfo(log"Listing catalogs with ${MDC(STATEMENT_ID, statementId)}")
     setState(OperationState.RUNNING)
-    // Always use the latest class loader provided by executionHive's state.
-    val executionHiveClassLoader = session.sharedState.jarClassLoader
-    Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
-
     HiveThriftServer2.eventManager.onStatementStart(
       statementId,
       parentSession.getSessionHandle.getSessionId.toString,
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala
index c0552f715442..4560856cb063 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala
@@ -31,7 +31,6 @@ import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys._
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.SessionCatalog
 import org.apache.spark.sql.types._
 
 /**
@@ -55,9 +54,7 @@ private[hive] class SparkGetColumnsOperation(
   with SparkOperation
   with Logging {
 
-  val catalog: SessionCatalog = session.sessionState.catalog
-
-  override def runInternal(): Unit = {
+  override def runInternal(): Unit = withClassLoader { _ =>
     // Do not change cmdStr. It's used for Hive auditing and authorization.
     val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName, 
tablePattern : $tableName"
     val logMsg = s"Listing columns '$cmdStr, columnName : $columnName'"
@@ -71,10 +68,6 @@ private[hive] class SparkGetColumnsOperation(
       log"with ${MDC(STATEMENT_ID, statementId)}")
 
     setState(OperationState.RUNNING)
-    // Always use the latest class loader provided by executionHive's state.
-    val executionHiveClassLoader = session.sharedState.jarClassLoader
-    Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
-
     HiveThriftServer2.eventManager.onStatementStart(
       statementId,
       parentSession.getSessionHandle.getSessionId.toString,
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala
index 515e64f5f529..c59875b90c44 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala
@@ -49,7 +49,7 @@ private[hive] class SparkGetFunctionsOperation(
   with SparkOperation
   with Logging {
 
-  override def runInternal(): Unit = {
+  override def runInternal(): Unit = withClassLoader { _ =>
     // Do not change cmdStr. It's used for Hive auditing and authorization.
     val cmdMDC = log"catalog : ${MDC(LogKeys.CATALOG_NAME, catalogName)}, " +
       log"schemaPattern : ${MDC(LogKeys.DATABASE_NAME, schemaName)}"
@@ -57,11 +57,6 @@ private[hive] class SparkGetFunctionsOperation(
       log", functionName : ${MDC(LogKeys.FUNCTION_NAME, functionName)}'"
     logInfo(logMDC + log" with ${MDC(LogKeys.STATEMENT_ID, statementId)}")
     setState(OperationState.RUNNING)
-    // Always use the latest class loader provided by executionHive's state.
-    val executionHiveClassLoader = session.sharedState.jarClassLoader
-    Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
-
-    val catalog = session.sessionState.catalog
     // get databases for schema pattern
     val schemaPattern = convertSchemaPattern(schemaName)
     val matchingDbs = catalog.listDatabases(schemaPattern)
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala
index 0e2c35b5ef55..1db286a7a7f2 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala
@@ -46,7 +46,7 @@ private[hive] class SparkGetSchemasOperation(
   with SparkOperation
   with Logging {
 
-  override def runInternal(): Unit = {
+  override def runInternal(): Unit = withClassLoader { _ =>
     // Do not change cmdStr. It's used for Hive auditing and authorization.
     val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName"
     val logMsg = s"Listing databases '$cmdStr'"
@@ -58,10 +58,6 @@ private[hive] class SparkGetSchemasOperation(
       log"with ${MDC(STATEMENT_ID, statementId)}")
 
     setState(OperationState.RUNNING)
-    // Always use the latest class loader provided by executionHive's state.
-    val executionHiveClassLoader = session.sharedState.jarClassLoader
-    Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
-
     if (isAuthV2Enabled) {
       authorizeMetaGets(HiveOperationType.GET_TABLES, null, cmdStr)
     }
@@ -75,11 +71,11 @@ private[hive] class SparkGetSchemasOperation(
 
     try {
       val schemaPattern = convertSchemaPattern(schemaName)
-      session.sessionState.catalog.listDatabases(schemaPattern).foreach { 
dbName =>
+      catalog.listDatabases(schemaPattern).foreach { dbName =>
         rowSet.addRow(Array[AnyRef](dbName, DEFAULT_HIVE_CATALOG))
       }
 
-      val globalTempViewDb = session.sessionState.catalog.globalTempDatabase
+      val globalTempViewDb = catalog.globalTempDatabase
       val databasePattern = 
Pattern.compile(CLIServiceUtils.patternToRegex(schemaName))
       if (schemaName == null || schemaName.isEmpty ||
           databasePattern.matcher(globalTempViewDb).matches()) {
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala
index 9709739a64a4..f8ed09857f1c 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala
@@ -42,14 +42,11 @@ private[hive] class SparkGetTableTypesOperation(
   with SparkOperation
   with Logging {
 
-  override def runInternal(): Unit = {
+  override def runInternal(): Unit = withClassLoader { _ =>
     statementId = UUID.randomUUID().toString
     val logMsg = "Listing table types"
     logInfo(log"Listing table types with ${MDC(STATEMENT_ID, statementId)}")
     setState(OperationState.RUNNING)
-    // Always use the latest class loader provided by executionHive's state.
-    val executionHiveClassLoader = session.sharedState.jarClassLoader
-    Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
 
     if (isAuthV2Enabled) {
       authorizeMetaGets(HiveOperationType.GET_TABLETYPES, null)
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala
index e1dd6e8dd95b..d57c590156d0 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala
@@ -53,7 +53,7 @@ private[hive] class SparkGetTablesOperation(
   with SparkOperation
   with Logging {
 
-  override def runInternal(): Unit = {
+  override def runInternal(): Unit = withClassLoader { _ =>
     // Do not change cmdStr. It's used for Hive auditing and authorization.
     val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName"
     val tableTypesStr = if (tableTypes == null) "null" else 
tableTypes.asScala.mkString(",")
@@ -67,11 +67,7 @@ private[hive] class SparkGetTablesOperation(
       log"tableName: ${MDC(TABLE_NAME, tableName)}' " +
       log"with ${MDC(STATEMENT_ID, statementId)}")
     setState(OperationState.RUNNING)
-    // Always use the latest class loader provided by executionHive's state.
-    val executionHiveClassLoader = session.sharedState.jarClassLoader
-    Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
 
-    val catalog = session.sessionState.catalog
     val schemaPattern = convertSchemaPattern(schemaName)
     val tablePattern = convertIdentifierPattern(tableName, true)
     val matchingDbs = catalog.listDatabases(schemaPattern)
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala
index 456ec44678c5..c982eaaef639 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala
@@ -43,14 +43,11 @@ private[hive] class SparkGetTypeInfoOperation(
   with SparkOperation
   with Logging {
 
-  override def runInternal(): Unit = {
+  override def runInternal(): Unit = withClassLoader { _ =>
     statementId = UUID.randomUUID().toString
     val logMsg = "Listing type info"
     logInfo(log"Listing type info with ${MDC(STATEMENT_ID, statementId)}")
     setState(OperationState.RUNNING)
-    // Always use the latest class loader provided by executionHive's state.
-    val executionHiveClassLoader = session.sharedState.jarClassLoader
-    Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
 
     if (isAuthV2Enabled) {
       authorizeMetaGets(HiveOperationType.GET_TYPEINFO, null)
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala
index b56888f49c1b..f653e899ebf4 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala
@@ -25,9 +25,9 @@ import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.{HIVE_OPERATION_TYPE, STATEMENT_ID}
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
-import org.apache.spark.sql.catalyst.catalog.CatalogTableType
+import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, SessionCatalog}
 import org.apache.spark.sql.catalyst.catalog.CatalogTableType.{EXTERNAL, 
MANAGED, VIEW}
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
 import org.apache.spark.util.Utils
 
 /**
@@ -47,6 +47,21 @@ private[hive] trait SparkOperation extends Operation with 
Logging {
     }
   }
 
+  private def sessionState: SessionState = session.sessionState
+
+  final protected def catalog: SessionCatalog = sessionState.catalog
+
+  final protected def conf: SQLConf = sessionState.conf
+
+  final protected def sparkContext: SparkContext = session.sparkContext
+
+  final protected def withClassLoader(f: ClassLoader => Unit): Unit = {
+    val sharedState: SharedState = session.sharedState
+    val executionHiveClassLoader = sharedState.jarClassLoader
+    Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
+    f(executionHiveClassLoader)
+  }
+
   abstract override def close(): Unit = {
     super.close()
     cleanup()
@@ -62,7 +77,7 @@ private[hive] trait SparkOperation extends Operation with 
Logging {
   // - set appropriate SparkSession
   // - set scheduler pool for the operation
   def withLocalProperties[T](f: => T): T = {
-    val originalProps = 
Utils.cloneProperties(session.sparkContext.getLocalProperties)
+    val originalProps = Utils.cloneProperties(sparkContext.getLocalProperties)
     val originalSession = SparkSession.getActiveSession
 
     try {
@@ -72,7 +87,7 @@ private[hive] trait SparkOperation extends Operation with 
Logging {
       // Set scheduler pool
       session.conf.getOption(SQLConf.THRIFTSERVER_POOL.key) match {
         case Some(pool) =>
-          
session.sparkContext.setLocalProperty(SparkContext.SPARK_SCHEDULER_POOL, pool)
+          sparkContext.setLocalProperty(SparkContext.SPARK_SCHEDULER_POOL, 
pool)
         case None =>
       }
       CURRENT_USER.set(getParentSession.getUserName)
@@ -81,7 +96,7 @@ private[hive] trait SparkOperation extends Operation with 
Logging {
     } finally {
       CURRENT_USER.remove()
       // reset local properties, will also reset SPARK_SCHEDULER_POOL
-      session.sparkContext.setLocalProperties(originalProps)
+      sparkContext.setLocalProperties(originalProps)
 
       originalSession match {
         case Some(session) => SparkSession.setActiveSession(session)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to