This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 604bb34b26dc [SPARK-52034][SQL] Add common methods in SparkOperation trait for thriftserver operations 604bb34b26dc is described below commit 604bb34b26dcd3462c888782620a3d5c6a31884f Author: Kent Yao <y...@apache.org> AuthorDate: Thu May 8 13:36:38 2025 +0800 [SPARK-52034][SQL] Add common methods in SparkOperation trait for thriftserver operations ### What changes were proposed in this pull request? Add common methods in the SparkOperation trait for ThriftServer operations. - sessionState - conf - catalog // called by every meta operation - sparkContext // used by listeners, cancellation, e.t.c.. - withClassLoader // for switch hive state jar loader ### Why are the changes needed? Remove duplicated code and also help solove the IDE type inference issue ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? passing ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #50824 from yaooqinn/SPARK-52034. Authored-by: Kent Yao <y...@apache.org> Signed-off-by: Kent Yao <y...@apache.org> --- .../SparkExecuteStatementOperation.scala | 30 ++++++++++------------ .../thriftserver/SparkGetCatalogsOperation.scala | 6 +---- .../thriftserver/SparkGetColumnsOperation.scala | 9 +------ .../thriftserver/SparkGetFunctionsOperation.scala | 7 +---- .../thriftserver/SparkGetSchemasOperation.scala | 10 +++----- .../thriftserver/SparkGetTableTypesOperation.scala | 5 +--- .../thriftserver/SparkGetTablesOperation.scala | 6 +---- .../thriftserver/SparkGetTypeInfoOperation.scala | 5 +--- .../sql/hive/thriftserver/SparkOperation.scala | 25 ++++++++++++++---- 9 files changed, 42 insertions(+), 61 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index ce857ce8f866..2ba4d72c9e42 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -54,7 +54,7 @@ private[hive] class SparkExecuteStatementOperation( // a global timeout value, we use the user-specified value. // This code follows the Hive timeout behaviour (See #29933 for details). private val timeout = { - val globalTimeout = session.sessionState.conf.getConf(SQLConf.THRIFTSERVER_QUERY_TIMEOUT) + val globalTimeout = conf.getConf(SQLConf.THRIFTSERVER_QUERY_TIMEOUT) if (globalTimeout > 0 && (queryTimeout <= 0 || globalTimeout < queryTimeout)) { globalTimeout } else { @@ -63,13 +63,13 @@ private[hive] class SparkExecuteStatementOperation( } private var timeoutExecutor: ScheduledExecutorService = _ - private val forceCancel = session.sessionState.conf.getConf(SQLConf.THRIFTSERVER_FORCE_CANCEL) + private val forceCancel = conf.getConf(SQLConf.THRIFTSERVER_FORCE_CANCEL) private val redactedStatement = { - val substitutorStatement = SQLConf.withExistingConf(session.sessionState.conf) { + val substitutorStatement = SQLConf.withExistingConf(conf) { new VariableSubstitution().substitute(statement) } - SparkUtils.redact(session.sessionState.conf.stringRedactionPattern, substitutorStatement) + SparkUtils.redact(conf.stringRedactionPattern, substitutorStatement) } private var result: DataFrame = _ @@ -89,10 +89,10 @@ private[hive] class SparkExecuteStatementOperation( def getNextRowSet(order: FetchOrientation, maxRowsL: Long): TRowSet = withLocalProperties { try { - session.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel) + sparkContext.setJobGroup(statementId, redactedStatement, forceCancel) getNextRowSetInternal(order, maxRowsL) } finally { - session.sparkContext.clearJobGroup() + sparkContext.clearJobGroup() } } @@ -209,7 +209,7 @@ private[hive] class SparkExecuteStatementOperation( } } - private def execute(): Unit = { + private def execute(): Unit = withClassLoader { classLoader => try { synchronized { if (getStatus.getState.isTerminal) { @@ -222,16 +222,12 @@ private[hive] class SparkExecuteStatementOperation( setState(OperationState.RUNNING) } } - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = session.sharedState.jarClassLoader - Thread.currentThread().setContextClassLoader(executionHiveClassLoader) - // Always set the session state classloader to `executionHiveClassLoader` even for sync mode if (!runInBackground) { - parentSession.getSessionState.getConf.setClassLoader(executionHiveClassLoader) + parentSession.getSessionState.getConf.setClassLoader(classLoader) } - session.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel) + sparkContext.setJobGroup(statementId, redactedStatement, forceCancel) result = session.sql(statement) logDebug(result.queryExecution.toString()) HiveThriftServer2.eventManager.onStatementParsed(statementId, @@ -253,7 +249,7 @@ private[hive] class SparkExecuteStatementOperation( // task interrupted, it may have started some spark job, so we need to cancel again to // make sure job was cancelled when background thread was interrupted if (statementId != null) { - session.sparkContext.cancelJobGroup(statementId, + sparkContext.cancelJobGroup(statementId, "The corresponding Thriftserver query has failed.") } val currentState = getStatus().getState() @@ -271,7 +267,7 @@ private[hive] class SparkExecuteStatementOperation( e match { case _: HiveSQLException => throw e case _ => throw HiveThriftServerErrors.runningQueryError( - e, session.sessionState.conf.errorMessageFormat) + e, conf.errorMessageFormat) } } } finally { @@ -281,7 +277,7 @@ private[hive] class SparkExecuteStatementOperation( HiveThriftServer2.eventManager.onStatementFinish(statementId) } } - session.sparkContext.clearJobGroup() + sparkContext.clearJobGroup() } } @@ -318,7 +314,7 @@ private[hive] class SparkExecuteStatementOperation( } // RDDs will be cleaned automatically upon garbage collection. if (statementId != null) { - session.sparkContext.cancelJobGroup(statementId) + sparkContext.cancelJobGroup(statementId) } // Shutdown the timeout thread if any, while cleaning up this operation if (timeoutExecutor != null && diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala index e4bb91d466ff..8dfe551892fa 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetCatalogsOperation.scala @@ -39,14 +39,10 @@ private[hive] class SparkGetCatalogsOperation( with SparkOperation with Logging { - override def runInternal(): Unit = { + override def runInternal(): Unit = withClassLoader { _ => val logMsg = "Listing catalogs" logInfo(log"Listing catalogs with ${MDC(STATEMENT_ID, statementId)}") setState(OperationState.RUNNING) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = session.sharedState.jarClassLoader - Thread.currentThread().setContextClassLoader(executionHiveClassLoader) - HiveThriftServer2.eventManager.onStatementStart( statementId, parentSession.getSessionHandle.getSessionId.toString, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala index c0552f715442..4560856cb063 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetColumnsOperation.scala @@ -31,7 +31,6 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.types._ /** @@ -55,9 +54,7 @@ private[hive] class SparkGetColumnsOperation( with SparkOperation with Logging { - val catalog: SessionCatalog = session.sessionState.catalog - - override def runInternal(): Unit = { + override def runInternal(): Unit = withClassLoader { _ => // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName, tablePattern : $tableName" val logMsg = s"Listing columns '$cmdStr, columnName : $columnName'" @@ -71,10 +68,6 @@ private[hive] class SparkGetColumnsOperation( log"with ${MDC(STATEMENT_ID, statementId)}") setState(OperationState.RUNNING) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = session.sharedState.jarClassLoader - Thread.currentThread().setContextClassLoader(executionHiveClassLoader) - HiveThriftServer2.eventManager.onStatementStart( statementId, parentSession.getSessionHandle.getSessionId.toString, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala index 515e64f5f529..c59875b90c44 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetFunctionsOperation.scala @@ -49,7 +49,7 @@ private[hive] class SparkGetFunctionsOperation( with SparkOperation with Logging { - override def runInternal(): Unit = { + override def runInternal(): Unit = withClassLoader { _ => // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdMDC = log"catalog : ${MDC(LogKeys.CATALOG_NAME, catalogName)}, " + log"schemaPattern : ${MDC(LogKeys.DATABASE_NAME, schemaName)}" @@ -57,11 +57,6 @@ private[hive] class SparkGetFunctionsOperation( log", functionName : ${MDC(LogKeys.FUNCTION_NAME, functionName)}'" logInfo(logMDC + log" with ${MDC(LogKeys.STATEMENT_ID, statementId)}") setState(OperationState.RUNNING) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = session.sharedState.jarClassLoader - Thread.currentThread().setContextClassLoader(executionHiveClassLoader) - - val catalog = session.sessionState.catalog // get databases for schema pattern val schemaPattern = convertSchemaPattern(schemaName) val matchingDbs = catalog.listDatabases(schemaPattern) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala index 0e2c35b5ef55..1db286a7a7f2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetSchemasOperation.scala @@ -46,7 +46,7 @@ private[hive] class SparkGetSchemasOperation( with SparkOperation with Logging { - override def runInternal(): Unit = { + override def runInternal(): Unit = withClassLoader { _ => // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" val logMsg = s"Listing databases '$cmdStr'" @@ -58,10 +58,6 @@ private[hive] class SparkGetSchemasOperation( log"with ${MDC(STATEMENT_ID, statementId)}") setState(OperationState.RUNNING) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = session.sharedState.jarClassLoader - Thread.currentThread().setContextClassLoader(executionHiveClassLoader) - if (isAuthV2Enabled) { authorizeMetaGets(HiveOperationType.GET_TABLES, null, cmdStr) } @@ -75,11 +71,11 @@ private[hive] class SparkGetSchemasOperation( try { val schemaPattern = convertSchemaPattern(schemaName) - session.sessionState.catalog.listDatabases(schemaPattern).foreach { dbName => + catalog.listDatabases(schemaPattern).foreach { dbName => rowSet.addRow(Array[AnyRef](dbName, DEFAULT_HIVE_CATALOG)) } - val globalTempViewDb = session.sessionState.catalog.globalTempDatabase + val globalTempViewDb = catalog.globalTempDatabase val databasePattern = Pattern.compile(CLIServiceUtils.patternToRegex(schemaName)) if (schemaName == null || schemaName.isEmpty || databasePattern.matcher(globalTempViewDb).matches()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala index 9709739a64a4..f8ed09857f1c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTableTypesOperation.scala @@ -42,14 +42,11 @@ private[hive] class SparkGetTableTypesOperation( with SparkOperation with Logging { - override def runInternal(): Unit = { + override def runInternal(): Unit = withClassLoader { _ => statementId = UUID.randomUUID().toString val logMsg = "Listing table types" logInfo(log"Listing table types with ${MDC(STATEMENT_ID, statementId)}") setState(OperationState.RUNNING) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = session.sharedState.jarClassLoader - Thread.currentThread().setContextClassLoader(executionHiveClassLoader) if (isAuthV2Enabled) { authorizeMetaGets(HiveOperationType.GET_TABLETYPES, null) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala index e1dd6e8dd95b..d57c590156d0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala @@ -53,7 +53,7 @@ private[hive] class SparkGetTablesOperation( with SparkOperation with Logging { - override def runInternal(): Unit = { + override def runInternal(): Unit = withClassLoader { _ => // Do not change cmdStr. It's used for Hive auditing and authorization. val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" val tableTypesStr = if (tableTypes == null) "null" else tableTypes.asScala.mkString(",") @@ -67,11 +67,7 @@ private[hive] class SparkGetTablesOperation( log"tableName: ${MDC(TABLE_NAME, tableName)}' " + log"with ${MDC(STATEMENT_ID, statementId)}") setState(OperationState.RUNNING) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = session.sharedState.jarClassLoader - Thread.currentThread().setContextClassLoader(executionHiveClassLoader) - val catalog = session.sessionState.catalog val schemaPattern = convertSchemaPattern(schemaName) val tablePattern = convertIdentifierPattern(tableName, true) val matchingDbs = catalog.listDatabases(schemaPattern) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala index 456ec44678c5..c982eaaef639 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala @@ -43,14 +43,11 @@ private[hive] class SparkGetTypeInfoOperation( with SparkOperation with Logging { - override def runInternal(): Unit = { + override def runInternal(): Unit = withClassLoader { _ => statementId = UUID.randomUUID().toString val logMsg = "Listing type info" logInfo(log"Listing type info with ${MDC(STATEMENT_ID, statementId)}") setState(OperationState.RUNNING) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = session.sharedState.jarClassLoader - Thread.currentThread().setContextClassLoader(executionHiveClassLoader) if (isAuthV2Enabled) { authorizeMetaGets(HiveOperationType.GET_TYPEINFO, null) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala index b56888f49c1b..f653e899ebf4 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkOperation.scala @@ -25,9 +25,9 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{HIVE_OPERATION_TYPE, STATEMENT_ID} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER -import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.CatalogTableType.{EXTERNAL, MANAGED, VIEW} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.util.Utils /** @@ -47,6 +47,21 @@ private[hive] trait SparkOperation extends Operation with Logging { } } + private def sessionState: SessionState = session.sessionState + + final protected def catalog: SessionCatalog = sessionState.catalog + + final protected def conf: SQLConf = sessionState.conf + + final protected def sparkContext: SparkContext = session.sparkContext + + final protected def withClassLoader(f: ClassLoader => Unit): Unit = { + val sharedState: SharedState = session.sharedState + val executionHiveClassLoader = sharedState.jarClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + f(executionHiveClassLoader) + } + abstract override def close(): Unit = { super.close() cleanup() @@ -62,7 +77,7 @@ private[hive] trait SparkOperation extends Operation with Logging { // - set appropriate SparkSession // - set scheduler pool for the operation def withLocalProperties[T](f: => T): T = { - val originalProps = Utils.cloneProperties(session.sparkContext.getLocalProperties) + val originalProps = Utils.cloneProperties(sparkContext.getLocalProperties) val originalSession = SparkSession.getActiveSession try { @@ -72,7 +87,7 @@ private[hive] trait SparkOperation extends Operation with Logging { // Set scheduler pool session.conf.getOption(SQLConf.THRIFTSERVER_POOL.key) match { case Some(pool) => - session.sparkContext.setLocalProperty(SparkContext.SPARK_SCHEDULER_POOL, pool) + sparkContext.setLocalProperty(SparkContext.SPARK_SCHEDULER_POOL, pool) case None => } CURRENT_USER.set(getParentSession.getUserName) @@ -81,7 +96,7 @@ private[hive] trait SparkOperation extends Operation with Logging { } finally { CURRENT_USER.remove() // reset local properties, will also reset SPARK_SCHEDULER_POOL - session.sparkContext.setLocalProperties(originalProps) + sparkContext.setLocalProperties(originalProps) originalSession match { case Some(session) => SparkSession.setActiveSession(session) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org