This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new a579c87f8220 [SPARK-55337][SS] Fix MemoryStream backward compatibility
a579c87f8220 is described below

commit a579c87f8220d192c81a6b0086b7b7487bab9e8d
Author: Wenchen Fan <[email protected]>
AuthorDate: Sun Feb 8 21:09:30 2026 +0500

    [SPARK-55337][SS] Fix MemoryStream backward compatibility
    
    This is a followup to #52402 that addresses backward compatibility concerns:
    
    1. Keep the original `implicit SQLContext` factory methods for full 
backward compatibility
    2. Add new overloads with explicit `SparkSession` parameter for new code
    3. Fix `TestGraphRegistrationContext` to provide implicit `spark` and 
`sqlContext` to avoid name shadowing issues in nested classes
    4. Remove redundant `implicit val sparkSession` declarations from pipeline 
tests that are no longer needed with the fix
    
    PR #52402 changed the MemoryStream API to use `implicit SparkSession` which 
broke backward compatibility for code that only has `implicit SQLContext` 
available. This followup ensures:
    
    - Old code continues to work without modification
    - New code can use SparkSession with explicit parameters
    - Internal implementation uses SparkSession (modernization from #52402)
    
    No. This maintains full backward compatibility while adding new API options.
    
    Existing tests pass. The API changes are additive.
    
    Yes
    
    Made with [Cursor](https://cursor.com)
    
    Closes #54108 from cloud-fan/memory-stream-compat.
    
    Lead-authored-by: Wenchen Fan <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit db28b99ded98dd2258d2e3a6d13f9b366cc0ad3d)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/execution/streaming/runtime/memory.scala   | 57 ++++++++++------
 .../streaming/sources/ContinuousMemoryStream.scala | 51 ++++++--------
 .../streaming/sources/LowLatencyMemoryStream.scala | 56 ++++++----------
 .../streaming/PythonStreamingDataSourceSuite.scala |  4 +-
 .../sql/execution/streaming/MemorySinkSuite.scala  | 78 ----------------------
 .../state/StateStoreCoordinatorSuite.scala         | 34 +++++-----
 .../streaming/state/StateStoreSuite.scala          |  4 +-
 .../streaming/FlatMapGroupsWithStateSuite.scala    |  2 +-
 .../sql/streaming/StreamingAggregationSuite.scala  |  2 +-
 .../streaming/StreamingDeduplicationSuite.scala    |  2 +-
 .../spark/sql/streaming/StreamingJoinSuite.scala   |  2 +-
 .../spark/sql/hive/execution/HiveDDLSuite.scala    |  2 +-
 .../graph/ConnectInvalidPipelineSuite.scala        |  5 +-
 .../graph/ConnectValidPipelineSuite.scala          |  6 --
 .../pipelines/graph/MaterializeTablesSuite.scala   |  8 +--
 .../sql/pipelines/graph/SystemMetadataSuite.scala  |  5 --
 .../graph/TriggeredGraphExecutionSuite.scala       |  6 +-
 .../utils/TestGraphRegistrationContext.scala       | 13 ++--
 18 files changed, 117 insertions(+), 220 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala
index bf67ed670ec8..c7556ed47859 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala
@@ -43,36 +43,51 @@ import 
org.apache.spark.sql.internal.connector.SimpleTableProvider
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
-object MemoryStream extends LowPriorityMemoryStreamImplicits {
+object MemoryStream {
   protected val currentBlockId = new AtomicInteger(0)
   protected val memoryStreamId = new AtomicInteger(0)
 
-  def apply[A : Encoder](implicit sparkSession: SparkSession): MemoryStream[A] 
=
-    new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
-
-  def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: 
SparkSession): MemoryStream[A] =
-    new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 
Some(numPartitions))
-}
-
-/**
- * Provides lower-priority implicits for MemoryStream to prevent ambiguity 
when both
- * SparkSession and SQLContext are in scope. The implicits in the companion 
object,
- * which use SparkSession, take higher precedence.
- */
-trait LowPriorityMemoryStreamImplicits {
-  this: MemoryStream.type =>
-
-  // Deprecated: Used when an implicit SQLContext is in scope
-  @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of 
SQLContext", "4.1.0")
-  def apply[A: Encoder]()(implicit sqlContext: SQLContext): MemoryStream[A] =
+  /**
+   * Creates a MemoryStream with an implicit SQLContext (backward compatible).
+   * Usage: `MemoryStream[Int]`
+   */
+  def apply[A: Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
     new MemoryStream[A](memoryStreamId.getAndIncrement(), 
sqlContext.sparkSession)
 
-  @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of 
SQLContext", "4.1.0")
-  def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): 
MemoryStream[A] =
+  /**
+   * Creates a MemoryStream with specified partitions using implicit 
SQLContext.
+   * Usage: `MemoryStream[Int](numPartitions)`
+   */
+  def apply[A: Encoder](numPartitions: Int)(
+      implicit sqlContext: SQLContext): MemoryStream[A] =
     new MemoryStream[A](
       memoryStreamId.getAndIncrement(),
       sqlContext.sparkSession,
       Some(numPartitions))
+
+  /**
+   * Creates a MemoryStream with explicit SparkSession.
+   * Usage: `MemoryStream[Int](spark)`
+   */
+  def apply[A: Encoder](sparkSession: SparkSession): MemoryStream[A] =
+    new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
+
+  /**
+   * Creates a MemoryStream with specified partitions using explicit 
SparkSession.
+   * Usage: `MemoryStream[Int](spark, numPartitions)`
+   */
+  def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): 
MemoryStream[A] =
+    new MemoryStream[A](
+      memoryStreamId.getAndIncrement(),
+      sparkSession,
+      Some(numPartitions))
+
+  /**
+   * Creates a MemoryStream with explicit encoder and SparkSession.
+   * Usage: `MemoryStream(Encoders.scalaInt, spark)`
+   */
+  def apply[A](encoder: Encoder[A], sparkSession: SparkSession): 
MemoryStream[A] =
+    new MemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession)(encoder)
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index 8042cacf1374..885f9ada22c9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -112,47 +112,36 @@ class ContinuousMemoryStream[A : Encoder](
   override def commit(end: Offset): Unit = {}
 }
 
