This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 88696eb  [SPARK-38121][PYTHON][SQL] Use SparkSession instead of 
SQLContext inside PySpark
88696eb is described below

commit 88696ebcb72fd3057b1546831f653b29b7e0abb2
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Feb 15 09:42:31 2022 +0900

    [SPARK-38121][PYTHON][SQL] Use SparkSession instead of SQLContext inside 
PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to `SparkSession` within PySpark. This is a base work for 
respecting runtime configurations, etc. Currently, we rely on old deprecated 
`SQLContext` internally that doesn't respect Spark session's runtime 
configurations correctly.
    
    This PR also contains related changes (and a bit of refactoring in the code 
this PR touches) as below:
    - Expose `DataFrame.sparkSession` like Scala API does.
    - Move `SQLContext._conf` -> `SparkSession._jconf`.
    - Rename `rdd_array` to `df_array` at `DataFrame.randomSplit`.
    - Issue warnings to discourage to use `DataFrame.sql_ctx` and 
`DataFrame(..., sql_ctx)`.
    
    ### Why are the changes needed?
    
    - This is a base work for PySpark to respect runtime configuration.
    - To expose the same API layer as Scala API (`df.sparkSession`)
    - To avoid relaying on old `SQLContext`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    - Issue warnings to discourage to use `DataFrame.sql_ctx` and 
`DataFrame(..., sql_ctx)`.
    - New API `DataFrame.sparkSession`
    
    ### How was this patch tested?
    
    Existing test cases should cover them.
    
    Closes #35410 from HyukjinKwon/SPARK-38121.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/docs/source/reference/pyspark.sql.rst       |   1 +
 python/pyspark/ml/clustering.py                    |   2 +-
 python/pyspark/ml/common.py                        |   2 +-
 python/pyspark/ml/fpm.py                           |   2 +-
 python/pyspark/ml/wrapper.py                       |   2 +-
 python/pyspark/mllib/common.py                     |   2 +-
 python/pyspark/pandas/internal.py                  |   2 +-
 python/pyspark/shell.py                            |   3 +-
 python/pyspark/sql/catalog.py                      |   4 +-
 python/pyspark/sql/context.py                      |  42 +++--
 python/pyspark/sql/dataframe.py                    | 185 ++++++++++++++-------
 python/pyspark/sql/functions.py                    |   2 +-
 python/pyspark/sql/group.py                        |  14 +-
 python/pyspark/sql/observation.py                  |   2 +-
 python/pyspark/sql/pandas/conversion.py            |  33 ++--
 python/pyspark/sql/pandas/group_ops.py             |   5 +-
 python/pyspark/sql/pandas/map_ops.py               |   4 +-
 python/pyspark/sql/readwriter.py                   |  12 +-
 python/pyspark/sql/session.py                      |  53 +++---
 python/pyspark/sql/streaming.py                    |   8 +-
 python/pyspark/sql/tests/test_session.py           |  16 +-
 python/pyspark/sql/tests/test_streaming.py         |  16 +-
 python/pyspark/sql/tests/test_udf.py               |   8 +-
 python/pyspark/sql/tests/test_udf_profiler.py      |  11 +-
 python/pyspark/sql/utils.py                        |   8 +-
 .../spark/sql/api/python/PythonSQLUtils.scala      |  14 +-
 .../org/apache/spark/sql/api/r/SQLUtils.scala      |   4 +-
 .../sql/execution/arrow/ArrowConverters.scala      |  12 +-
 28 files changed, 266 insertions(+), 203 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql.rst 
b/python/docs/source/reference/pyspark.sql.rst
index 818814c..1d34961 100644
--- a/python/docs/source/reference/pyspark.sql.rst
+++ b/python/docs/source/reference/pyspark.sql.rst
@@ -201,6 +201,7 @@ DataFrame APIs
     DataFrame.show
     DataFrame.sort
     DataFrame.sortWithinPartitions
