This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 21e7fe81a66 [SPARK-44694][PYTHON][CONNECT] Refactor active sessions 
and expose them as an API
21e7fe81a66 is described below

commit 21e7fe81a662eacebcbc1971a3586a6d470376f3
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Tue Aug 8 11:03:05 2023 +0900

    [SPARK-44694][PYTHON][CONNECT] Refactor active sessions and expose them as 
an API
    
    This PR proposes to (mostly) refactor all the internal workarounds to get 
the active session correctly.
    
    There are few things to note:
    
    - _PySpark without Spark Connect does not already support the hierarchy of 
active sessions_. With pinned thread mode (enabled by default), PySpark does 
map each Python thread to JVM thread, but the thread creation happens within 
gateway server, that does not respect the thread hierarchy. Therefore, this PR 
follows the exactly same behaviour.
      - New thread will not have an active thread by default.
      - Other behaviours are same as PySpark without Connect, see also 
https://github.com/apache/spark/pull/42367
    - Since I am here, I piggiyback few documentation changes. We missed 
document `SparkSession.readStream`, `SparkSession.streams`, 
`SparkSession.udtf`, `SparkSession.conf` and `SparkSession.version` in Spark 
Connect.
    - The changes here are mostly refactoring that reuses existing unittests 
while I expose two methods:
      - `SparkSession.getActiveSession` (only for Spark Connect)
      - `SparkSession.active` (for both in PySpark)
    
    For Spark Connect users to be able to play with active and default sessions 
in Python.
    
    Yes, it adds new API:
      - `SparkSession.getActiveSession` (only for Spark Connect)
      - `SparkSession.active` (for both in PySpark)
    
    Existing unittests should cover all.
    
    Closes #42371 from HyukjinKwon/SPARK-44694.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 9368a0f0c1001fb6fd64799a2e744874b6cd27e4)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../source/reference/pyspark.sql/spark_session.rst |   1 +
 python/pyspark/errors/error_classes.py             |   5 +
 python/pyspark/ml/connect/io_utils.py              |   8 +-
 python/pyspark/ml/connect/tuning.py                |  11 ++-
 python/pyspark/ml/torch/distributor.py             |   3 +-
 python/pyspark/ml/util.py                          |  13 ---
 python/pyspark/pandas/utils.py                     |   7 +-
 python/pyspark/sql/connect/session.py              | 107 ++++++++++++++-------
 python/pyspark/sql/connect/udf.py                  |  25 +++--
 python/pyspark/sql/connect/udtf.py                 |  27 +++---
 python/pyspark/sql/session.py                      |  65 +++++++++++--
 .../sql/tests/connect/test_connect_basic.py        |   4 +-
 python/pyspark/sql/utils.py                        |  18 ++++
 13 files changed, 197 insertions(+), 97 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst 
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index a5f8bc47d44..74315a0aacc 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -28,6 +28,7 @@ See also :class:`SparkSession`.
 .. autosummary::
     :toctree: api/
 
