This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e7fc4003b246 [SPARK-47812][CONNECT] Support Serialization of 
SparkSession for ForEachBatch worker
e7fc4003b246 is described below

commit e7fc4003b246bab743ab82d9e7bb77c0e2e5946e
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Sat Apr 13 10:30:23 2024 +0900

    [SPARK-47812][CONNECT] Support Serialization of SparkSession for 
ForEachBatch worker
    
    ### What changes were proposed in this pull request?
    
    This patch adds support to register custom dispatch handlers when 
serializing objects using the provided Cloudpickle library. This is necessary 
to provide compatibility when executing ForEachBatch functions in structured 
streaming.
    
    A typical example for this behavior is the following test case:
    
    ```python
    def curried_function(df):
        def inner(batch_df, batch_id):
            df.createOrReplaceTempView("updates")
            batch_df.createOrReplaceTempView("batch_updates")
    
        return inner
    
    df = 
spark.readStream.format("text").load("python/test_support/sql/streaming")
    other_df = self.spark.range(100)
    df.writeStream.foreachBatch(curried_function(other_df)).start()
    ```
    Here we curry a DataFrame into the function called during ForEachBatch and 
effectively passing state. Until now, serializing DataFrames and SparkSessions 
in Spark Connect was not possible since the SparkSession carries the open GPRC 
connection and the DataFrame itself overrides certain magic methods that make 
pickling fail.
    
    To make serializing Spark Sessions possible, we register a custom session 
constructor, that simply returns the current active session, during the 
serialization of the ForEachBatch function. Now, when the ForEachBatch worker 
starts the execution it already creates and registers an active SparkSession. 
To serialize and reconstruct the DataFrame we simply have to pass in the 
session and the plan, the remaining attributes do not carry a permanent state.
    
    To avoid modifying any global behavior, the serialization handlers are not 
registered for all cases but only when the ForEachBatch and ForEach handlers 
are called. This is to make sure that we don't unexpectedly change behavior.
    
    ### Why are the changes needed?
    Compatibility and Ease of Use
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added and updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46002 from grundprinzip/SPARK-47812.
    
    Lead-authored-by: Martin Grund <martin.gr...@databricks.com>
    Co-authored-by: Martin Grund <grundprin...@gmail.com>
    Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py            | 22 +++++++
 python/pyspark/sql/connect/session.py              | 37 ++++++++++++
 .../streaming/worker/foreach_batch_worker.py       | 15 ++++-
 .../connect/streaming/worker/listener_worker.py    | 15 ++++-
 .../connect/streaming/test_parity_foreach_batch.py | 70 +++++++++++++++++-----
 .../connect/streaming/test_parity_listener.py      | 23 ++-----
 .../pyspark/sql/tests/connect/test_parity_udtf.py  | 18 +++++-
 7 files changed, 163 insertions(+), 37 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 1dddcc078810..f0dc412760a4 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -122,6 +122,28 @@ class DataFrame:
         self._support_repr_html = False
         self._cached_schema: Optional[StructType] = None
 
