This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 12dc89e2206e [SPARK-55909][SQL][TESTS] Introduce trait
`SparkSessionProvider`
12dc89e2206e is described below
commit 12dc89e2206ee7371f5ad53003b8d3c24a366a60
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Mar 11 16:43:14 2026 +0800
[SPARK-55909][SQL][TESTS] Introduce trait `SparkSessionProvider`
### What changes were proposed in this pull request?
Introduce trait SparkSessionProvider
### Why are the changes needed?
so that we can switch the underlying session in the future
### Does this PR introduce _any_ user-facing change?
No, test-only
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #54711 from zhengruifeng/add_session_provider.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../scala/org/apache/spark/sql/QueryTest.scala | 4 +--
.../apache/spark/sql/SparkSessionProvider.scala | 29 ++++++++++++++++++++++
.../apache/spark/sql/execution/SparkPlanTest.scala | 6 ++---
.../org/apache/spark/sql/test/SQLTestData.scala | 10 +++++---
.../org/apache/spark/sql/test/SQLTestUtils.scala | 4 ++-
.../apache/spark/sql/test/SharedSparkSession.scala | 8 +++---
.../spark/sql/hive/test/TestHiveSingleton.scala | 5 ++--
7 files changed, 49 insertions(+), 17 deletions(-)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 515a60228405..ae1b4abe31a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -35,9 +35,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
-abstract class QueryTest extends PlanTest {
-
- protected def spark: SparkSession
+abstract class QueryTest extends PlanTest with SparkSessionProvider {
/**
* Runs the plan and makes sure the answer contains all of the keywords.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionProvider.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionProvider.scala
new file mode 100644
index 000000000000..67ff122efec8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionProvider.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+/**
+ * A common trait for test suites that require a [[SparkSession]]. It
abstracts over the
+ * different session types supported by Spark tests: Spark classic sessions,
Hive sessions
+ * (backed by [[org.apache.spark.sql.hive.test.TestHiveContext]]), and Spark
Connect sessions.
+ * Concrete implementations are responsible for managing the session lifecycle
(creation,
+ * configuration, and teardown).
+ */
+trait SparkSessionProvider {
+ protected def spark: SparkSession
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 95d7c4cd3caf..e9c8b3161108 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext}
+import org.apache.spark.sql.{classic, DataFrame, Row, SparkSessionProvider,
SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.classic.ClassicConversions._
@@ -30,8 +30,8 @@ import org.apache.spark.sql.test.SQLTestUtils
* Base class for writing tests for individual physical operators. For an
example of how this
* class's test helper methods can be used, see [[SortSuite]].
*/
-private[sql] abstract class SparkPlanTest extends SparkFunSuite {
- protected def spark: SparkSession
+private[sql] abstract class SparkPlanTest extends SparkFunSuite with
SparkSessionProvider {
+ override protected def spark: classic.SparkSession
/**
* Runs the plan and makes sure the answer matches the expected result.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 477da731b81b..bd5dd038f5d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -21,7 +21,9 @@ import java.nio.charset.StandardCharsets
import java.time.{Duration, Period}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.classic.{DataFrame, SparkSession, SQLImplicits}
+import org.apache.spark.sql.SparkSessionProvider
+import org.apache.spark.sql.classic
+import org.apache.spark.sql.classic.{DataFrame, SQLImplicits}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE,
SECOND}
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
@@ -30,12 +32,12 @@ import org.apache.spark.unsafe.types.CalendarInterval
/**
* A collection of sample data used in SQL tests.
*/
-private[sql] trait SQLTestData { self =>
- protected def spark: SparkSession
+private[sql] trait SQLTestData extends SparkSessionProvider { self =>
+ override protected def spark: classic.SparkSession
// Helper object to import SQL implicits without a concrete SparkSession
private object internalImplicits extends SQLImplicits {
- override protected def session: SparkSession = self.spark
+ override protected def session: classic.SparkSession = self.spark
}
import internalImplicits._
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index f0f3f94b811f..e330db6b1292 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -33,7 +33,7 @@ import org.scalatest.{BeforeAndAfterAll, Suite, Tag}
import org.scalatest.concurrent.Eventually
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{classic, AnalysisException, Row}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
@@ -228,6 +228,8 @@ private[sql] trait SQLTestUtilsBase
with SQLTestData
with PlanTestBase { self: Suite =>
+ override protected def spark: classic.SparkSession
+
protected def sparkContext = spark.sparkContext
// Shorthand for running a query using our SparkSession
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index 720b13b812e0..456e7ed3478a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -24,10 +24,9 @@ import org.scalatest.concurrent.Eventually
import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{classic, SparkSession, SparkSessionProvider,
SQLContext}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase {
@@ -84,6 +83,7 @@ trait SharedSparkSession extends SQLTestUtils with
SharedSparkSessionBase {
*/
trait SharedSparkSessionBase
extends SQLTestUtilsBase
+ with SparkSessionProvider
with BeforeAndAfterEach
with Eventually { self: Suite =>
@@ -121,7 +121,7 @@ trait SharedSparkSessionBase
/**
* The [[TestSparkSession]] to use for all tests in this suite.
*/
- protected implicit def spark: SparkSession = _spark
+ protected override def spark: classic.SparkSession = _spark
/**
* The [[TestSQLContext]] to use for all tests in this suite.
@@ -129,7 +129,7 @@ trait SharedSparkSessionBase
protected implicit def sqlContext: SQLContext = _spark.sqlContext
protected def createSparkSession: TestSparkSession = {
- SparkSession.cleanupAnyExistingSession()
+ classic.SparkSession.cleanupAnyExistingSession()
new TestSparkSession(sparkConf)
}
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
index d1bc6bd92fff..47cc9853f754 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
@@ -18,14 +18,15 @@
package org.apache.spark.sql.hive.test
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.SparkSessionProvider
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.hive.HiveExternalCatalog
import org.apache.spark.sql.hive.client.HiveClient
-trait TestHiveSingleton extends SparkFunSuite {
+trait TestHiveSingleton extends SparkFunSuite with SparkSessionProvider {
override protected val enableAutoThreadAudit = false
- protected val spark: SparkSession = TestHive.sparkSession
+ override protected val spark: SparkSession = TestHive.sparkSession
protected val hiveContext: TestHiveContext = TestHive
protected val hiveClient: HiveClient =
spark.sharedState.externalCatalog.unwrapped.asInstanceOf[HiveExternalCatalog].client
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]