+    SparkSession.active
     SparkSession.builder.appName
     SparkSession.builder.config
     SparkSession.builder.enableHiveSupport
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 971dc59bbb2..937a8758404 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -607,6 +607,11 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>` should be a WindowSpec, got <arg_type>."
     ]
   },
+  "NO_ACTIVE_OR_DEFAULT_SESSION" : {
+    "message" : [
+      "No active or default Spark session found. Please create a new Spark 
session before running the code."
+    ]
+  },
   "NO_ACTIVE_SESSION" : {
     "message" : [
       "No active Spark session found. Please create a new Spark session before 
running the code."
diff --git a/python/pyspark/ml/connect/io_utils.py 
b/python/pyspark/ml/connect/io_utils.py
index 9a963086aaf..a09a244862c 100644
--- a/python/pyspark/ml/connect/io_utils.py
+++ b/python/pyspark/ml/connect/io_utils.py
@@ -23,7 +23,7 @@ import time
 from urllib.parse import urlparse
 from typing import Any, Dict, List
 from pyspark.ml.base import Params
-from pyspark.ml.util import _get_active_session
+from pyspark.sql import SparkSession
 from pyspark.sql.utils import is_remote
 
 
@@ -34,7 +34,7 @@ _META_DATA_FILE_NAME = "metadata.json"
 
 
 def _copy_file_from_local_to_fs(local_path: str, dest_path: str) -> None:
-    session = _get_active_session(is_remote())
+    session = SparkSession.active()
     if is_remote():
         session.copyFromLocalToFs(local_path, dest_path)
     else:
@@ -228,7 +228,7 @@ class ParamsReadWrite(Params):
 
         .. versionadded:: 3.5.0
         """
-        session = _get_active_session(is_remote())
+        session = SparkSession.active()
         path_exist = True
         try:
             session.read.format("binaryFile").load(path).head()
@@ -256,7 +256,7 @@ class ParamsReadWrite(Params):
 
         .. versionadded:: 3.5.0
         """
-        session = _get_active_session(is_remote())
+        session = SparkSession.active()
 
         tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_")
         try:
diff --git a/python/pyspark/ml/connect/tuning.py 
b/python/pyspark/ml/connect/tuning.py
index 6d539933e1d..c22c31e84e8 100644
--- a/python/pyspark/ml/connect/tuning.py
+++ b/python/pyspark/ml/connect/tuning.py
@@ -178,11 +178,12 @@ def _parallelFitTasks(
 
     def get_single_task(index: int, param_map: Any) -> Callable[[], Tuple[int, 
float]]:
         def single_task() -> Tuple[int, float]:
-            # Active session is thread-local variable, in background thread 
the active session
-            # is not set, the following line sets it as the main thread active 
session.
-            active_session._jvm.SparkSession.setActiveSession(  # type: 
ignore[union-attr]
-                active_session._jsparkSession  # type: ignore[union-attr]
-            )
+            if not is_remote():
+                # Active session is thread-local variable, in background 
thread the active session
+                # is not set, the following line sets it as the main thread 
active session.
+                active_session._jvm.SparkSession.setActiveSession(  # type: 
ignore[union-attr]
+                    active_session._jsparkSession  # type: ignore[union-attr]
+                )
 
             model = estimator.fit(train, param_map)
             metric = evaluator.evaluate(
diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index 71bcde3b48e..b407672ac48 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -48,7 +48,6 @@ from pyspark.ml.torch.log_communication import (  # type: 
ignore
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.ml.util import _get_active_session
 
 
 def _get_resources(session: SparkSession) -> Dict[str, ResourceInformation]:
@@ -164,7 +163,7 @@ class Distributor:
         from pyspark.sql.utils import is_remote
 
         self.is_remote = is_remote()
-        self.spark = _get_active_session(self.is_remote)
+        self.spark = SparkSession.active()
 
         # indicate whether the server side is local mode
         self.is_spark_local_master = False
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 2c90ff3cb7b..64676947017 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -747,16 +747,3 @@ def try_remote_functions(f: FuncT) -> FuncT:
             return f(*args, **kwargs)
 
     return cast(FuncT, wrapped)
-
-
-def _get_active_session(is_remote: bool) -> SparkSession:
-    if not is_remote:
-        spark = SparkSession.getActiveSession()
-    else:
-        import pyspark.sql.connect.session
-
-        spark = pyspark.sql.connect.session._active_spark_session  # type: 
ignore[assignment]
-
-    if spark is None:
-        raise RuntimeError("An active SparkSession is required for the 
distributor.")
-    return spark
diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py
index c66b3359e77..55b9a57ef61 100644
--- a/python/pyspark/pandas/utils.py
+++ b/python/pyspark/pandas/utils.py
@@ -478,12 +478,7 @@ def is_testing() -> bool:
 
 
 def default_session() -> SparkSession:
-    if not is_remote():
-        spark = SparkSession.getActiveSession()
-    else:
-        from pyspark.sql.connect.session import _active_spark_session
-
-        spark = _active_spark_session  # type: ignore[assignment]
+    spark = SparkSession.getActiveSession()
     if spark is None:
         spark = SparkSession.builder.appName("pandas-on-Spark").getOrCreate()
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 9bba0db05e4..d75a30c561f 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -18,6 +18,7 @@ from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
 
+import threading
 import os
 import warnings
 from collections.abc import Sized
@@ -36,6 +37,7 @@ from typing import (
     overload,
     Iterable,
     TYPE_CHECKING,
+    ClassVar,
 )
 
 import numpy as np
@@ -93,14 +95,13 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.udtf import UDTFRegistration
 
 
-# `_active_spark_session` stores the active spark connect session created by
-# `SparkSession.builder.getOrCreate`. It is used by ML code.
-#  If sessions are created with `SparkSession.builder.create`, it stores
-#  The last created session
-_active_spark_session = None
-
-
 class SparkSession:
+    # The active SparkSession for the current thread
+    _active_session: ClassVar[threading.local] = threading.local()
+    # Reference to the root SparkSession
+    _default_session: ClassVar[Optional["SparkSession"]] = None
+    _lock: ClassVar[RLock] = RLock()
+
     class Builder:
         """Builder for :class:`SparkSession`."""
 
@@ -176,8 +177,6 @@ class SparkSession:
             )
 
         def create(self) -> "SparkSession":
-            global _active_spark_session
-
             has_channel_builder = self._channel_builder is not None
             has_spark_remote = "spark.remote" in self._options
 
@@ -200,23 +199,26 @@ class SparkSession:
                 assert spark_remote is not None
                 session = SparkSession(connection=spark_remote)
 
-            _active_spark_session = session
+            SparkSession._set_default_and_active_session(session)
             return session
 
         def getOrCreate(self) -> "SparkSession":
-            global _active_spark_session
-            if _active_spark_session is not None:
-                return _active_spark_session
-            _active_spark_session = self.create()
-            return _active_spark_session
+            with SparkSession._lock:
+                session = SparkSession.getActiveSession()
+                if session is None:
+                    session = SparkSession._default_session
+                    if session is None:
+                        session = self.create()
+                return session
 
     _client: SparkConnectClient
 
     @classproperty
     def builder(cls) -> Builder:
-        """Creates a :class:`Builder` for constructing a 
:class:`SparkSession`."""
         return cls.Builder()
 
+    builder.__doc__ = PySparkSession.builder.__doc__
+
     def __init__(self, connection: Union[str, ChannelBuilder], userId: 
Optional[str] = None):
         """
         Creates a new SparkSession for the Spark Connect interface.
@@ -236,6 +238,38 @@ class SparkSession:
         self._client = SparkConnectClient(connection=connection, 
user_id=userId)
         self._session_id = self._client._session_id
 
+    @classmethod
+    def _set_default_and_active_session(cls, session: "SparkSession") -> None:
+        """
+        Set the (global) default :class:`SparkSession`, and (thread-local)
+        active :class:`SparkSession` when they are not set yet.
+        """
+        with cls._lock:
+            if cls._default_session is None:
+                cls._default_session = session
+        if getattr(cls._active_session, "session", None) is None:
+            cls._active_session.session = session
+
+    @classmethod
+    def getActiveSession(cls) -> Optional["SparkSession"]:
+        return getattr(cls._active_session, "session", None)
+
+    getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__
+
+    @classmethod
+    def active(cls) -> "SparkSession":
+        session = cls.getActiveSession()
+        if session is None:
+            session = cls._default_session
+            if session is None:
+                raise PySparkRuntimeError(
+                    error_class="NO_ACTIVE_OR_DEFAULT_SESSION",
+                    message_parameters={},
+                )
+        return session
+
+    active.__doc__ = PySparkSession.active.__doc__
+
     def table(self, tableName: str) -> DataFrame:
         return self.read.table(tableName)
 
@@ -251,6 +285,8 @@ class SparkSession:
     def readStream(self) -> "DataStreamReader":
         return DataStreamReader(self)
 
+    readStream.__doc__ = PySparkSession.readStream.__doc__
+
     def _inferSchemaFromList(
         self, data: Iterable[Any], names: Optional[List[str]] = None
     ) -> StructType:
@@ -601,19 +637,20 @@ class SparkSession:
         # specifically in Spark Connect the Spark Connect server is designed 
for
         # multi-tenancy - the remote client side cannot just stop the server 
and stop
         # other remote clients being used from other users.
-        global _active_spark_session
-        self.client.close()
-        _active_spark_session = None
-
-        if "SPARK_LOCAL_REMOTE" in os.environ:
-            # When local mode is in use, follow the regular Spark session's
-            # behavior by terminating the Spark Connect server,
-            # meaning that you can stop local mode, and restart the Spark 
Connect
-            # client with a different remote address.
-            active_session = PySparkSession.getActiveSession()
-            if active_session is not None:
-                active_session.stop()
-            with SparkContext._lock:
+        with SparkSession._lock:
+            self.client.close()
+            if self is SparkSession._default_session:
+                SparkSession._default_session = None
+            if self is getattr(SparkSession._active_session, "session", None):
+                SparkSession._active_session.session = None
+
+            if "SPARK_LOCAL_REMOTE" in os.environ:
+                # When local mode is in use, follow the regular Spark session's
+                # behavior by terminating the Spark Connect server,
+                # meaning that you can stop local mode, and restart the Spark 
Connect
+                # client with a different remote address.
+                if PySparkSession._activeSession is not None:
+                    PySparkSession._activeSession.stop()
                 del os.environ["SPARK_LOCAL_REMOTE"]
                 del os.environ["SPARK_CONNECT_MODE_ENABLED"]
                 if "SPARK_REMOTE" in os.environ:
@@ -628,20 +665,18 @@ class SparkSession:
         """
         return self.client.is_closed
 
-    @classmethod
-    def getActiveSession(cls) -> Any:
-        raise PySparkNotImplementedError(
-            error_class="NOT_IMPLEMENTED", message_parameters={"feature": 
"getActiveSession()"}
-        )
-
     @property
     def conf(self) -> RuntimeConf:
         return RuntimeConf(self.client)
 
+    conf.__doc__ = PySparkSession.conf.__doc__
+
     @property
     def streams(self) -> "StreamingQueryManager":
         return StreamingQueryManager(self)
 
+    streams.__doc__ = PySparkSession.streams.__doc__
+
     def __getattr__(self, name: str) -> Any:
         if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession"]:
             raise PySparkAttributeError(
@@ -675,6 +710,8 @@ class SparkSession:
         assert result is not None
         return result
 
+    version.__doc__ = PySparkSession.version.__doc__
+
     @property
     def client(self) -> "SparkConnectClient":
         return self._client
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index 2d7e423d3d5..eb0541b9369 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -37,8 +37,7 @@ from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.types import UnparsedDataType
 from pyspark.sql.types import DataType, StringType
 from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
-from pyspark.errors import PySparkTypeError
-
+from pyspark.errors import PySparkTypeError, PySparkRuntimeError
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import (
@@ -58,14 +57,20 @@ def _create_py_udf(
     from pyspark.sql.udf import _create_arrow_py_udf
 
     if useArrow is None:
-        from pyspark.sql.connect.session import _active_spark_session
-
-        is_arrow_enabled = (
-            False
-            if _active_spark_session is None
-            else 
_active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled")
-            == "true"
-        )
+        is_arrow_enabled = False
+        try:
+            from pyspark.sql.connect.session import SparkSession
+
+            session = SparkSession.active()
+            is_arrow_enabled = (
+                
str(session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled")).lower()
+                == "true"
+            )
+        except PySparkRuntimeError as e:
+            if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION":
+                pass  # Just uses the default if no session found.
+            else:
+                raise e
     else:
         is_arrow_enabled = useArrow
 
diff --git a/python/pyspark/sql/connect/udtf.py 
b/python/pyspark/sql/connect/udtf.py
index 07e2bad6ec7..850ffe2b9b4 100644
--- a/python/pyspark/sql/connect/udtf.py
+++ b/python/pyspark/sql/connect/udtf.py
@@ -68,13 +68,20 @@ def _create_py_udtf(
     if useArrow is not None:
         arrow_enabled = useArrow
     else:
-        from pyspark.sql.connect.session import _active_spark_session
+        from pyspark.sql.connect.session import SparkSession
 
         arrow_enabled = False
-        if _active_spark_session is not None:
-            value = 
_active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")
-            if isinstance(value, str) and value.lower() == "true":
-                arrow_enabled = True
+        try:
+            session = SparkSession.active()
+            arrow_enabled = (
+                
str(session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")).lower()
+                == "true"
+            )
+        except PySparkRuntimeError as e:
+            if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION":
+                pass  # Just uses the default if no session found.
+            else:
+                raise e
 
     # Create a regular Python UDTF and check for invalid handler class.
     regular_udtf = _create_udtf(cls, returnType, name, 
PythonEvalType.SQL_TABLE_UDF, deterministic)
@@ -156,17 +163,13 @@ class UserDefinedTableFunction:
         )
 
     def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
+        from pyspark.sql.connect.session import SparkSession
         from pyspark.sql.connect.dataframe import DataFrame
-        from pyspark.sql.connect.session import _active_spark_session
 
-        if _active_spark_session is None:
-            raise PySparkRuntimeError(
-                "An active SparkSession is required for "
-                "executing a Python user-defined table function."
-            )
+        session = SparkSession.active()
 
         plan = self._build_common_inline_user_defined_table_function(*cols)
-        return DataFrame.withPlan(plan, _active_spark_session)
+        return DataFrame.withPlan(plan, session)
 
     def asNondeterministic(self) -> "UserDefinedTableFunction":
         self.deterministic = False
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index ede6318782e..9141051fdf8 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -64,8 +64,8 @@ from pyspark.sql.types import (
     _from_numpy_type,
 )
 from pyspark.errors.exceptions.captured import install_exception_handler
-from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str
-from pyspark.errors import PySparkValueError, PySparkTypeError
+from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, 
try_remote_session_classmethod
+from pyspark.errors import PySparkValueError, PySparkTypeError, 
PySparkRuntimeError
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import AtomicValue, RowLike, OptionalPrimitiveType
@@ -500,7 +500,7 @@ class SparkSession(SparkConversionMixin):
                     ).applyModifiableSettings(session._jsparkSession, 
self._options)
                 return session
 
-        # SparkConnect-specific API
+        # Spark Connect-specific API
         def create(self) -> "SparkSession":
             """Creates a new SparkSession. Can only be used in the context of 
Spark Connect
             and will throw an exception otherwise.
@@ -510,6 +510,10 @@ class SparkSession(SparkConversionMixin):
             Returns
             -------
             :class:`SparkSession`
+
+            Notes
+            -----
+            This method will update the default and/or active session if they 
are not set.
             """
             opts = dict(self._options)
             if "SPARK_REMOTE" in os.environ or "spark.remote" in opts:
@@ -546,7 +550,11 @@ class SparkSession(SparkConversionMixin):
     # to Python 3.9.6 (https://github.com/python/cpython/pull/28838)
     @classproperty
     def builder(cls) -> Builder:
-        """Creates a :class:`Builder` for constructing a 
:class:`SparkSession`."""
+        """Creates a :class:`Builder` for constructing a :class:`SparkSession`.
+
+        .. versionchanged:: 3.4.0
+            Supports Spark Connect.
+        """
         return cls.Builder()
 
     _instantiatedSession: ClassVar[Optional["SparkSession"]] = None
@@ -632,12 +640,16 @@ class SparkSession(SparkConversionMixin):
         return self.__class__(self._sc, self._jsparkSession.newSession())
 
     @classmethod
+    @try_remote_session_classmethod
     def getActiveSession(cls) -> Optional["SparkSession"]:
         """
         Returns the active :class:`SparkSession` for the current thread, 
returned by the builder
 
         .. versionadded:: 3.0.0
 
+        .. versionchanged:: 3.5.0
+            Supports Spark Connect.
+
         Returns
         -------
         :class:`SparkSession`
@@ -667,6 +679,30 @@ class SparkSession(SparkConversionMixin):
             else:
                 return None
 
+    @classmethod
+    @try_remote_session_classmethod
+    def active(cls) -> "SparkSession":
+        """
+        Returns the active or default :class:`SparkSession` for the current 
thread, returned by
+        the builder.
+
+        .. versionadded:: 3.5.0
+
+        Returns
+        -------
+        :class:`SparkSession`
+            Spark session if an active or default session exists for the 
current thread.
+        """
+        session = cls.getActiveSession()
+        if session is None:
+            session = cls._instantiatedSession
+            if session is None:
+                raise PySparkRuntimeError(
+                    error_class="NO_ACTIVE_OR_DEFAULT_SESSION",
+                    message_parameters={},
+                )
+        return session
+
     @property
     def sparkContext(self) -> SparkContext:
         """
@@ -698,6 +734,9 @@ class SparkSession(SparkConversionMixin):
 
         .. versionadded:: 2.0.0
 
+        .. versionchanged:: 3.4.0
+            Supports Spark Connect.
+
         Returns
         -------
         str
@@ -719,6 +758,9 @@ class SparkSession(SparkConversionMixin):
 
         .. versionadded:: 2.0.0
 
+        .. versionchanged:: 3.4.0
+            Supports Spark Connect.
+
         Returns
         -------
         :class:`pyspark.sql.conf.RuntimeConfig`
@@ -726,7 +768,7 @@ class SparkSession(SparkConversionMixin):
         Examples
         --------
         >>> spark.conf
-        <pyspark.sql.conf.RuntimeConfig object ...>
+        <pyspark...RuntimeConf...>
 
         Set a runtime configuration for the session
 
@@ -805,6 +847,9 @@ class SparkSession(SparkConversionMixin):
 
         .. versionadded:: 3.5.0
 
+        .. versionchanged:: 3.5.0
+            Supports Spark Connect.
+
         Returns
         -------
         :class:`UDTFRegistration`
@@ -1639,6 +1684,9 @@ class SparkSession(SparkConversionMixin):
 
         .. versionadded:: 2.0.0
 
+        .. versionchanged:: 3.5.0
+            Supports Spark Connect.
+
         Notes
         -----
         This API is evolving.
@@ -1650,7 +1698,7 @@ class SparkSession(SparkConversionMixin):
         Examples
         --------
         >>> spark.readStream
-        <pyspark.sql.streaming.readwriter.DataStreamReader object ...>
+        <pyspark...DataStreamReader object ...>
 
         The example below uses Rate source that generates rows continuously.
         After that, we operate a modulo by 3, and then write the stream out to 
the console.
@@ -1672,6 +1720,9 @@ class SparkSession(SparkConversionMixin):
 
         .. versionadded:: 2.0.0
 
+        .. versionchanged:: 3.5.0
+            Supports Spark Connect.
+
         Notes
         -----
         This API is evolving.
@@ -1683,7 +1734,7 @@ class SparkSession(SparkConversionMixin):
         Examples
         --------
         >>> spark.streams
-        <pyspark.sql.streaming.query.StreamingQueryManager object ...>
+        <pyspark...StreamingQueryManager object ...>
 
         Get the list of active streaming queries
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 065f1585a9f..0687fc9f313 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3043,9 +3043,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
     def test_unsupported_session_functions(self):
         # SPARK-41934: Disable unsupported functions.
 
-        with self.assertRaises(NotImplementedError):
-            RemoteSparkSession.getActiveSession()
-
         with self.assertRaises(NotImplementedError):
             RemoteSparkSession.builder.enableHiveSupport()
 
@@ -3331,6 +3328,7 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
         spark.stop()
 
     def test_can_create_multiple_sessions_to_different_remotes(self):
+        self.spark.stop()
         self.assertIsNotNone(self.spark._client)
         # Creates a new remote session.
         other = 
PySparkSession.builder.remote("sc://other.remote:114/").create()
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 608ed7e9ac9..b72d8d9a7c8 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import inspect
 import functools
 import os
 from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, 
TypeVar, Union, Type
@@ -239,6 +240,23 @@ def try_remote_observation(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+def try_remote_session_classmethod(f: FuncT) -> FuncT:
+    """Mark API supported from Spark Connect."""
+
+    @functools.wraps(f)
+    def wrapped(*args: Any, **kwargs: Any) -> Any:
+
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            from pyspark.sql.connect.session import SparkSession  # type: 
ignore[misc]
+
+            assert inspect.isclass(args[0])
+            return getattr(SparkSession, f.__name__)(*args[1:], **kwargs)
+        else:
+            return f(*args, **kwargs)
+
+    return cast(FuncT, wrapped)
+
+
 def pyspark_column_op(
     func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None
 ) -> Union["SeriesOrIndex", None]:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to