+    DataFrame.sparkSession
     DataFrame.stat
     DataFrame.storageLevel
     DataFrame.subtract
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index a66d6e3..c8b4c93 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -2107,7 +2107,7 @@ class PowerIterationClustering(
         assert self._java_obj is not None
 
         jdf = self._java_obj.assignClusters(dataset._jdf)
-        return DataFrame(jdf, dataset.sql_ctx)
+        return DataFrame(jdf, dataset.sparkSession)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
index 2329421..32829c4 100644
--- a/python/pyspark/ml/common.py
+++ b/python/pyspark/ml/common.py
@@ -108,7 +108,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", 
encoding: str = "byt
             return RDD(jrdd, sc)
 
         if clsName == "Dataset":
-            return DataFrame(r, SparkSession(sc)._wrapped)
+            return DataFrame(r, SparkSession._getActiveSessionOrCreate())
 
         if clsName in _picklable_classes:
             r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 0795ec2..b748a7d 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -510,7 +510,7 @@ class PrefixSpan(JavaParams):
         self._transfer_params_to_java()
         assert self._java_obj is not None
         jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
-        return DataFrame(jdf, dataset.sql_ctx)
+        return DataFrame(jdf, dataset.sparkSession)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 7f03f64..385a243 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -393,7 +393,7 @@ class JavaTransformer(JavaParams, Transformer, 
metaclass=ABCMeta):
         assert self._java_obj is not None
 
         self._transfer_params_to_java()
-        return DataFrame(self._java_obj.transform(dataset._jdf), 
dataset.sql_ctx)
+        return DataFrame(self._java_obj.transform(dataset._jdf), 
dataset.sparkSession)
 
 
 @inherit_doc
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index 24a3f41..00653aa 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -110,7 +110,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", 
encoding: str = "byt
             return RDD(jrdd, sc)
 
         if clsName == "Dataset":
-            return DataFrame(r, SparkSession(sc)._wrapped)
+            return DataFrame(r, SparkSession._getActiveSessionOrCreate())
 
         if clsName in _picklable_classes:
             r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
diff --git a/python/pyspark/pandas/internal.py 
b/python/pyspark/pandas/internal.py
index 1c32c43..71f8f6e 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -906,7 +906,7 @@ class InternalFrame:
         if len(sdf.columns) > 0:
             return SparkDataFrame(
                 sdf._jdf.toDF().withSequenceColumn(column_name),  # type: 
ignore[operator]
-                sdf.sql_ctx,
+                sdf.sparkSession,
             )
         else:
             cnt = sdf.count()
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 4164e3a..e0a8c06 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -28,6 +28,7 @@ import warnings
 
 from pyspark.context import SparkContext
 from pyspark.sql import SparkSession
+from pyspark.sql.context import SQLContext
 
 if os.environ.get("SPARK_EXECUTOR_URI"):
     SparkContext.setSystemProperty("spark.executor.uri", 
os.environ["SPARK_EXECUTOR_URI"])
@@ -49,7 +50,7 @@ sql = spark.sql
 atexit.register((lambda sc: lambda: sc.stop())(sc))
 
 # for compatibility
-sqlContext = spark._wrapped
+sqlContext = SQLContext._get_or_create(sc)
 sqlCtx = sqlContext
 
 print(
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index ea8bb97..3ececfa 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -345,7 +345,7 @@ class Catalog:
         if path is not None:
             options["path"] = path
         if source is None:
-            c = self._sparkSession._wrapped._conf
+            c = self._sparkSession._jconf
             source = c.defaultDataSourceName()  # type: ignore[attr-defined]
         if description is None:
             description = ""
@@ -356,7 +356,7 @@ class Catalog:
                 raise TypeError("schema should be StructType")
             scala_datatype = self._jsparkSession.parseDataType(schema.json())
             df = self._jcatalog.createTable(tableName, source, scala_datatype, 
description, options)
-        return DataFrame(df, self._sparkSession._wrapped)
+        return DataFrame(df, self._sparkSession)
 
     def dropTempView(self, viewName: str) -> None:
         """Drops the local temporary view with the given view name in the 
catalog.
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 6ab70ee..6f94e9a 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -46,7 +46,6 @@ from pyspark.context import SparkContext
 from pyspark.rdd import RDD
 from pyspark.sql.types import AtomicType, DataType, StructType
 from pyspark.sql.streaming import StreamingQueryManager
-from pyspark.conf import SparkConf
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import (
@@ -121,7 +120,7 @@ class SQLContext:
         if sparkSession is None:
             sparkSession = SparkSession._getActiveSessionOrCreate()
         if jsqlContext is None:
-            jsqlContext = sparkSession._jwrapped
+            jsqlContext = sparkSession._jsparkSession.sqlContext()
         self.sparkSession = sparkSession
         self._jsqlContext = jsqlContext
         _monkey_patch_RDD(self.sparkSession)
@@ -141,11 +140,6 @@ class SQLContext:
         """
         return self._jsqlContext
 
-    @property
-    def _conf(self) -> SparkConf:
-        """Accessor for the JVM SQL-specific configurations"""
-        return self.sparkSession._jsparkSession.sessionState().conf()
-
     @classmethod
     def getOrCreate(cls: Type["SQLContext"], sc: SparkContext) -> "SQLContext":
         """
@@ -164,17 +158,20 @@ class SQLContext:
             "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() 
instead.",
             FutureWarning,
         )
+        return cls._get_or_create(sc)
+
+    @classmethod
+    def _get_or_create(cls: Type["SQLContext"], sc: SparkContext) -> 
"SQLContext":
 
         if (
             cls._instantiatedContext is None
             or SQLContext._instantiatedContext._sc._jsc is None  # type: 
ignore[union-attr]
         ):
             assert sc._jvm is not None
-            jsqlContext = (
-                
sc._jvm.SparkSession.builder().sparkContext(sc._jsc.sc()).getOrCreate().sqlContext()
-            )
-            sparkSession = SparkSession(sc, jsqlContext.sparkSession())
-            cls(sc, sparkSession, jsqlContext)
+            # There can be only one running Spark context. That will 
automatically
+            # be used in the Spark session internally.
+            session = SparkSession._getActiveSessionOrCreate()
+            cls(sc, session, session._jsparkSession.sqlContext())
         return cast(SQLContext, cls._instantiatedContext)
 
     def newSession(self) -> "SQLContext":
@@ -590,9 +587,9 @@ class SQLContext:
         Row(namespace='', tableName='table1', isTemporary=True)
         """
         if dbName is None:
-            return DataFrame(self._ssql_ctx.tables(), self)
+            return DataFrame(self._ssql_ctx.tables(), self.sparkSession)
         else:
-            return DataFrame(self._ssql_ctx.tables(dbName), self)
+            return DataFrame(self._ssql_ctx.tables(dbName), self.sparkSession)
 
     def tableNames(self, dbName: Optional[str] = None) -> List[str]:
         """Returns a list of names of tables in the database ``dbName``.
@@ -647,7 +644,7 @@ class SQLContext:
         -------
         :class:`DataFrameReader`
         """
-        return DataFrameReader(self)
+        return DataFrameReader(self.sparkSession)
 
     @property
     def readStream(self) -> DataStreamReader:
@@ -669,7 +666,7 @@ class SQLContext:
         >>> text_sdf.isStreaming
         True
         """
-        return DataStreamReader(self)
+        return DataStreamReader(self.sparkSession)
 
     @property
     def streams(self) -> StreamingQueryManager:
@@ -714,14 +711,13 @@ class HiveContext(SQLContext):
             + "SparkSession.builder.enableHiveSupport().getOrCreate() 
instead.",
             FutureWarning,
         )
+        static_conf = {}
         if jhiveContext is None:
-            sparkContext._conf.set(  # type: ignore[attr-defined]
-                "spark.sql.catalogImplementation", "hive"
-            )
-            sparkSession = 
SparkSession.builder._sparkContext(sparkContext).getOrCreate()
-        else:
-            sparkSession = SparkSession(sparkContext, 
jhiveContext.sparkSession())
-        SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext)
+            static_conf = {"spark.sql.catalogImplementation": "in-memory"}
+        # There can be only one running Spark context. That will automatically
+        # be used in the Spark session internally.
+        session = SparkSession._getActiveSessionOrCreate(**static_conf)
+        SQLContext.__init__(self, sparkContext, session, jhiveContext)
 
     @classmethod
     def _createForTesting(cls, sparkContext: SparkContext) -> "HiveContext":
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 0372527..610b8d6 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -16,6 +16,7 @@
 #
 
 import json
+import os
 import sys
 import random
 import warnings
@@ -70,6 +71,7 @@ if TYPE_CHECKING:
     from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
     from pyspark.sql._typing import ColumnOrName, LiteralType, 
OptionalPrimitiveType
     from pyspark.sql.context import SQLContext
+    from pyspark.sql.session import SparkSession
     from pyspark.sql.group import GroupedData
     from pyspark.sql.observation import Observation
 
@@ -102,12 +104,34 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
           .groupBy(department.name, "gender").agg({"salary": "avg", "age": 
"max"})
 
     .. versionadded:: 1.3.0
+
+    .. note: A DataFrame should only be created as described above. It should 
not be directly
+        created via using the constructor.
     """
 
-    def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"):
-        self._jdf = jdf
-        self.sql_ctx = sql_ctx
-        self._sc: SparkContext = cast(SparkContext, sql_ctx and sql_ctx._sc)
+    def __init__(
+        self,
+        jdf: JavaObject,
+        sql_ctx: Union["SQLContext", "SparkSession"],
+    ):
+        from pyspark.sql.context import SQLContext
+
+        self._session: Optional["SparkSession"] = None
+        self._sql_ctx: Optional["SQLContext"] = None
+
+        if isinstance(sql_ctx, SQLContext):
+            assert not os.environ.get("SPARK_TESTING")  # Sanity check for our 
internal usage.
+            assert isinstance(sql_ctx, SQLContext)
+            # We should remove this if-else branch in the future release, and 
rename
+            # sql_ctx to session in the constructor. This is an internal code 
path but
+            # was kept with an warning because it's used intensively by 
third-party libraries.
+            warnings.warn("DataFrame constructor is internal. Do not directly 
use it.")
+            self._sql_ctx = sql_ctx
+        else:
+            self._session = sql_ctx
+
+        self._sc: SparkContext = sql_ctx._sc
+        self._jdf: JavaObject = jdf
         self.is_cached = False
         # initialized lazily
         self._schema: Optional[StructType] = None
@@ -116,13 +140,45 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         # by __repr__ and _repr_html_ while eager evaluation opened.
         self._support_repr_html = False
 
+    @property
+    def sql_ctx(self) -> "SQLContext":
+        from pyspark.sql.context import SQLContext
+
+        warnings.warn(
+            "DataFrame.sql_ctx is an internal property, and will be removed "
+            "in future releases. Use DataFrame.sparkSession instead."
+        )
+        if self._sql_ctx is None:
+            self._sql_ctx = SQLContext._get_or_create(self._sc)
+        return self._sql_ctx
+
+    @property  # type: ignore[misc]
+    def sparkSession(self) -> "SparkSession":
+        """Returns Spark session that created this :class:`DataFrame`.
+
+        .. versionadded:: 3.3.0
+
+        Examples
+        --------
+        >>> df = spark.range(1)
+        >>> type(df.sparkSession)
+        <class 'pyspark.sql.session.SparkSession'>
+        """
+        from pyspark.sql.session import SparkSession
+
+        if self._session is None:
+            self._session = SparkSession._getActiveSessionOrCreate()
+        return self._session
+
     @property  # type: ignore[misc]
     @since(1.3)
     def rdd(self) -> "RDD[Row]":
         """Returns the content as an :class:`pyspark.RDD` of :class:`Row`."""
         if self._lazy_rdd is None:
             jrdd = self._jdf.javaToPython()
-            self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, 
BatchedSerializer(CPickleSerializer()))
+            self._lazy_rdd = RDD(
+                jrdd, self.sparkSession._sc, 
BatchedSerializer(CPickleSerializer())
+            )
         return self._lazy_rdd
 
     @property  # type: ignore[misc]
@@ -456,7 +512,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         +---+---+
 
         """
-        return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx)
+        return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession)
 
     @since(1.3)
     def isLocal(self) -> bool:
@@ -563,12 +619,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
     def __repr__(self) -> str:
         if (
             not self._support_repr_html
-            and self.sql_ctx._conf.isReplEagerEvalEnabled()  # type: 
ignore[attr-defined]
+            and self.sparkSession._jconf.isReplEagerEvalEnabled()  # type: 
ignore[attr-defined]
         ):
             vertical = False
             return self._jdf.showString(
-                self.sql_ctx._conf.replEagerEvalMaxNumRows(),  # type: 
ignore[attr-defined]
-                self.sql_ctx._conf.replEagerEvalTruncate(),  # type: 
ignore[attr-defined]
+                self.sparkSession._jconf.replEagerEvalMaxNumRows(),  # type: 
ignore[attr-defined]
+                self.sparkSession._jconf.replEagerEvalTruncate(),  # type: 
ignore[attr-defined]
                 vertical,
             )  # type: ignore[attr-defined]
         else:
@@ -581,13 +637,13 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         if not self._support_repr_html:
             self._support_repr_html = True
-        if self.sql_ctx._conf.isReplEagerEvalEnabled():  # type: 
ignore[attr-defined]
+        if self.sparkSession._jconf.isReplEagerEvalEnabled():  # type: 
ignore[attr-defined]
             max_num_rows = max(
-                self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0  # type: 
ignore[attr-defined]
+                self.sparkSession._jconf.replEagerEvalMaxNumRows(), 0  # type: 
ignore[attr-defined]
             )
             sock_info = self._jdf.getRowsToPython(
                 max_num_rows,
-                self.sql_ctx._conf.replEagerEvalTruncate(),  # type: 
ignore[attr-defined]
+                self.sparkSession._jconf.replEagerEvalTruncate(),  # type: 
ignore[attr-defined]
             )
             rows = list(_load_from_socket(sock_info, 
BatchedSerializer(CPickleSerializer())))
             head = rows[0]
@@ -631,7 +687,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         This API is experimental.
         """
         jdf = self._jdf.checkpoint(eager)
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def localCheckpoint(self, eager: bool = True) -> "DataFrame":
         """Returns a locally checkpointed version of this :class:`DataFrame`. 
Checkpointing can be
@@ -651,7 +707,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         This API is experimental.
         """
         jdf = self._jdf.localCheckpoint(eager)
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def withWatermark(self, eventTime: str, delayThreshold: str) -> 
"DataFrame":
         """Defines an event time watermark for this :class:`DataFrame`. A 
watermark tracks a point
@@ -695,7 +751,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         if not delayThreshold or type(delayThreshold) is not str:
             raise TypeError("delayThreshold should be provided as a string 
interval")
         jdf = self._jdf.withWatermark(eventTime, delayThreshold)
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def hint(
         self, name: str, *parameters: Union["PrimitiveType", 
List["PrimitiveType"]]
@@ -740,7 +796,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                 )
 
         jdf = self._jdf.hint(name, self._jseq(parameters))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def count(self) -> int:
         """Returns the number of rows in this :class:`DataFrame`.
@@ -804,7 +860,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         []
         """
         jdf = self._jdf.limit(num)
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def take(self, num: int) -> List[Row]:
         """Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
@@ -970,7 +1026,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         >>> df.coalesce(1).rdd.getNumPartitions()
         1
         """
-        return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
+        return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)
 
     @overload
     def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> 
"DataFrame":
@@ -1041,14 +1097,15 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         """
         if isinstance(numPartitions, int):
             if len(cols) == 0:
-                return DataFrame(self._jdf.repartition(numPartitions), 
self.sql_ctx)
+                return DataFrame(self._jdf.repartition(numPartitions), 
self.sparkSession)
             else:
                 return DataFrame(
-                    self._jdf.repartition(numPartitions, self._jcols(*cols)), 
self.sql_ctx
+                    self._jdf.repartition(numPartitions, self._jcols(*cols)),
+                    self.sparkSession,
                 )
         elif isinstance(numPartitions, (str, Column)):
             cols = (numPartitions,) + cols
-            return DataFrame(self._jdf.repartition(self._jcols(*cols)), 
self.sql_ctx)
+            return DataFrame(self._jdf.repartition(self._jcols(*cols)), 
self.sparkSession)
         else:
             raise TypeError("numPartitions should be an int or Column")
 
@@ -1115,11 +1172,12 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
                 raise ValueError("At least one partition-by expression must be 
specified.")
             else:
                 return DataFrame(
-                    self._jdf.repartitionByRange(numPartitions, 
self._jcols(*cols)), self.sql_ctx
+                    self._jdf.repartitionByRange(numPartitions, 
self._jcols(*cols)),
+                    self.sparkSession,
                 )
         elif isinstance(numPartitions, (str, Column)):
             cols = (numPartitions,) + cols
-            return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), 
self.sql_ctx)
+            return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), 
self.sparkSession)
         else:
             raise TypeError("numPartitions should be an int, string or Column")
 