-object ContinuousMemoryStream extends 
LowPriorityContinuousMemoryStreamImplicits {
+object ContinuousMemoryStream {
   protected val memoryStreamId = new AtomicInteger(0)
 
-  def apply[A : Encoder](implicit sparkSession: SparkSession): 
ContinuousMemoryStream[A] =
-    new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession)
-
-  def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: 
SparkSession):
-  ContinuousMemoryStream[A] =
-    new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession, numPartitions)
-
-  def singlePartition[A : Encoder](implicit sparkSession: SparkSession): 
ContinuousMemoryStream[A] =
-    new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession, 1)
-}
-
-/**
- * Provides lower-priority implicits for ContinuousMemoryStream to prevent 
ambiguity when both
- * SparkSession and SQLContext are in scope. The implicits in the companion 
object,
- * which use SparkSession, take higher precedence.
- */
-trait LowPriorityContinuousMemoryStreamImplicits {
-  this: ContinuousMemoryStream.type =>
-
-  // Deprecated: Used when an implicit SQLContext is in scope
-  @deprecated("Use ContinuousMemoryStream with an implicit SparkSession " +
-    "instead of SQLContext", "4.1.0")
-  def apply[A: Encoder]()(implicit sqlContext: SQLContext): 
ContinuousMemoryStream[A] =
+  /** Creates a ContinuousMemoryStream with an implicit SQLContext (backward 
compatible). */
+  def apply[A: Encoder](implicit sqlContext: SQLContext): 
ContinuousMemoryStream[A] =
     new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), 
sqlContext.sparkSession)
 
-  @deprecated("Use ContinuousMemoryStream with an implicit SparkSession " +
-    "instead of SQLContext", "4.1.0")
-  def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext):
-  ContinuousMemoryStream[A] =
+  /** Creates a ContinuousMemoryStream with specified partitions (SQLContext). 
*/
+  def apply[A: Encoder](numPartitions: Int)(
+      implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
     new ContinuousMemoryStream[A](
       memoryStreamId.getAndIncrement(),
       sqlContext.sparkSession,
       numPartitions)
 
-  @deprecated("Use ContinuousMemoryStream.singlePartition with an implicit 
SparkSession " +
-    "instead of SQLContext", "4.1.0")
-  def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): 
ContinuousMemoryStream[A] =
+  /** Creates a ContinuousMemoryStream with explicit SparkSession. */
+  def apply[A: Encoder](sparkSession: SparkSession): ContinuousMemoryStream[A] 
=
+    new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession)
+
+  /** Creates a ContinuousMemoryStream with specified partitions 
(SparkSession). */
+  def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): 
ContinuousMemoryStream[A] =
+    new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession, numPartitions)
+
+  /** Creates a single partition ContinuousMemoryStream (SQLContext). */
+  def singlePartition[A: Encoder](implicit sqlContext: SQLContext): 
ContinuousMemoryStream[A] =
     new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), 
