This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 12dc89e2206e [SPARK-55909][SQL][TESTS] Introduce trait 
`SparkSessionProvider`
12dc89e2206e is described below

commit 12dc89e2206ee7371f5ad53003b8d3c24a366a60
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Mar 11 16:43:14 2026 +0800

    [SPARK-55909][SQL][TESTS] Introduce trait `SparkSessionProvider`
    
    ### What changes were proposed in this pull request?
    Introduce trait SparkSessionProvider
    
    ### Why are the changes needed?
    so that we can switch the underlying session in the future
    
    ### Does this PR introduce _any_ user-facing change?
    No, test-only
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #54711 from zhengruifeng/add_session_provider.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../scala/org/apache/spark/sql/QueryTest.scala     |  4 +--
 .../apache/spark/sql/SparkSessionProvider.scala    | 29 ++++++++++++++++++++++
 .../apache/spark/sql/execution/SparkPlanTest.scala |  6 ++---
 .../org/apache/spark/sql/test/SQLTestData.scala    | 10 +++++---
 .../org/apache/spark/sql/test/SQLTestUtils.scala   |  4 ++-
 .../apache/spark/sql/test/SharedSparkSession.scala |  8 +++---
 .../spark/sql/hive/test/TestHiveSingleton.scala    |  5 ++--
 7 files changed, 49 insertions(+), 17 deletions(-)

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 515a60228405..ae1b4abe31a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -35,9 +35,7 @@ import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.ArrayImplicits._
 
 
-abstract class QueryTest extends PlanTest {
-
-  protected def spark: SparkSession
+abstract class QueryTest extends PlanTest with SparkSessionProvider {
 
   /**
    * Runs the plan and makes sure the answer contains all of the keywords.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionProvider.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionProvider.scala
new file mode 100644
index 000000000000..67ff122efec8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionProvider.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+/**
+ * A common trait for test suites that require a [[SparkSession]]. It 
abstracts over the
+ * different session types supported by Spark tests: Spark classic sessions, 
Hive sessions
+ * (backed by [[org.apache.spark.sql.hive.test.TestHiveContext]]), and Spark 
Connect sessions.
+ * Concrete implementations are responsible for managing the session lifecycle 
(creation,
+ * configuration, and teardown).
+ */
+trait SparkSessionProvider {
+  protected def spark: SparkSession
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 95d7c4cd3caf..e9c8b3161108 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import scala.util.control.NonFatal
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext}
+import org.apache.spark.sql.{classic, DataFrame, Row, SparkSessionProvider, 
SQLContext}
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.classic.ClassicConversions._
@@ -30,8 +30,8 @@ import org.apache.spark.sql.test.SQLTestUtils
  * Base class for writing tests for individual physical operators. For an 
example of how this
  * class's test helper methods can be used, see [[SortSuite]].
  */
-private[sql] abstract class SparkPlanTest extends SparkFunSuite {
-  protected def spark: SparkSession
+private[sql] abstract class SparkPlanTest extends SparkFunSuite with 
SparkSessionProvider {
+  override protected def spark: classic.SparkSession
 
   /**
    * Runs the plan and makes sure the answer matches the expected result.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 477da731b81b..bd5dd038f5d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -21,7 +21,9 @@ import java.nio.charset.StandardCharsets
 import java.time.{Duration, Period}
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.classic.{DataFrame, SparkSession, SQLImplicits}
+import org.apache.spark.sql.SparkSessionProvider
+import org.apache.spark.sql.classic
+import org.apache.spark.sql.classic.{DataFrame, SQLImplicits}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, 
SECOND}
 import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
@@ -30,12 +32,12 @@ import org.apache.spark.unsafe.types.CalendarInterval
 /**
  * A collection of sample data used in SQL tests.
  */
-private[sql] trait SQLTestData { self =>
-  protected def spark: SparkSession
+private[sql] trait SQLTestData extends SparkSessionProvider { self =>
+  override protected def spark: classic.SparkSession
 
   // Helper object to import SQL implicits without a concrete SparkSession
   private object internalImplicits extends SQLImplicits {
-    override protected def session: SparkSession = self.spark
+    override protected def session: classic.SparkSession = self.spark
   }
 
   import internalImplicits._
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index f0f3f94b811f..e330db6b1292 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -33,7 +33,7 @@ import org.scalatest.{BeforeAndAfterAll, Suite, Tag}
 import org.scalatest.concurrent.Eventually
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{classic, AnalysisException, Row}
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
@@ -228,6 +228,8 @@ private[sql] trait SQLTestUtilsBase
   with SQLTestData
   with PlanTestBase { self: Suite =>
 
+  override protected def spark: classic.SparkSession
+
   protected def sparkContext = spark.sparkContext
 
   // Shorthand for running a query using our SparkSession
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index 720b13b812e0..456e7ed3478a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -24,10 +24,9 @@ import org.scalatest.concurrent.Eventually
 
 import org.apache.spark.{DebugFilesystem, SparkConf}
 import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{classic, SparkSession, SparkSessionProvider, 
SQLContext}
 import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
 import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
 
 trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase {
@@ -84,6 +83,7 @@ trait SharedSparkSession extends SQLTestUtils with 
SharedSparkSessionBase {
  */
 trait SharedSparkSessionBase
   extends SQLTestUtilsBase
+  with SparkSessionProvider
   with BeforeAndAfterEach
   with Eventually { self: Suite =>
 
@@ -121,7 +121,7 @@ trait SharedSparkSessionBase
   /**
    * The [[TestSparkSession]] to use for all tests in this suite.
    */
-  protected implicit def spark: SparkSession = _spark
+  protected override def spark: classic.SparkSession = _spark
 
   /**
    * The [[TestSQLContext]] to use for all tests in this suite.
@@ -129,7 +129,7 @@ trait SharedSparkSessionBase
   protected implicit def sqlContext: SQLContext = _spark.sqlContext
 
   protected def createSparkSession: TestSparkSession = {
-    SparkSession.cleanupAnyExistingSession()
+    classic.SparkSession.cleanupAnyExistingSession()
     new TestSparkSession(sparkConf)
   }
 
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
index d1bc6bd92fff..47cc9853f754 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
@@ -18,14 +18,15 @@
 package org.apache.spark.sql.hive.test
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.SparkSessionProvider
 import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.hive.HiveExternalCatalog
 import org.apache.spark.sql.hive.client.HiveClient
 
 
-trait TestHiveSingleton extends SparkFunSuite {
+trait TestHiveSingleton extends SparkFunSuite with SparkSessionProvider {
   override protected val enableAutoThreadAudit = false
-  protected val spark: SparkSession = TestHive.sparkSession
+  override protected val spark: SparkSession = TestHive.sparkSession
   protected val hiveContext: TestHiveContext = TestHive
   protected val hiveClient: HiveClient =
     
spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to