@@ -1133,7 +1191,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         >>> df.distinct().count()
         2
         """
-        return DataFrame(self._jdf.distinct(), self.sql_ctx)
+        return DataFrame(self._jdf.distinct(), self.sparkSession)
 
     @overload
     def sample(self, fraction: float, seed: Optional[int] = ...) -> 
"DataFrame":
@@ -1228,7 +1286,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         seed = int(seed) if seed is not None else None
         args = [arg for arg in [withReplacement, fraction, seed] if arg is not 
None]
         jdf = self._jdf.sample(*args)
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def sampleBy(
         self, col: "ColumnOrName", fractions: Dict[Any, float], seed: 
Optional[int] = None
@@ -1283,7 +1341,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             fractions[k] = float(v)
         col = col._jc
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
-        return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), 
seed), self.sql_ctx)
+        return DataFrame(
+            self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), 
self.sparkSession
+        )
 
     def randomSplit(self, weights: List[float], seed: Optional[int] = None) -> 
List["DataFrame"]:
         """Randomly splits this :class:`DataFrame` with the provided weights.
@@ -1311,10 +1371,10 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             if w < 0.0:
                 raise ValueError("Weights must be positive. Found weight 
value: %s" % w)
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
-        rdd_array = self._jdf.randomSplit(
-            _to_list(self.sql_ctx._sc, cast(List["ColumnOrName"], weights)), 
int(seed)
+        df_array = self._jdf.randomSplit(
+            _to_list(self.sparkSession._sc, cast(List["ColumnOrName"], 
weights)), int(seed)
         )
-        return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
+        return [DataFrame(df, self.sparkSession) for df in df_array]
 
     @property
     def dtypes(self) -> List[Tuple[str, str]]:
@@ -1392,7 +1452,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         [Row(name='Bob', name='Bob', age=5), Row(name='Alice', name='Alice', 
age=2)]
         """
         assert isinstance(alias, str), "alias should be a string"
-        return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
+        return DataFrame(getattr(self._jdf, "as")(alias), self.sparkSession)
 
     def crossJoin(self, other: "DataFrame") -> "DataFrame":
         """Returns the cartesian product with another :class:`DataFrame`.
@@ -1416,7 +1476,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
 
         jdf = self._jdf.crossJoin(other._jdf)
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def join(
         self,
@@ -1486,7 +1546,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                 on = self._jseq([])
             assert isinstance(how, str), "how should be a string"
             jdf = self._jdf.join(other._jdf, on, how)
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     # TODO(SPARK-22947): Fix the DataFrame API.
     def _joinAsOf(
@@ -1607,7 +1667,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             allowExactMatches,
             direction,
         )
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def sortWithinPartitions(
         self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
@@ -1639,7 +1699,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         +---+-----+
         """
         jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def sort(
         self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
@@ -1677,7 +1737,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         [Row(age=5, name='Bob'), Row(age=2, name='Alice')]
         """
         jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     orderBy = sort
 
@@ -1687,11 +1747,11 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]] 
= None,
     ) -> JavaObject:
         """Return a JVM Seq of Columns from a list of Column or names"""
-        return _to_seq(self.sql_ctx._sc, cols, converter)
+        return _to_seq(self.sparkSession._sc, cols, converter)
 
     def _jmap(self, jm: Dict) -> JavaObject:
         """Return a JVM Scala Map from a dict"""
-        return _to_scala_map(self.sql_ctx._sc, jm)
+        return _to_scala_map(self.sparkSession._sc, jm)
 
     def _jcols(self, *cols: "ColumnOrName") -> JavaObject:
         """Return a JVM Seq of Columns from a list of Column or column names
@@ -1767,7 +1827,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         if len(cols) == 1 and isinstance(cols[0], list):
             cols = cols[0]  # type: ignore[assignment]
         jdf = self._jdf.describe(self._jseq(cols))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def summary(self, *statistics: str) -> "DataFrame":
         """Computes specified statistics for numeric and string columns. 
Available statistics are:
@@ -1832,7 +1892,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         if len(statistics) == 1 and isinstance(statistics[0], list):
             statistics = statistics[0]
         jdf = self._jdf.summary(self._jseq(statistics))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     @overload
     def head(self) -> Optional[Row]:
@@ -1970,7 +2030,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         [Row(name='Alice', age=12), Row(name='Bob', age=15)]
         """
         jdf = self._jdf.select(self._jcols(*cols))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     @overload
     def selectExpr(self, *expr: str) -> "DataFrame":
@@ -1995,7 +2055,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         if len(expr) == 1 and isinstance(expr[0], list):
             expr = expr[0]  # type: ignore[assignment]
         jdf = self._jdf.selectExpr(self._jseq(expr))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def filter(self, condition: "ColumnOrName") -> "DataFrame":
         """Filters rows using the given condition.
@@ -2028,7 +2088,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             jdf = self._jdf.filter(condition._jc)
         else:
             raise TypeError("condition should be string or Column")
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     @overload
     def groupBy(self, *cols: "ColumnOrName") -> "GroupedData":
@@ -2203,7 +2263,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         Also as standard in SQL, this function resolves columns by position 
(not by name).
         """
-        return DataFrame(self._jdf.union(other._jdf), self.sql_ctx)
+        return DataFrame(self._jdf.union(other._jdf), self.sparkSession)
 
     @since(1.3)
     def unionAll(self, other: "DataFrame") -> "DataFrame":
@@ -2260,7 +2320,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
            Added optional argument `allowMissingColumns` to specify whether to 
allow
            missing columns.
         """
-        return DataFrame(self._jdf.unionByName(other._jdf, 
allowMissingColumns), self.sql_ctx)
+        return DataFrame(self._jdf.unionByName(other._jdf, 
allowMissingColumns), self.sparkSession)
 
     @since(1.3)
     def intersect(self, other: "DataFrame") -> "DataFrame":
@@ -2269,7 +2329,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         This is equivalent to `INTERSECT` in SQL.
         """
-        return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+        return DataFrame(self._jdf.intersect(other._jdf), self.sparkSession)
 
     def intersectAll(self, other: "DataFrame") -> "DataFrame":
         """Return a new :class:`DataFrame` containing rows in both this 
:class:`DataFrame`
@@ -2295,7 +2355,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         +---+---+
 
         """
-        return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx)
+        return DataFrame(self._jdf.intersectAll(other._jdf), self.sparkSession)
 
     @since(1.3)
     def subtract(self, other: "DataFrame") -> "DataFrame":
@@ -2305,7 +2365,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         This is equivalent to `EXCEPT DISTINCT` in SQL.
 
         """
-        return DataFrame(getattr(self._jdf, "except")(other._jdf), 
self.sql_ctx)
+        return DataFrame(getattr(self._jdf, "except")(other._jdf), 
self.sparkSession)
 
     def dropDuplicates(self, subset: Optional[List[str]] = None) -> 
"DataFrame":
         """Return a new :class:`DataFrame` with duplicate rows removed,
@@ -2350,7 +2410,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             jdf = self._jdf.dropDuplicates()
         else:
             jdf = self._jdf.dropDuplicates(self._jseq(subset))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def dropna(
         self,
@@ -2398,7 +2458,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         if thresh is None:
             thresh = len(subset) if how == "any" else 1
 
-        return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), 
self.sql_ctx)
+        return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), 
self.sparkSession)
 
     @overload
     def fillna(
@@ -2476,16 +2536,16 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             value = float(value)
 
         if isinstance(value, dict):
-            return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
+            return DataFrame(self._jdf.na().fill(value), self.sparkSession)
         elif subset is None:
-            return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
+            return DataFrame(self._jdf.na().fill(value), self.sparkSession)
         else:
             if isinstance(subset, str):
                 subset = [subset]
             elif not isinstance(subset, (list, tuple)):
                 raise TypeError("subset should be a list or tuple of column 
names")
 
-            return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), 
self.sql_ctx)
+            return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), 
self.sparkSession)
 
     @overload
     def replace(
@@ -2686,10 +2746,11 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             raise ValueError("Mixed type replacements are not supported")
 
         if subset is None:
-            return DataFrame(self._jdf.na().replace("*", rep_dict), 
self.sql_ctx)
+            return DataFrame(self._jdf.na().replace("*", rep_dict), 
self.sparkSession)
         else:
             return DataFrame(
-                self._jdf.na().replace(self._jseq(subset), 
self._jmap(rep_dict)), self.sql_ctx
+                self._jdf.na().replace(self._jseq(subset), 
self._jmap(rep_dict)),
+                self.sparkSession,
             )
 
     @overload
@@ -2875,7 +2936,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             raise TypeError("col1 should be a string.")
         if not isinstance(col2, str):
             raise TypeError("col2 should be a string.")
-        return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
+        return DataFrame(self._jdf.stat().crosstab(col1, col2), 
self.sparkSession)
 
     def freqItems(
         self, cols: Union[List[str], Tuple[str]], support: Optional[float] = 
None
@@ -2909,7 +2970,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             raise TypeError("cols must be a list or tuple of column names as 
strings.")
         if not support:
             support = 0.01
-        return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), 
support), self.sql_ctx)
+        return DataFrame(
+            self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), 
self.sparkSession
+        )
 
     def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame":
         """
@@ -2978,7 +3041,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         if not isinstance(col, Column):
             raise TypeError("col should be Column")
-        return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
+        return DataFrame(self._jdf.withColumn(colName, col._jc), 
self.sparkSession)
 
     def withColumnRenamed(self, existing: str, new: str) -> "DataFrame":
         """Returns a new :class:`DataFrame` by renaming an existing column.
@@ -2998,7 +3061,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         >>> df.withColumnRenamed('age', 'age2').collect()
         [Row(age2=2, name='Alice'), Row(age2=5, name='Bob')]
         """
-        return DataFrame(self._jdf.withColumnRenamed(existing, new), 
self.sql_ctx)
+        return DataFrame(self._jdf.withColumnRenamed(existing, new), 
self.sparkSession)
 
     def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> 
"DataFrame":
         """Returns a new :class:`DataFrame` by updating an existing column 
with metadata.
@@ -3023,7 +3086,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         sc = SparkContext._active_spark_context
         assert sc is not None and sc._jvm is not None
         jmeta = 
sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(json.dumps(metadata))
-        return DataFrame(self._jdf.withMetadata(columnName, jmeta), 
self.sql_ctx)
+        return DataFrame(self._jdf.withMetadata(columnName, jmeta), 
self.sparkSession)
 
     @overload
     def drop(self, cols: "ColumnOrName") -> "DataFrame":
@@ -3075,7 +3138,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                     raise TypeError("each col in the param list should be a 
string")
             jdf = self._jdf.drop(self._jseq(cols))
 
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def toDF(self, *cols: "ColumnOrName") -> "DataFrame":
         """Returns a new :class:`DataFrame` that with new specified column 
names
@@ -3091,7 +3154,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         [Row(f1=2, f2='Alice'), Row(f1=5, f2='Bob')]
         """
         jdf = self._jdf.toDF(self._jseq(cols))
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def transform(self, func: Callable[..., "DataFrame"], *args: Any, 
**kwargs: Any) -> "DataFrame":
         """Returns a new :class:`DataFrame`. Concise syntax for chaining 
custom transformations.
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 2dfaec8..d79da47 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1105,7 +1105,7 @@ def broadcast(df: DataFrame) -> DataFrame:
 
     sc = SparkContext._active_spark_context
     assert sc is not None and sc._jvm is not None
-    return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx)
+    return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sparkSession)
 
 
 def coalesce(*cols: "ColumnOrName") -> Column:
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 485e017..802d34d0 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -22,7 +22,7 @@ from typing import Callable, List, Optional, TYPE_CHECKING, 
overload, Dict, Unio
 from py4j.java_gateway import JavaObject  # type: ignore[import]
 
 from pyspark.sql.column import Column, _to_seq