sqlContext.sparkSession, 1)
+
+  /** Creates a single partition ContinuousMemoryStream (SparkSession). */
+  def singlePartition[A: Encoder](sparkSession: SparkSession): 
ContinuousMemoryStream[A] =
+    new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession, 1)
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala
index d04f4b5d011c..6dfeb0cc4603 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala
@@ -172,53 +172,39 @@ class LowLatencyMemoryStream[A: Encoder](
   }
 }
 
-object LowLatencyMemoryStream extends 
LowPriorityLowLatencyMemoryStreamImplicits {
+object LowLatencyMemoryStream {
   protected val memoryStreamId = new AtomicInteger(0)
 
-  def apply[A: Encoder](implicit sparkSession: SparkSession): 
LowLatencyMemoryStream[A] =
-    new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession)
+  /** Creates a LowLatencyMemoryStream with an implicit SQLContext (backward 
compatible). */
+  def apply[A: Encoder](implicit sqlContext: SQLContext): 
LowLatencyMemoryStream[A] =
+    new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), 
sqlContext.sparkSession)
 
+  /** Creates a LowLatencyMemoryStream with specified partitions (SQLContext). 
*/
   def apply[A: Encoder](numPartitions: Int)(
-      implicit
-      sparkSession: SparkSession): LowLatencyMemoryStream[A] =
+      implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
     new LowLatencyMemoryStream[A](
       memoryStreamId.getAndIncrement(),
-      sparkSession,
-      numPartitions = numPartitions
-    )
-
-  def singlePartition[A: Encoder](implicit sparkSession: SparkSession): 
LowLatencyMemoryStream[A] =
-    new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession, 1)
-}
-
-/**
- * Provides lower-priority implicits for LowLatencyMemoryStream to prevent 
ambiguity when both
- * SparkSession and SQLContext are in scope. The implicits in the companion 
object,
- * which use SparkSession, take higher precedence.
- */
-trait LowPriorityLowLatencyMemoryStreamImplicits {
-  this: LowLatencyMemoryStream.type =>
+      sqlContext.sparkSession,
+      numPartitions = numPartitions)
 
-  // Deprecated: Used when an implicit SQLContext is in scope
-  @deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " +
-    "instead of SQLContext", "4.1.0")
-  def apply[A: Encoder]()(implicit sqlContext: SQLContext): 
LowLatencyMemoryStream[A] =
-    new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), 
sqlContext.sparkSession)
+  /** Creates a LowLatencyMemoryStream with explicit SparkSession. */
+  def apply[A: Encoder](sparkSession: SparkSession): LowLatencyMemoryStream[A] 
=
+    new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession)
 
-  @deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " +
-    "instead of SQLContext", "4.1.0")
-  def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext):
-  LowLatencyMemoryStream[A] =
+  /** Creates a LowLatencyMemoryStream with specified partitions 
(SparkSession). */
+  def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): 
LowLatencyMemoryStream[A] =
     new LowLatencyMemoryStream[A](
       memoryStreamId.getAndIncrement(),
-      sqlContext.sparkSession,
-      numPartitions = numPartitions
-    )
+      sparkSession,
+      numPartitions = numPartitions)
 
-  @deprecated("Use LowLatencyMemoryStream.singlePartition with an implicit 
SparkSession " +
-    "instead of SQLContext", "4.1.0")
-  def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): 
LowLatencyMemoryStream[A] =
+  /** Creates a single partition LowLatencyMemoryStream (SQLContext). */
+  def singlePartition[A: Encoder](implicit sqlContext: SQLContext): 
LowLatencyMemoryStream[A] =
     new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), 
sqlContext.sparkSession, 1)
+
+  /** Creates a single partition LowLatencyMemoryStream (SparkSession). */
+  def singlePartition[A: Encoder](sparkSession: SparkSession): 
LowLatencyMemoryStream[A] =
+    new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), 
sparkSession, 1)
 }
 
 /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