+    def __reduce__(self) -> Tuple:
+        """
+        Custom method for serializing the DataFrame object using Pickle. Since 
the DataFrame
+        overrides "__getattr__" method, the default serialization method does 
not work.
+
+        Returns
+        -------
+        The tuple containing the information needed to reconstruct the object.
+
+        """
+        return (
+            DataFrame,
+            (
+                self._plan,
+                self._session,
+            ),
+            {
+                "_support_repr_html": self._support_repr_html,
+                "_cached_schema": self._cached_schema,
+            },
+        )
+
     def __repr__(self) -> str:
         if not self._support_repr_html:
             (
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 07fe8a62f082..3be6c83cf13b 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -96,6 +96,7 @@ from pyspark.errors import (
     PySparkRuntimeError,
     PySparkValueError,
     PySparkTypeError,
+    PySparkAssertionError,
 )
 
 if TYPE_CHECKING:
@@ -288,6 +289,26 @@ class SparkSession:
     def getActiveSession(cls) -> Optional["SparkSession"]:
         return getattr(cls._active_session, "session", None)
 
+    @classmethod
+    def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession":
+        """
+        Internal use only. This method is called from the custom handler
+        generated by __reduce__. To avoid serializing a WeakRef, we create a
+        custom classmethod to instantiate the SparkSession.
+        """
+        session = SparkSession.getActiveSession()
+        if session is None:
+            raise PySparkRuntimeError(
+                error_class="NO_ACTIVE_SESSION",
+                message_parameters={},
+            )
+        if session._session_id != session_id:
+            raise PySparkAssertionError(
+                "Expected session ID does not match active session ID: "
+                f"{session_id} != {session._session_id}"
+            )
+        return session
+
     getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__
 
     @classmethod
@@ -1034,6 +1055,22 @@ class SparkSession:
 
     profile.__doc__ = PySparkSession.profile.__doc__
 
+    def __reduce__(self) -> Tuple:
+        """
+        This method is called when the object is pickled. It returns a tuple 
of the object's
+        constructor function, arguments to it and the local state of the 
object.
+        This function is supposed to only be used when the active spark 
session that is pickled
+        is the same active spark session that is unpickled.
+        """
+
+        def creator(old_session_id: str) -> "SparkSession":
+            # We cannot perform the checks for session matching here because 
accessing the
+            # session ID property causes the serialization of a WeakRef and in 
turn breaks
+            # the serialization.
+            return SparkSession._getActiveSessionIfMatches(old_session_id)
+
+        return creator, (self._session_id,)
+
 
 SparkSession.__doc__ = PySparkSession.__doc__
 
diff --git 
a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index c4cf52b9996d..92ed7a4aaff5 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -29,7 +29,7 @@ from pyspark.serializers import (
     CPickleSerializer,
 )
 from pyspark import worker
-from pyspark.sql import SparkSession
+from pyspark.sql.connect.session import SparkSession
 from pyspark.util import handle_worker_exception
 from typing import IO
 from pyspark.worker_util import check_python_version
@@ -38,9 +38,16 @@ pickle_ser = CPickleSerializer()
 utf8_deserializer = UTF8Deserializer()
 
 
+spark = None
+
+
 def main(infile: IO, outfile: IO) -> None:
+    global spark
     check_python_version(infile)
 
+    # Enable Spark Connect Mode
+    os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
+
     connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
     session_id = utf8_deserializer.loads(infile)
 
@@ -49,8 +56,11 @@ def main(infile: IO, outfile: IO) -> None:
         f"url {connect_url} and sessionId {session_id}."
     )
 
+    # To attach to the existing SparkSession, we're setting the session_id in 
the URL.
+    connect_url = connect_url + ";session_id=" + session_id
     spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
-    spark_connect_session._client._session_id = session_id  # type: 
ignore[attr-defined]
+    assert spark_connect_session.session_id == session_id
+    spark = spark_connect_session
 
     # TODO(SPARK-44461): Enable Process Isolation
 
@@ -62,6 +72,7 @@ def main(infile: IO, outfile: IO) -> None:
     log_name = "Streaming ForeachBatch worker"
 
     def process(df_id, batch_id):  # type: ignore[no-untyped-def]
+        global spark
         print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
         batch_df = spark_connect_session._create_remote_dataframe(df_id)
         func(batch_df, batch_id)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index 69e0d8a46248..d3efb5894fc0 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -30,7 +30,7 @@ from pyspark.serializers import (
     CPickleSerializer,
 )
 from pyspark import worker
-from pyspark.sql import SparkSession
+from pyspark.sql.connect.session import SparkSession
 from pyspark.util import handle_worker_exception
 from typing import IO
 
@@ -46,9 +46,16 @@ pickle_ser = CPickleSerializer()
 utf8_deserializer = UTF8Deserializer()
 
 