-from pyspark.sql.context import SQLContext
+from pyspark.sql.session import SparkSession
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin
 from pyspark.sql.types import StructType, StructField, IntegerType, StringType
@@ -37,7 +37,7 @@ def dfapi(f: Callable) -> Callable:
     def _api(self: "GroupedData") -> DataFrame:
         name = f.__name__
         jdf = getattr(self._jgd, name)()
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.session)
 
     _api.__name__ = f.__name__
     _api.__doc__ = f.__doc__
@@ -47,8 +47,8 @@ def dfapi(f: Callable) -> Callable:
 def df_varargs_api(f: Callable) -> Callable:
     def _api(self: "GroupedData", *cols: str) -> DataFrame:
         name = f.__name__
-        jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols))
-        return DataFrame(jdf, self.sql_ctx)
+        jdf = getattr(self._jgd, name)(_to_seq(self.session._sc, cols))
+        return DataFrame(jdf, self.session)
 
     _api.__name__ = f.__name__
     _api.__doc__ = f.__doc__
@@ -66,7 +66,7 @@ class GroupedData(PandasGroupedOpsMixin):
     def __init__(self, jgd: JavaObject, df: DataFrame):
         self._jgd = jgd
         self._df = df
-        self.sql_ctx: SQLContext = df.sql_ctx
+        self.session: SparkSession = df.sparkSession
 
     @overload
     def agg(self, *exprs: Column) -> DataFrame:
@@ -134,8 +134,8 @@ class GroupedData(PandasGroupedOpsMixin):
             # Columns
             assert all(isinstance(c, Column) for c in exprs), "all exprs 
should be Column"
             exprs = cast(Tuple[Column, ...], exprs)
-            jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc 
for c in exprs[1:]]))
-        return DataFrame(jdf, self.sql_ctx)
+            jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.session._sc, [c._jc 
for c in exprs[1:]]))
+        return DataFrame(jdf, self.session)
 
     @dfapi
     def count(self) -> DataFrame:
diff --git a/python/pyspark/sql/observation.py 
b/python/pyspark/sql/observation.py
index e5d426a..951b0f4 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -109,7 +109,7 @@ class Observation:
         observed_df = self._jo.on(
             df._jdf, exprs[0]._jc, column._to_seq(df._sc, [c._jc for c in 
exprs[1:]])
         )
-        return DataFrame(observed_df, df.sql_ctx)
+        return DataFrame(observed_df, df.sparkSession)
 
     @property
     def get(self) -> Dict[str, Any]:
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index 33a4058..fbb5183 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -43,6 +43,7 @@ from pyspark.traceback_utils import SCCallSiteSync
 if TYPE_CHECKING:
     import numpy as np
     import pyarrow as pa
+    from py4j.java_gateway import JavaObject
 
     from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
     from pyspark.sql import DataFrame
@@ -88,9 +89,10 @@ class PandasConversionMixin:
         import pandas as pd
         from pandas.core.dtypes.common import is_timedelta64_dtype
 
-        timezone = self.sql_ctx._conf.sessionLocalTimeZone()  # type: 
ignore[attr-defined]
+        jconf = self.sparkSession._jconf
+        timezone = jconf.sessionLocalTimeZone()
 
-        if self.sql_ctx._conf.arrowPySparkEnabled():  # type: 
ignore[attr-defined]
+        if jconf.arrowPySparkEnabled():  # type: ignore[attr-defined]
             use_arrow = True
             try:
                 from pyspark.sql.pandas.types import to_arrow_schema
@@ -100,7 +102,7 @@ class PandasConversionMixin:
                 to_arrow_schema(self.schema)
             except Exception as e:
 
-                if self.sql_ctx._conf.arrowPySparkFallbackEnabled():  # type: 
ignore[attr-defined]
+                if jconf.arrowPySparkFallbackEnabled():  # type: 
ignore[attr-defined]
                     msg = (
                         "toPandas attempted Arrow optimization because "
                         "'spark.sql.execution.arrow.pyspark.enabled' is set to 
true; however, "
@@ -134,7 +136,7 @@ class PandasConversionMixin:
 
                     # Rename columns to avoid duplicated column names.
                     tmp_column_names = ["col_{}".format(i) for i in 
range(len(self.columns))]
-                    c = self.sql_ctx._conf
+                    c = self.sparkSession._jconf
                     self_destruct = (
                         c.arrowPySparkSelfDestructEnabled()  # type: 
ignore[attr-defined]
                     )
@@ -368,6 +370,8 @@ class SparkConversionMixin:
     can use this class.
     """
 
+    _jsparkSession: "JavaObject"
+
     @overload
     def createDataFrame(
         self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
@@ -398,20 +402,17 @@ class SparkConversionMixin:
 
         require_minimum_pandas_version()
 
-        timezone = self._wrapped._conf.sessionLocalTimeZone()  # type: 
ignore[attr-defined]
+        timezone = self._jconf.sessionLocalTimeZone()  # type: 
ignore[attr-defined]
 
         # If no schema supplied by user then get the names of columns only
         if schema is None:
             schema = [str(x) if not isinstance(x, str) else x for x in 
data.columns]
 
-        if (
-            self._wrapped._conf.arrowPySparkEnabled()  # type: 
ignore[attr-defined]
-            and len(data) > 0
-        ):
+        if self._jconf.arrowPySparkEnabled() and len(data) > 0:  # type: 
ignore[attr-defined]
             try:
                 return self._create_from_pandas_with_arrow(data, schema, 
timezone)
             except Exception as e:
-                if self._wrapped._conf.arrowPySparkFallbackEnabled():  # type: 
ignore[attr-defined]
+                if self._jconf.arrowPySparkFallbackEnabled():  # type: 
ignore[attr-defined]
                     msg = (
                         "createDataFrame attempted Arrow optimization because "
                         "'spark.sql.execution.arrow.pyspark.enabled' is set to 
true; however, "
@@ -603,25 +604,25 @@ class SparkConversionMixin:
             for pdf_slice in pdf_slices
         ]
 
-        jsqlContext = self._wrapped._jsqlContext  # type: ignore[attr-defined]
+        jsparkSession = self._jsparkSession
 
-        safecheck = self._wrapped._conf.arrowSafeTypeConversion()  # type: 
ignore[attr-defined]
+        safecheck = self._jconf.arrowSafeTypeConversion()  # type: 
ignore[attr-defined]
         col_by_name = True  # col by name only applies to StructType columns, 
can't happen here
         ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
 
         @no_type_check
         def reader_func(temp_filename):
-            return 
self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
+            return 
self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsparkSession, temp_filename)
 
         @no_type_check
         def create_RDD_server():
-            return self._jvm.ArrowRDDServer(jsqlContext)
+            return self._jvm.ArrowRDDServer(jsparkSession)
 
         # Create Spark DataFrame from Arrow stream file, using one batch per 
partition
         jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, 
create_RDD_server)
         assert self._jvm is not None
-        jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), 
jsqlContext)
-        df = DataFrame(jdf, self._wrapped)
+        jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), 
jsparkSession)
+        df = DataFrame(jdf, self)
         df._schema = schema
         return df
 
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 35f531f..e7599b1 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -214,7 +214,7 @@ class PandasGroupedOpsMixin:
         df = self._df
         udf_column = udf(*[df[col] for col in df.columns])
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())  # type: 
ignore[attr-defined]
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.session)
 
     def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
         """
@@ -246,7 +246,6 @@ class PandasCogroupedOps:
     def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
         self._gd1 = gd1
         self._gd2 = gd2
-        self.sql_ctx = gd1.sql_ctx
 
     def applyInPandas(
         self, func: "PandasCogroupedMapFunction", schema: Union[StructType, 
str]
@@ -345,7 +344,7 @@ class PandasCogroupedOps:
         jdf = self._gd1._jgd.flatMapCoGroupsInPandas(  # type: 
ignore[attr-defined]
             self._gd2._jgd, udf_column._jc.expr()  # type: ignore[attr-defined]
         )
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self._gd1.session)
 
     @staticmethod
     def _extract_cols(gd: "GroupedData") -> List[Column]:
diff --git a/python/pyspark/sql/pandas/map_ops.py 
b/python/pyspark/sql/pandas/map_ops.py
index c1c29ec..c1bf6aa 100644
--- a/python/pyspark/sql/pandas/map_ops.py
+++ b/python/pyspark/sql/pandas/map_ops.py
@@ -90,7 +90,7 @@ class PandasMapOpsMixin:
         )  # type: ignore[call-overload]
         udf_column = udf(*[self[col] for col in self.columns])
         jdf = self._jdf.mapInPandas(udf_column._jc.expr())  # type: 