index 3b3e8687858d..074a5cf9f2bf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
@@ -859,7 +859,7 @@ class PythonStreamingDataSourceWriteSuite extends 
PythonDataSourceSuiteBase {
     val dataSource =
       createUserDefinedPythonDataSource(dataSourceName, 
simpleDataStreamWriterScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
-    val inputData = MemoryStream[Int](numPartitions = 3)
+    val inputData = MemoryStream[Int](spark, numPartitions = 3)
     val df = inputData.toDF()
     withTempDir { dir =>
       val path = dir.getAbsolutePath
@@ -943,7 +943,7 @@ class PythonStreamingDataSourceWriteSuite extends 
PythonDataSourceSuiteBase {
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
-    val inputData = MemoryStream[Int](numPartitions = 3)
+    val inputData = MemoryStream[Int](spark, numPartitions = 3)
     withTempDir { dir =>
       val path = dir.getAbsolutePath
       val checkpointDir = new File(path, "checkpoint")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
index e0ec3fd1b907..4ec44eac22e3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
@@ -343,84 +343,6 @@ class MemorySinkSuite extends StreamTest with 
BeforeAndAfter {
       intsToDF(expected)(schema))
   }
 
-  test("LowPriorityMemoryStreamImplicits works with implicit sqlContext") {
-    // Test that MemoryStream can be created using implicit sqlContext
-    implicit val sqlContext: SQLContext = spark.sqlContext
-
-    // Test MemoryStream[A]() with implicit sqlContext
-    val stream1 = MemoryStream[Int]()
-    assert(stream1 != null)
-
-    // Test MemoryStream[A](numPartitions) with implicit sqlContext
-    val stream2 = MemoryStream[String](3)
-    assert(stream2 != null)
-
-    // Verify the streams work correctly
-    stream1.addData(1, 2, 3)
-    val df1 = stream1.toDF()
-    assert(df1.schema.fieldNames.contains("value"))
-
-    stream2.addData("a", "b", "c")
-    val df2 = stream2.toDF()
-    assert(df2.schema.fieldNames.contains("value"))
-  }
-
-  test("LowPriorityContinuousMemoryStreamImplicits works with implicit 
sqlContext") {
-    import 
org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
-    // Test that ContinuousMemoryStream can be created using implicit 
sqlContext
-    implicit val sqlContext: SQLContext = spark.sqlContext
-
-    // Test ContinuousMemoryStream[A]() with implicit sqlContext
-    val stream1 = ContinuousMemoryStream[Int]()
-    assert(stream1 != null)
-
-    // Test ContinuousMemoryStream[A](numPartitions) with implicit sqlContext
-    val stream2 = ContinuousMemoryStream[String](3)
-    assert(stream2 != null)
-
-    // Test ContinuousMemoryStream.singlePartition with implicit sqlContext
-    val stream3 = ContinuousMemoryStream.singlePartition[Int]()
-    assert(stream3 != null)
-
-    // Verify the streams work correctly
-    stream1.addData(Seq(1, 2, 3))
-    stream2.addData(Seq("a", "b", "c"))
-    stream3.addData(Seq(10, 20))
-
-    // Basic verification that streams are functional
-    assert(stream1.initialOffset() != null)
-    assert(stream2.initialOffset() != null)
-    assert(stream3.initialOffset() != null)
-  }
-
-  test("LowPriorityLowLatencyMemoryStreamImplicits works with implicit 
sqlContext") {
-    import org.apache.spark.sql.execution.streaming.LowLatencyMemoryStream
-    // Test that LowLatencyMemoryStream can be created using implicit 
sqlContext
-    implicit val sqlContext: SQLContext = spark.sqlContext
-
-    // Test LowLatencyMemoryStream[A]() with implicit sqlContext
-    val stream1 = LowLatencyMemoryStream[Int]()
-    assert(stream1 != null)
-
-    // Test LowLatencyMemoryStream[A](numPartitions) with implicit sqlContext
-    val stream2 = LowLatencyMemoryStream[String](3)
-    assert(stream2 != null)
-
-    // Test LowLatencyMemoryStream.singlePartition with implicit sqlContext
-    val stream3 = LowLatencyMemoryStream.singlePartition[Int]()
-    assert(stream3 != null)
-
-    // Verify the streams work correctly
-    stream1.addData(Seq(1, 2, 3))
-    stream2.addData(Seq("a", "b", "c"))
-    stream3.addData(Seq(10, 20))
-
-    // Basic verification that streams are functional
-    assert(stream1.initialOffset() != null)
-    assert(stream2.initialOffset() != null)
-    assert(stream3.initialOffset() != null)
-  }
-
   private implicit def intsToDF(seq: Seq[Int])(implicit schema: StructType): 
DataFrame = {
     require(schema.fields.length === 1)
     sqlContext.createDataset(seq).toDF(schema.fieldNames.head)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
index 79bcdbca9ec6..4f2b78404131 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
@@ -123,14 +123,14 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
   test("query stop deactivates related store providers") {
     var coordRef: StateStoreCoordinatorRef = null
     try {
-      implicit val spark: SparkSession = 
SparkSession.builder().sparkContext(sc).getOrCreate()
+      val spark: SparkSession = 
SparkSession.builder().sparkContext(sc).getOrCreate()
       SparkSession.setActiveSession(spark)
       import spark.implicits._
       coordRef = spark.streams.stateStoreCoordinator
       spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")
 
       // Start a query and run a batch to load state stores
-      val inputData = MemoryStream[Int]
+      val inputData = MemoryStream[Int](spark)
       val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // 
stateful query
       val checkpointLocation = Utils.createTempDir().getAbsoluteFile
       val query = aggregated.writeStream
@@ -253,8 +253,8 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with 
SharedSparkContext {
       ) {
         case (coordRef, spark) =>
           import spark.implicits._
-          implicit val sparkSession: SparkSession = spark
-          val inputData = MemoryStream[Int]
+
+          val inputData = MemoryStream[Int](spark)
           val query = setUpStatefulQuery(inputData, "query")
           // Add, commit, and wait multiple times to force snapshot versions 
and time difference
           (0 until 6).foreach { _ =>
@@ -289,10 +289,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
       ) {
         case (coordRef, spark) =>
           import spark.implicits._
-          implicit val sparkSession: SparkSession = spark
+
           // Start a join query and run some data to force snapshot uploads
-          val input1 = MemoryStream[Int]
-          val input2 = MemoryStream[Int]
+          val input1 = MemoryStream[Int](spark)
+          val input2 = MemoryStream[Int](spark)
           val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) 
as "leftValue")
           val df2 = input2.toDF().select($"value" as "rightKey", ($"value" * 
3) as "rightValue")
           val joined = df1.join(df2, expr("leftKey = rightKey"))
@@ -332,10 +332,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite 
with SharedSparkContext {
     ) {
       case (coordRef, spark) =>
         import spark.implicits._
-        implicit val sparkSession: SparkSession = spark
+
         // Start and run two queries together with some data to force snapshot 
uploads
-        val input1 = MemoryStream[Int]
-        val input2 = MemoryStream[Int]
+        val input1 = MemoryStream[Int](spark)
+        val input2 = MemoryStream[Int](spark)
         val query1 = setUpStatefulQuery(input1, "query1")
         val query2 = setUpStatefulQuery(input2, "query2")
 
@@ -399,9 +399,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with 
SharedSparkContext {
     ) {
       case (coordRef, spark) =>
         import spark.implicits._
-        implicit val sparkSession: SparkSession = spark
+
         // Start a query and run some data to force snapshot uploads
-        val inputData = MemoryStream[Int]
+        val inputData = MemoryStream[Int](spark)
         val query = setUpStatefulQuery(inputData, "query")
 
         // Go through two batches to force two snapshot uploads.
@@ -443,9 +443,9 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with 
SharedSparkContext {
     ) {
       case (coordRef, spark) =>
         import spark.implicits._
-        implicit val sparkSession: SparkSession = spark
+
         // Start a query and run some data to force snapshot uploads
-        val inputData = MemoryStream[Int]
+        val inputData = MemoryStream[Int](spark)
         val query = setUpStatefulQuery(inputData, "query")
 
         // Go through several rounds of input to force snapshot uploads
@@ -486,7 +486,7 @@ class StateStoreCoordinatorStreamingSuite extends 
StreamTest {
         SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"
       ) {
         withTempDir { srcDir =>
-          val inputData = MemoryStream[Int]
+          val inputData = MemoryStream[Int](spark)
           val query = inputData.toDF().dropDuplicates()
           val numPartitions = 
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)
           // Keep track of state checkpoint directory for the second run
@@ -608,7 +608,7 @@ class StateStoreCoordinatorStreamingSuite extends 
StreamTest {
       SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"
     ) {
       withTempDir { srcDir =>
-        val inputData = MemoryStream[Int]
+        val inputData = MemoryStream[Int](spark)
         val query = inputData.toDF().dropDuplicates()
 
         testStream(query)(
@@ -686,7 +686,7 @@ class StateStoreCoordinatorStreamingSuite extends 
StreamTest {
         SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key -> "0"
       ) {
         withTempDir { srcDir =>
-          val inputData = MemoryStream[Int]
+          val inputData = MemoryStream[Int](spark)
           val query = inputData.toDF().dropDuplicates()
 
           // Populate state stores with an initial snapshot, so that timestamp 
isn't marked
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index e839ccd35ec0..232332a6575a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -1213,11 +1213,11 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   test("SPARK-21145: Restarted queries create new provider instances") {
     try {
       val checkpointLocation = Utils.createTempDir().getAbsoluteFile
-      implicit val spark: SparkSession = 
SparkSession.builder().master("local[2]").getOrCreate()
+      val spark: SparkSession = 
SparkSession.builder().master("local[2]").getOrCreate()
       SparkSession.setActiveSession(spark)
       spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")
       import spark.implicits._
-      val inputData = MemoryStream[Int]
+      val inputData = MemoryStream[Int](spark)
 
       def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = {
         val aggregated = inputData.toDF().groupBy("value").agg(count("*"))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 93efbe3b3cf5..4cd3f849a594 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -827,7 +827,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
 
     def constructUnionDf(desiredPartitionsForInput1: Int)
       : (MemoryStream[String], MemoryStream[String], DataFrame) = {
-      val input1 = MemoryStream[String](desiredPartitionsForInput1)
+      val input1 = MemoryStream[String](spark, desiredPartitionsForInput1)
       val input2 = MemoryStream[String]
       val df1 = input1.toDF()
         .select($"value", $"value")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 7825730d901d..f065f1de5cdc 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -347,7 +347,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest with Assertions {
     " shifted partition IDs") {
     def constructUnionDf(desiredPartitionsForInput1: Int)
       : (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
-      val input1 = MemoryStream[Int](desiredPartitionsForInput1)
+      val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1)
       val input2 = MemoryStream[Int]
       val df1 = input1.toDF()
         .select($"value", $"value" + 1)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
index 832b22d6304f..11fc9cbfc484 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
@@ -334,7 +334,7 @@ class StreamingDeduplicationSuite extends 
StateStoreMetricsTest {
     " shifted partition IDs") {
     def constructUnionDf(desiredPartitionsForInput1: Int)
       : (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
-      val input1 = MemoryStream[Int](desiredPartitionsForInput1)
+      val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1)
       val input2 = MemoryStream[Int]
       val df1 = input1.toDF().select($"value")
       val df2 = input2.toDF().dropDuplicates("value")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 22028a585e22..6cdca9fb5309 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -1609,7 +1609,7 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite {
   test("SPARK-29438: ensure UNION doesn't lead stream-stream join to use 
shifted partition IDs") {
     def constructUnionDf(desiredPartitionsForInput1: Int)
         : (MemoryStream[Int], MemoryStream[Int], MemoryStream[Int], DataFrame) 
= {
-      val input1 = MemoryStream[Int](desiredPartitionsForInput1)
+      val input1 = MemoryStream[Int](spark, desiredPartitionsForInput1)
       val df1 = input1.toDF()
         .select(
           $"value" as "key",
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index baafdc1ea50a..86041d48cde7 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -2655,7 +2655,7 @@ class HiveDDLSuite
           |SELECT word, number from t1
         """.stripMargin)
 
-      val inputData = MemoryStream[Int]
+      val inputData = MemoryStream[Int](spark)
       val joined = inputData.toDS().toDF()
         .join(spark.table("smallTable"), $"value" === $"number")
 
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala
index 7c8181b5b72a..f37716b4a24d 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.pipelines.graph
 
-import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
 import org.apache.spark.sql.pipelines.utils.{PipelineTest, 
TestGraphRegistrationContext}
 import org.apache.spark.sql.test.SharedSparkSession
@@ -423,7 +423,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with 
SharedSparkSession {
     import session.implicits._
 
     val p = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val mem = MemoryStream[Int]
       mem.addData(1)
       registerPersistedView("a", query = dfFlowFunc(mem.toDF()))
@@ -467,7 +466,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with 
SharedSparkSession {
     import session.implicits._
 
     val graph = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       registerMaterializedView("a", query = 
dfFlowFunc(MemoryStream[Int].toDF()))
     }.resolveToDataflowGraph()
 
@@ -491,7 +489,6 @@ class ConnectInvalidPipelineSuite extends PipelineTest with 
SharedSparkSession {
 
     val graph = new TestGraphRegistrationContext(spark) {
       registerTable("a")
-      implicit val sparkSession: SparkSession = spark
       registerFlow(
         destinationName = "a",
         name = "once_flow",
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
index a4bb7c067d87..3ac3c0901750 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.pipelines.graph
 
-import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.catalyst.plans.logical.Union
@@ -159,7 +158,6 @@ class ConnectValidPipelineSuite extends PipelineTest with 
SharedSparkSession {
     import session.implicits._
 
     class P extends TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val ints = MemoryStream[Int]
       ints.addData(1, 2, 3, 4)
       registerPersistedView("a", query = dfFlowFunc(ints.toDF()))
@@ -201,7 +199,6 @@ class ConnectValidPipelineSuite extends PipelineTest with 
SharedSparkSession {
     import session.implicits._
 
     class P extends TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val ints1 = MemoryStream[Int]
       ints1.addData(1, 2, 3, 4)
       val ints2 = MemoryStream[Int]
@@ -362,7 +359,6 @@ class ConnectValidPipelineSuite extends PipelineTest with 
SharedSparkSession {
     import session.implicits._
 
     class P extends TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val mem = MemoryStream[Int]
       registerPersistedView("a", query = dfFlowFunc(mem.toDF()))
       registerTable("b")
@@ -406,7 +402,6 @@ class ConnectValidPipelineSuite extends PipelineTest with 
SharedSparkSession {
     import session.implicits._
 
     val graph = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val mem = MemoryStream[Int]
       mem.addData(1, 2)
       registerPersistedView("complete-view", query = dfFlowFunc(Seq(1, 
2).toDF("x")))
@@ -499,7 +494,6 @@ class ConnectValidPipelineSuite extends PipelineTest with 
SharedSparkSession {
     import session.implicits._
 
     val P = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val mem = MemoryStream[Int]
       mem.addData(1, 2)
       registerTemporaryView("a", query = dfFlowFunc(mem.toDF().select($"value" 
as "x")))
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index ba8419eb6e9c..72cc644e5768 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -20,8 +20,7 @@ package org.apache.spark.sql.pipelines.graph
 import scala.jdk.CollectionConverters._
 
 import org.apache.spark.SparkThrowable
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.{AnalysisException, SQLContext}
 import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, 
TableCatalog}
 import org.apache.spark.sql.connector.expressions.{ClusterByTransform, 
Expressions, FieldReference}
 import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
@@ -269,7 +268,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
 
   test("invalid schema merge") {
     val session = spark
-    implicit val sparkSession: SparkSession = spark
+    implicit val sqlCtx: SQLContext = spark.sqlContext
     import session.implicits._
 
     val streamInts = MemoryStream[Int]
@@ -353,7 +352,6 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
 
     val ex = intercept[TableMaterializationException] {
       materializeGraph(new TestGraphRegistrationContext(spark) {
-        implicit val sparkSession: SparkSession = spark
         val source: MemoryStream[Int] = MemoryStream[Int]
         source.addData(1, 2)
         registerTable(
@@ -646,7 +644,7 @@ abstract class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
       s"Streaming tables should evolve schema only if not full refresh = 
$isFullRefresh"
     ) {
       val session = spark
-      implicit val sparkSession: SparkSession = spark
+      implicit val sqlCtx: SQLContext = spark.sqlContext
       import session.implicits._
 
       val streamInts = MemoryStream[Int]
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
index c37a6fb52f95..71301c34c14e 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SystemMetadataSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.pipelines.graph
 
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryWrapper}
 import org.apache.spark.sql.pipelines.utils.{ExecutionTest, 
TestGraphRegistrationContext}
@@ -39,7 +38,6 @@ class SystemMetadataSuite
 
       // create a pipeline with only a single ST
       val graph = new TestGraphRegistrationContext(spark) {
-        implicit val sparkSession: SparkSession = spark
         val mem: MemoryStream[Int] = MemoryStream[Int]
         mem.addData(1, 2, 3)
         registerView("a", query = dfFlowFunc(mem.toDF()))
@@ -107,7 +105,6 @@ class SystemMetadataSuite
     import session.implicits._
 
     val graph = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val mem: MemoryStream[Int] = MemoryStream[Int]
       mem.addData(1, 2, 3)
       registerView("a", query = dfFlowFunc(mem.toDF()))
@@ -172,7 +169,6 @@ class SystemMetadataSuite
     import session.implicits._
 
     val graph = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val mem: MemoryStream[Int] = MemoryStream[Int]
       mem.addData(1, 2, 3)
       registerView("a", query = dfFlowFunc(mem.toDF()))
@@ -234,7 +230,6 @@ class SystemMetadataSuite
 
     // create a pipeline with only a single ST
     val graph = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       val mem: MemoryStream[Int] = MemoryStream[Int]
       mem.addData(1, 2, 3)
       registerView("a", query = dfFlowFunc(mem.toDF()))
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
index 36b749cc84d9..57baf4c2d5b1 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.pipelines.graph
 
 import org.scalatest.time.{Seconds, Span}
 
-import org.apache.spark.sql.{functions, Row, SparkSession}
+import org.apache.spark.sql.{functions, Row}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.classic.{DataFrame, Dataset}
 import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, 
TableCatalog}
@@ -183,7 +183,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
 
     // Construct pipeline
     val pipelineDef = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       private val ints = MemoryStream[Int]
       ints.addData(1 until 10: _*)
       registerView("input", query = dfFlowFunc(ints.toDF()))
@@ -260,7 +259,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
 
     // Construct pipeline
     val pipelineDef = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       private val ints = MemoryStream[Int]
       registerView("input", query = dfFlowFunc(ints.toDF()))
       registerTable(
@@ -311,7 +309,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
     })
 
     val pipelineDef = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       private val memoryStream = MemoryStream[Int]
       memoryStream.addData(1, 2)
       registerView("input_view", query = dfFlowFunc(memoryStream.toDF()))
@@ -551,7 +548,6 @@ class TriggeredGraphExecutionSuite extends ExecutionTest 
with SharedSparkSession
 
     // Construct pipeline
     val pipelineDef = new TestGraphRegistrationContext(spark) {
-      implicit val sparkSession: SparkSession = spark
       private val memoryStream = MemoryStream[Int]
       memoryStream.addData(1, 2)
       registerView("input_view", query = dfFlowFunc(memoryStream.toDF()))
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
index e7c095638513..9ff92ee895b1 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/TestGraphRegistrationContext.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.pipelines.utils
 
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.{LocalTempView, PersistedView => 
PersistedViewType, UnresolvedRelation, ViewType}
 import org.apache.spark.sql.classic.{DataFrame, SparkSession}
@@ -28,7 +29,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
  * A test class to simplify the creation of pipelines and datasets for unit 
testing.
  */
 class TestGraphRegistrationContext(
-    val spark: SparkSession,
+    val _spark: SparkSession,
     val sqlConf: Map[String, String] = Map.empty)
     extends GraphRegistrationContext(
       defaultCatalog = TestGraphRegistrationContext.DEFAULT_CATALOG,
@@ -36,6 +37,10 @@ class TestGraphRegistrationContext(
       defaultSqlConf = sqlConf
     ) {
 
+  /** Re-expose as implicit so nested anonymous classes can use it without 
shadowing issues */
+  implicit def spark: SparkSession = _spark
+  implicit def sqlContext: SQLContext = _spark.sqlContext
+
   // scalastyle:off
   // Disable scalastyle to ignore argument count.
   /** Registers a streaming table in this [[TestGraphRegistrationContext]] */
@@ -145,7 +150,7 @@ class TestGraphRegistrationContext(
     val qualifiedIdentifier = GraphIdentifierManager
           .parseAndQualifyTableIdentifier(
             rawTableIdentifier = GraphIdentifierManager
-              .parseTableIdentifier(name, spark),
+              .parseTableIdentifier(name, _spark),
             currentCatalog = catalog.orElse(Some(defaultCatalog)),
             currentDatabase = database.orElse(Some(defaultDatabase)))
           .identifier
@@ -304,9 +309,9 @@ class TestGraphRegistrationContext(
       catalog: Option[String] = None,
       database: Option[String] = None
   ): Unit = {
-    val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, 
spark)
+    val rawFlowIdentifier = GraphIdentifierManager.parseTableIdentifier(name, 
_spark)
     val rawDestinationIdentifier =
-      GraphIdentifierManager.parseTableIdentifier(destinationName, spark)
+      GraphIdentifierManager.parseTableIdentifier(destinationName, _spark)
 
     val flowWritesToView = getViews
         .filter(_.isInstanceOf[TemporaryView])


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to