+spark = None
+
+
 def main(infile: IO, outfile: IO) -> None:
+    global spark
     check_python_version(infile)
 
+    # Enable Spark Connect Mode
+    os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
+
     connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
     session_id = utf8_deserializer.loads(infile)
 
@@ -57,8 +64,11 @@ def main(infile: IO, outfile: IO) -> None:
         f"url {connect_url} and sessionId {session_id}."
     )
 
+    # To attach to the existing SparkSession, we're setting the session_id in 
the URL.
+    connect_url = connect_url + ";session_id=" + session_id
     spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
-    spark_connect_session._client._session_id = session_id  # type: 
ignore[attr-defined]
+    assert spark_connect_session.session_id == session_id
+    spark = spark_connect_session
 
     # TODO(SPARK-44461): Enable Process Isolation
 
@@ -71,6 +81,7 @@ def main(infile: IO, outfile: IO) -> None:
     assert listener.spark == spark_connect_session
 
     def process(listener_event_str, listener_event_type):  # type: 
ignore[no-untyped-def]
+        global spark
         listener_event = json.loads(listener_event_str)
         if listener_event_type == 0:
             listener.onQueryStarted(QueryStartedEvent.fromJson(listener_event))
diff --git 
a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
index 30f7bb8c2df9..4598cbbdca4e 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
@@ -30,33 +30,73 @@ class 
StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedCo
     def test_streaming_foreach_batch_graceful_stop(self):
         super().test_streaming_foreach_batch_graceful_stop()
 
+    def test_nested_dataframes(self):
+        def curried_function(df):
+            def inner(batch_df, batch_id):
+                df.createOrReplaceTempView("updates")
+                batch_df.createOrReplaceTempView("batch_updates")
+
+            return inner
+
+        try:
+            df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+            other_df = self.spark.range(100)
+            q = df.writeStream.foreachBatch(curried_function(other_df)).start()
+            q.processAllAvailable()
+            collected = self.spark.sql("select * from batch_updates").collect()
+            self.assertTrue(len(collected), 2)
+            self.assertEqual(100, self.spark.sql("select * from 
updates").count())
+        finally:
+            if q:
+                q.stop()
+
+    def test_pickling_error(self):
+        class NoPickle:
+            def __reduce__(self):
+                raise ValueError("No pickle")
+
+        no_pickle = NoPickle()
+
+        def func(df, _):
+            print(no_pickle)
+            df.count()
+
+        with self.assertRaises(PySparkPicklingError):
+            df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+            q = df.writeStream.foreachBatch(func).start()
+            q.processAllAvailable()
+
     def test_accessing_spark_session(self):
         spark = self.spark
 
         def func(df, _):
-            spark.createDataFrame([("do", "not"), ("serialize", 
"spark")]).collect()
+            spark.createDataFrame([("you", "can"), ("serialize", 
"spark")]).createOrReplaceTempView(
+                "test_accessing_spark_session"
+            )
 
-        error_thrown = False
         try:
-            
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
-        except PySparkPicklingError as e:
-            self.assertEqual(e.getErrorClass(), 
"STREAMING_CONNECT_SERIALIZATION_ERROR")
-            error_thrown = True
-        self.assertTrue(error_thrown)
+            df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+            q = df.writeStream.foreachBatch(func).start()
+            q.processAllAvailable()
+            self.assertEqual(2, 
spark.table("test_accessing_spark_session").count())
+        finally:
+            if q:
+                q.stop()
 
     def test_accessing_spark_session_through_df(self):
-        dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", 
"dataframe")])
+        dataframe = self.spark.createDataFrame([("you", "can"), ("serialize", 
"dataframe")])
 
         def func(df, _):
-            dataframe.collect()
+            
dataframe.createOrReplaceTempView("test_accessing_spark_session_through_df")
 
-        error_thrown = False
         try:
-            
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
-        except PySparkPicklingError as e:
-            self.assertEqual(e.getErrorClass(), 
"STREAMING_CONNECT_SERIALIZATION_ERROR")
-            error_thrown = True
-        self.assertTrue(error_thrown)
+            df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+            q = df.writeStream.foreachBatch(func).start()
+            q.processAllAvailable()
+            self.assertEqual(2, 
self.spark.table("test_accessing_spark_session_through_df").count())
+        finally:
+            if q:
+                q.stop()
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index f5ffa0154df1..a15e4547f67a 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -19,7 +19,6 @@ import unittest
 import time
 
 import pyspark.cloudpickle
-from pyspark.errors import PySparkPicklingError
 from pyspark.sql.tests.streaming.test_streaming_listener import 
StreamingListenerTestsMixin
 from pyspark.sql.streaming.listener import StreamingQueryListener
 from pyspark.sql.functions import count, lit
@@ -138,7 +137,9 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
 
         class TestListener(StreamingQueryListener):
             def onQueryStarted(self, event):
-                spark.createDataFrame([("do", "not"), ("serialize", 
"spark")]).collect()
+                spark.createDataFrame(
+                    [("you", "can"), ("serialize", "spark")]
+                ).createOrReplaceTempView("test_accessing_spark_session")
 
             def onQueryProgress(self, event):
                 pass
@@ -149,16 +150,10 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
             def onQueryTerminated(self, event):
                 pass
 
-        error_thrown = False
-        try:
-            self.spark.streams.addListener(TestListener())
-        except PySparkPicklingError as e:
-            self.assertEqual(e.getErrorClass(), 
"STREAMING_CONNECT_SERIALIZATION_ERROR")
-            error_thrown = True
-        self.assertTrue(error_thrown)
+        self.spark.streams.addListener(TestListener())
 
     def test_accessing_spark_session_through_df(self):
-        dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", 
"dataframe")])
+        dataframe = self.spark.createDataFrame([("you", "can"), ("serialize", 
"dataframe")])
 
         class TestListener(StreamingQueryListener):
             def onQueryStarted(self, event):
@@ -173,13 +168,7 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
             def onQueryTerminated(self, event):
                 pass
 
-        error_thrown = False
-        try:
-            self.spark.streams.addListener(TestListener())
-        except PySparkPicklingError as e:
-            self.assertEqual(e.getErrorClass(), 
"STREAMING_CONNECT_SERIALIZATION_ERROR")
-            error_thrown = True
-        self.assertTrue(error_thrown)
+        self.spark.streams.addListener(TestListener())
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py 
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index 02570ac9efa7..5071b69060a1 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -28,7 +28,7 @@ if should_test_connect:
 from pyspark.util import is_remote_only
 from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.errors.exceptions.connect import SparkConnectGrpcException
+from pyspark.errors.exceptions.connect import SparkConnectGrpcException, 
PythonException
 
 
 class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase):
@@ -76,6 +76,10 @@ class UDTFParityTests(BaseUDTFTestsMixin, 
ReusedConnectTestCase):
     def test_udtf_with_analyze_using_file(self):
         super().test_udtf_with_analyze_using_file()
 
+    @unittest.skip("pyspark-connect can serialize SparkSession, but fails on 
executor")
+    def test_udtf_access_spark_session(self):
+        super().test_udtf_access_spark_session()
+
     def _add_pyfile(self, path):
         self.spark.addArtifacts(path, pyfile=True)
 
@@ -99,6 +103,18 @@ class ArrowUDTFParityTests(UDTFArrowTestsMixin, 
UDTFParityTests):
         finally:
             super(ArrowUDTFParityTests, cls).tearDownClass()
 
+    def test_udtf_access_spark_session_connect(self):
+        df = self.spark.range(10)
+
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self):
+                df.collect()
+                yield 1,
+
+        with self.assertRaisesRegex(PythonException, "NO_ACTIVE_SESSION"):
+            TestUDTF().collect()
+
 
 if __name__ == "__main__":
     import unittest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to