ignore[operator]
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
     def mapInArrow(
         self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
@@ -153,7 +153,7 @@ class PandasMapOpsMixin:
         )  # type: ignore[call-overload]
         udf_column = udf(*[self[col] for col in self.columns])
         jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr())
-        return DataFrame(jdf, self.sql_ctx)
+        return DataFrame(jdf, self.sparkSession)
 
 
 def _test() -> None:
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index df4a089..8c729c6 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -27,7 +27,7 @@ from pyspark.sql.utils import to_str
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import OptionalPrimitiveType, ColumnOrName
-    from pyspark.sql.context import SQLContext
+    from pyspark.sql.session import SparkSession
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.streaming import StreamingQuery
 
@@ -62,8 +62,8 @@ class DataFrameReader(OptionUtils):
     .. versionadded:: 1.4
     """
 
-    def __init__(self, spark: "SQLContext"):
-        self._jreader = spark._ssql_ctx.read()  # type: ignore[attr-defined]
+    def __init__(self, spark: "SparkSession"):
+        self._jreader = spark._jsparkSession.read()  # type: 
ignore[attr-defined]
         self._spark = spark
 
     def _df(self, jdf: JavaObject) -> "DataFrame":
@@ -560,7 +560,7 @@ class DataFrameReader(OptionUtils):
             # There aren't any jvm api for creating a dataframe from rdd 
storing csv.
             # We can do it through creating a jvm dataset firstly and using 
the jvm api
             # for creating a dataframe from dataset storing csv.
-            jdataset = self._spark._ssql_ctx.createDataset(
+            jdataset = self._spark._jsparkSession.createDataset(
                 jrdd.rdd(), self._spark._jvm.Encoders.STRING()
             )
             return self._df(self._jreader.csv(jdataset))
@@ -737,7 +737,7 @@ class DataFrameWriter(OptionUtils):
 
     def __init__(self, df: "DataFrame"):
         self._df = df
-        self._spark = df.sql_ctx
+        self._spark = df.sparkSession
         self._jwrite = df._jdf.write()  # type: ignore[operator]
 
     def _sq(self, jsq: JavaObject) -> "StreamingQuery":
@@ -1360,7 +1360,7 @@ class DataFrameWriterV2:
 
     def __init__(self, df: "DataFrame", table: str):
         self._df = df
-        self._spark = df.sql_ctx
+        self._spark = df.sparkSession
         self._jwriter = df._jdf.writeTo(table)  # type: ignore[operator]
 
     @since(3.1)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 233d529..a41ad15 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -230,11 +230,6 @@ class SparkSession(SparkConversionMixin):
             """
             return self.config("spark.sql.catalogImplementation", "hive")
 
-        def _sparkContext(self, sc: SparkContext) -> "SparkSession.Builder":
-            with self._lock:
-                self._sc = sc
-                return self
-
         def getOrCreate(self) -> "SparkSession":
             """Gets an existing :class:`SparkSession` or, if there is no 
existing one, creates a
             new one based on the options set in this builder.
@@ -267,14 +262,11 @@ class SparkSession(SparkConversionMixin):
 
                 session = SparkSession._instantiatedSession
                 if session is None or session._sc._jsc is None:  # type: 
ignore[attr-defined]
-                    if self._sc is not None:
-                        sc = self._sc
-                    else:
-                        sparkConf = SparkConf()
-                        for key, value in self._options.items():
-                            sparkConf.set(key, value)
-                        # This SparkContext may be an existing one.
-                        sc = SparkContext.getOrCreate(sparkConf)
+                    sparkConf = SparkConf()
+                    for key, value in self._options.items():
+                        sparkConf.set(key, value)
+                    # This SparkContext may be an existing one.
+                    sc = SparkContext.getOrCreate(sparkConf)
                     # Do not update `SparkConf` for existing `SparkContext`, 
as it's shared
                     # by all sessions.
                     session = SparkSession(sc, options=self._options)
@@ -296,8 +288,6 @@ class SparkSession(SparkConversionMixin):
         jsparkSession: Optional[JavaObject] = None,
         options: Dict[str, Any] = {},
     ):
-        from pyspark.sql.context import SQLContext
-
         self._sc = sparkContext
         self._jsc = self._sc._jsc
         self._jvm = self._sc._jvm
@@ -320,8 +310,6 @@ class SparkSession(SparkConversionMixin):
                 jsparkSession, options
             )
         self._jsparkSession = jsparkSession
-        self._jwrapped = self._jsparkSession.sqlContext()
-        self._wrapped = SQLContext(self._sc, self, self._jwrapped)
         _monkey_patch_RDD(self)
         install_exception_handler()
         # If we had an instantiated SparkSession attached with a SparkContext
@@ -348,6 +336,11 @@ class SparkSession(SparkConversionMixin):
             sc_HTML=self.sparkContext._repr_html_(),  # type: 
ignore[attr-defined]
         )
 
+    @property
+    def _jconf(self) -> "JavaObject":
+        """Accessor for the JVM SQL-specific configurations"""
+        return self._jsparkSession.sessionState().conf()
+
     @since(2.0)
     def newSession(self) -> "SparkSession":
         """
@@ -498,7 +491,7 @@ class SparkSession(SparkConversionMixin):
         else:
             jdf = self._jsparkSession.range(int(start), int(end), int(step), 
int(numPartitions))
 
-        return DataFrame(jdf, self._wrapped)
+        return DataFrame(jdf, self)
 
     def _inferSchemaFromList(
         self, data: Iterable[Any], names: Optional[List[str]] = None
@@ -519,7 +512,7 @@ class SparkSession(SparkConversionMixin):
         """
         if not data:
             raise ValueError("can not infer schema from empty dataset")
-        infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()  # 
type: ignore[attr-defined]
+        infer_dict_as_struct = self._jconf.inferDictAsStruct()  # type: 
ignore[attr-defined]
         prefer_timestamp_ntz = is_timestamp_ntz_preferred()
         schema = reduce(
             _merge_type,
@@ -554,7 +547,7 @@ class SparkSession(SparkConversionMixin):
         if not first:
             raise ValueError("The first row in RDD is empty, " "can not infer 
schema")
 
-        infer_dict_as_struct = self._wrapped._conf.inferDictAsStruct()  # 
type: ignore[attr-defined]
+        infer_dict_as_struct = self._jconf.inferDictAsStruct()  # type: 
ignore[attr-defined]
         prefer_timestamp_ntz = is_timestamp_ntz_preferred()
         if samplingRatio is None:
             schema = _infer_schema(
@@ -684,14 +677,20 @@ class SparkSession(SparkConversionMixin):
         return SparkSession._getActiveSessionOrCreate()
 
     @staticmethod
-    def _getActiveSessionOrCreate() -> "SparkSession":
+    def _getActiveSessionOrCreate(**static_conf: Any) -> "SparkSession":
         """
         Returns the active :class:`SparkSession` for the current thread, 
returned by the builder,
         or if there is no existing one, creates a new one based on the options 
set in the builder.
+
+        NOTE that 'static_conf' might not be set if there's an active or 
default Spark session
+        running.
         """
         spark = SparkSession.getActiveSession()
         if spark is None:
-            spark = SparkSession.builder.getOrCreate()
+            builder = SparkSession.builder
+            for k, v in static_conf.items():
+                builder = builder.config(k, v)
+            spark = builder.getOrCreate()
         return spark
 
     @overload
@@ -940,7 +939,7 @@ class SparkSession(SparkConversionMixin):
             rdd._to_java_object_rdd()  # type: ignore[attr-defined]
         )
         jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), 
struct.json())
-        df = DataFrame(jdf, self._wrapped)
+        df = DataFrame(jdf, self)
         df._schema = struct
         return df
 
@@ -1034,7 +1033,7 @@ class SparkSession(SparkConversionMixin):
         if len(kwargs) > 0:
             sqlQuery = formatter.format(sqlQuery, **kwargs)
         try:
-            return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
+            return DataFrame(self._jsparkSession.sql(sqlQuery), self)
         finally:
             if len(kwargs) > 0:
                 formatter.clear()
@@ -1055,7 +1054,7 @@ class SparkSession(SparkConversionMixin):
         >>> sorted(df.collect()) == sorted(df2.collect())
         True
         """
-        return DataFrame(self._jsparkSession.table(tableName), self._wrapped)
+        return DataFrame(self._jsparkSession.table(tableName), self)
 
     @property
     def read(self) -> DataFrameReader:
@@ -1069,7 +1068,7 @@ class SparkSession(SparkConversionMixin):
         -------
         :class:`DataFrameReader`
         """
-        return DataFrameReader(self._wrapped)
+        return DataFrameReader(self)
 
     @property
     def readStream(self) -> DataStreamReader:
@@ -1087,7 +1086,7 @@ class SparkSession(SparkConversionMixin):
         -------
         :class:`DataStreamReader`
         """
-        return DataStreamReader(self._wrapped)
+        return DataStreamReader(self)
 
     @property
     def streams(self) -> "StreamingQueryManager":
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index de68ccc..7cff8d0 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -29,7 +29,7 @@ from pyspark.sql.types import Row, StructType, StructField, 
StringType
 from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException
 
 if TYPE_CHECKING:
-    from pyspark.sql import SQLContext
+    from pyspark.sql.session import SparkSession
     from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType
     from pyspark.sql.dataframe import DataFrame
 
@@ -316,8 +316,8 @@ class DataStreamReader(OptionUtils):
     This API is evolving.
     """
 
-    def __init__(self, spark: "SQLContext") -> None:
-        self._jreader = spark._ssql_ctx.readStream()
+    def __init__(self, spark: "SparkSession") -> None:
+        self._jreader = spark._jsparkSession.readStream()
         self._spark = spark
 
     def _df(self, jdf: JavaObject) -> "DataFrame":
@@ -856,7 +856,7 @@ class DataStreamWriter:
 
     def __init__(self, df: "DataFrame") -> None:
         self._df = df
-        self._spark = df.sql_ctx
+        self._spark = df.sparkSession
         self._jwrite = df._jdf.writeStream()
 
     def _sq(self, jsq: JavaObject) -> StreamingQuery:
diff --git a/python/pyspark/sql/tests/test_session.py 
b/python/pyspark/sql/tests/test_session.py
index 1262e52..91aa923 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -224,28 +224,26 @@ class SparkSessionTests5(unittest.TestCase):
     def test_sqlcontext_with_stopped_sparksession(self):
         # SPARK-30856: test that SQLContext.getOrCreate() returns a usable 
instance after
         # the SparkSession is restarted.
-        sql_context = self.spark._wrapped
+        sql_context = SQLContext.getOrCreate(self.spark.sparkContext)
         self.spark.stop()
-        sc = SparkContext("local[4]", self.sc.appName)
-        spark = SparkSession(sc)  # Instantiate the underlying SQLContext
-        new_sql_context = spark._wrapped
+        spark = 
SparkSession.builder.master("local[4]").appName(self.sc.appName).getOrCreate()
+        new_sql_context = SQLContext.getOrCreate(spark.sparkContext)
 
         self.assertIsNot(new_sql_context, sql_context)
-        self.assertIs(SQLContext.getOrCreate(sc).sparkSession, spark)
+        self.assertIs(SQLContext.getOrCreate(spark.sparkContext).sparkSession, 
spark)
         try:
             df = spark.createDataFrame([(1, 2)], ["c", "c"])
             df.collect()
         finally:
             spark.stop()
             self.assertIsNone(SQLContext._instantiatedContext)
-            sc.stop()
 
     def test_sqlcontext_with_stopped_sparkcontext(self):
         # SPARK-30856: test initialization via SparkSession when only the 
SparkContext is stopped
         self.sc.stop()
-        self.sc = SparkContext("local[4]", self.sc.appName)
-        self.spark = SparkSession(self.sc)
-        self.assertIs(SQLContext.getOrCreate(self.sc).sparkSession, self.spark)
+        spark = 
SparkSession.builder.master("local[4]").appName(self.sc.appName).getOrCreate()
+        self.sc = spark.sparkContext
+        self.assertIs(SQLContext.getOrCreate(self.sc).sparkSession, spark)
 
     def test_get_sqlcontext_with_stopped_sparkcontext(self):
         # SPARK-30856: test initialization via SQLContext.getOrCreate() when 
only the SparkContext
diff --git a/python/pyspark/sql/tests/test_streaming.py 
b/python/pyspark/sql/tests/test_streaming.py
index 87e3564..4920423 100644
--- a/python/pyspark/sql/tests/test_streaming.py
+++ b/python/pyspark/sql/tests/test_streaming.py
@@ -86,7 +86,7 @@ class StreamingTests(ReusedSQLTestCase):
             .load("python/test_support/sql/streaming")
             .withColumn("id", lit(1))
         )
-        for q in self.spark._wrapped.streams.active:
+        for q in self.spark.streams.active:
             q.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -117,7 +117,7 @@ class StreamingTests(ReusedSQLTestCase):
 
     def test_stream_save_options_overwrite(self):
         df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
-        for q in self.spark._wrapped.streams.active:
+        for q in self.spark.streams.active:
             q.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -154,7 +154,7 @@ class StreamingTests(ReusedSQLTestCase):
 
     def test_stream_status_and_progress(self):
         df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
-        for q in self.spark._wrapped.streams.active:
+        for q in self.spark.streams.active:
             q.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -198,7 +198,7 @@ class StreamingTests(ReusedSQLTestCase):
 
     def test_stream_await_termination(self):
         df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
-        for q in self.spark._wrapped.streams.active:
+        for q in self.spark.streams.active:
             q.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -267,7 +267,7 @@ class StreamingTests(ReusedSQLTestCase):
 
     def test_query_manager_await_termination(self):
         df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
-        for q in self.spark._wrapped.streams.active:
+        for q in self.spark.streams.active:
             q.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -280,13 +280,13 @@ class StreamingTests(ReusedSQLTestCase):
         try:
             self.assertTrue(q.isActive)
             try:
-                self.spark._wrapped.streams.awaitAnyTermination("hello")
+                self.spark.streams.awaitAnyTermination("hello")
                 self.fail("Expected a value exception")
             except ValueError:
                 pass
             now = time.time()
             # test should take at least 2 seconds
-            res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
+            res = self.spark.streams.awaitAnyTermination(2.6)
             duration = time.time() - now
             self.assertTrue(duration >= 2)
             self.assertFalse(res)
@@ -347,7 +347,7 @@ class StreamingTests(ReusedSQLTestCase):
                 self.stop_all()
 
         def stop_all(self):
-            for q in self.spark._wrapped.streams.active:
+            for q in self.spark.streams.active:
                 q.stop()
 
         def _reset(self):
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index a092d67..0e9d766 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -22,7 +22,7 @@ import tempfile
 import unittest
 import datetime
 
-from pyspark import SparkContext
+from pyspark import SparkContext, SQLContext
 from pyspark.sql import SparkSession, Column, Row
 from pyspark.sql.functions import udf, assert_true, lit
 from pyspark.sql.udf import UserDefinedFunction
@@ -79,7 +79,7 @@ class UDFTests(ReusedSQLTestCase):
         self.assertEqual(row[0], 5)
 
         # This is to check if a deprecated 'SQLContext.registerFunction' can 
call its alias.
-        sqlContext = self.spark._wrapped
+        sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
         sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
         [row] = sqlContext.sql("SELECT oneArg('test')").collect()
         self.assertEqual(row[0], 4)
@@ -372,7 +372,7 @@ class UDFTests(ReusedSQLTestCase):
         )
 
         # This is to check if a 'SQLContext.udf' can call its alias.
-        sqlContext = self.spark._wrapped
+        sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
         add_four = sqlContext.udf.register("add_four", lambda x: x + 4, 
IntegerType())
 
         self.assertListEqual(
@@ -419,7 +419,7 @@ class UDFTests(ReusedSQLTestCase):
         )
 
         # This is to check if a deprecated 'SQLContext.registerJavaFunction' 
can call its alias.
-        sqlContext = spark._wrapped
+        sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
         self.assertRaisesRegex(
             AnalysisException,
             "Can not load class non_existed_udf",
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py 
b/python/pyspark/sql/tests/test_udf_profiler.py
index 27d9458..136f423 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -21,7 +21,7 @@ import os
 import sys
 from io import StringIO
 
-from pyspark import SparkConf, SparkContext
+from pyspark import SparkConf
 from pyspark.sql import SparkSession
 from pyspark.sql.functions import udf
 from pyspark.profiler import UDFBasicProfiler
@@ -32,8 +32,13 @@ class UDFProfilerTests(unittest.TestCase):
         self._old_sys_path = list(sys.path)
         class_name = self.__class__.__name__
         conf = SparkConf().set("spark.python.profile", "true")
-        self.sc = SparkContext("local[4]", class_name, conf=conf)
-        self.spark = SparkSession.builder._sparkContext(self.sc).getOrCreate()
+        self.spark = (
+            SparkSession.builder.master("local[4]")
+            .config(conf=conf)
+            .appName(class_name)
+            .getOrCreate()
+        )
+        self.sc = self.spark.sparkContext
 
     def tearDown(self):
         self.spark.stop()
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 15645d0..b5abe68 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -30,7 +30,7 @@ from pyspark import SparkContext
 from pyspark.find_spark_home import _find_spark_home
 
 if TYPE_CHECKING:
-    from pyspark.sql.context import SQLContext
+    from pyspark.sql.session import SparkSession
     from pyspark.sql.dataframe import DataFrame
 
 
@@ -258,15 +258,15 @@ class ForeachBatchFunction:
     the query is active.
     """
 
-    def __init__(self, sql_ctx: "SQLContext", func: Callable[["DataFrame", 
int], None]):
-        self.sql_ctx = sql_ctx
+    def __init__(self, session: "SparkSession", func: Callable[["DataFrame", 
int], None]):
         self.func = func
+        self.session = session
 
     def call(self, jdf: JavaObject, batch_id: int) -> None:
         from pyspark.sql.dataframe import DataFrame
 
         try:
-            self.func(DataFrame(jdf, self.sql_ctx), batch_id)
+            self.func(DataFrame(jdf, self.session), batch_id)
         except Exception as e:
             self.error = e
             raise e
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 490ab9f..ab43aa4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -24,7 +24,7 @@ import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.python.PythonRDDServer
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Column, DataFrame, SQLContext}
+import org.apache.spark.sql.{Column, DataFrame, SparkSession}
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.expressions.{CastTimestampNTZToLong, 
ExpressionInfo}
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -59,8 +59,8 @@ private[sql] object PythonSQLUtils extends Logging {
    * Python callable function to read a file in Arrow stream format and create 
a [[RDD]]
    * using each serialized ArrowRecordBatch as a partition.
    */
-  def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): 
JavaRDD[Array[Byte]] = {
-    ArrowConverters.readArrowStreamFromFile(sqlContext, filename)
+  def readArrowStreamFromFile(session: SparkSession, filename: String): 
JavaRDD[Array[Byte]] = {
+    ArrowConverters.readArrowStreamFromFile(session, filename)
   }
 
   /**
@@ -70,8 +70,8 @@ private[sql] object PythonSQLUtils extends Logging {
   def toDataFrame(
       arrowBatchRDD: JavaRDD[Array[Byte]],
       schemaString: String,
-      sqlContext: SQLContext): DataFrame = {
-    ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext)
+      session: SparkSession): DataFrame = {
+    ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, session)
   }
 
   def explainString(queryExecution: QueryExecution, mode: String): String = {
@@ -85,13 +85,13 @@ private[sql] object PythonSQLUtils extends Logging {
  * Helper for making a dataframe from arrow data from data sent from python 
over a socket.  This is
  * used when encryption is enabled, and we don't want to write data to a file.
  */
-private[sql] class ArrowRDDServer(sqlContext: SQLContext) extends 
PythonRDDServer {
+private[sql] class ArrowRDDServer(session: SparkSession) extends 
PythonRDDServer {
 
   override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
     // Create array to consume iterator so that we can safely close the 
inputStream
     val batches = 
ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray
     // Parallelize the record batches to create an RDD
-    JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
+    JavaRDD.fromRDD(session.sparkContext.parallelize(batches, batches.length))
   }
 
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index befaea2..7831dde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -230,7 +230,7 @@ private[sql] object SQLUtils extends Logging {
   def readArrowStreamFromFile(
       sparkSession: SparkSession,
       filename: String): JavaRDD[Array[Byte]] = {
-    ArrowConverters.readArrowStreamFromFile(sparkSession.sqlContext, filename)
+    ArrowConverters.readArrowStreamFromFile(sparkSession, filename)
   }
 
   /**
@@ -241,6 +241,6 @@ private[sql] object SQLUtils extends Logging {
       arrowBatchRDD: JavaRDD[Array[Byte]],
       schema: StructType,
       sparkSession: SparkSession): DataFrame = {
-    ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, 
sparkSession.sqlContext)
+    ArrowConverters.toDataFrame(arrowBatchRDD, schema.json, sparkSession)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 8e22c42..93ff276 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -31,7 +31,7 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, 
IpcOption, Message
 import org.apache.spark.TaskContext
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.ArrowUtils
@@ -195,27 +195,27 @@ private[sql] object ArrowConverters {
   private[sql] def toDataFrame(
       arrowBatchRDD: JavaRDD[Array[Byte]],
       schemaString: String,
-      sqlContext: SQLContext): DataFrame = {
+      session: SparkSession): DataFrame = {
     val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
-    val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+    val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
     val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
       val context = TaskContext.get()
       ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
     }
-    sqlContext.internalCreateDataFrame(rdd.setName("arrow"), schema)
+    session.internalCreateDataFrame(rdd.setName("arrow"), schema)
   }
 
   /**
    * Read a file as an Arrow stream and parallelize as an RDD of serialized 
ArrowRecordBatches.
    */
   private[sql] def readArrowStreamFromFile(
-      sqlContext: SQLContext,
+      session: SparkSession,
       filename: String): JavaRDD[Array[Byte]] = {
     Utils.tryWithResource(new FileInputStream(filename)) { fileStream =>
       // Create array to consume iterator so that we can safely close the file
       val batches = getBatchesFromStream(fileStream.getChannel).toArray
       // Parallelize the record batches to create an RDD
-      JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, 
batches.length))
+      JavaRDD.fromRDD(session.sparkContext.parallelize(batches, 
batches.length))
     }
   }
 

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to