This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 93e0acbf7d9 [SPARK-43971][CONNECT][PYTHON] Support Python's 
createDataFrame in streaming manner
93e0acbf7d9 is described below

commit 93e0acbf7d9fcf3422860b2a5d39379bebf7bc43
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Sat Jun 10 01:25:04 2023 +0300

    [SPARK-43971][CONNECT][PYTHON] Support Python's createDataFrame in 
streaming manner
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to transfer a local relation from **the Python connect 
client** to the server in streaming way when it exceeds some size which is 
defined by the SQL config `spark.sql.session.localRelationCacheThreshold`. The 
implementation is similar to https://github.com/apache/spark/pull/40827.  In 
particular:
    1. The client applies the `sha256` function over **the proto form** of the 
local relation;
    2. It checks presents of the relation at the server side by sending the 
relation hash to the server;
    3. If the server doesn't have the local relation, the client transfers the 
local relation as an artefact with the name `cache/<sha256>`;
    4. As soon as the relation has presented at the server already, or 
transferred recently, the client transform the logical plan by replacing the 
`LocalRelation` node by `CachedLocalRelation` with the hash.
    5. On another hand, the server converts `CachedLocalRelation` back to 
`LocalRelation` by retrieving the relation body from the local cache.
    
    ### Why are the changes needed?
    To fix the issues of creating a large dataframe from a local collection:
    ```python
    pyspark.errors.exceptions.connect.SparkConnectGrpcException: 
<_MultiThreadedRendezvous of RPC that terminated with:
            status = StatusCode.RESOURCE_EXHAUSTED
            details = "Sent message larger than max (134218508 vs. 134217728)"
            debug_error_string = "UNKNOWN:Error received from peer 
localhost:50982 {grpc_message:"Sent message larger than max (134218508 vs. 
134217728)", grpc_status:8, created_time:"2023-06-09T15:34:08.362797+03:00"}
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    By running new test:
    ```
    $ python/run-tests --parallelism=1 --testnames 
'pyspark.sql.tests.connect.test_connect_basic 
SparkConnectBasicTests.test_streaming_local_relation'
    ```
    
    Closes #41537 from MaxGekk/streaming-createDataFrame-python-4.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 python/pyspark/sql/connect/client/core.py          |  3 ++
 python/pyspark/sql/connect/plan.py                 | 34 ++++++++++++++++++++++
 python/pyspark/sql/connect/session.py              | 26 +++++++++++++++--
 .../sql/tests/connect/test_connect_basic.py        | 19 ++++++++++++
 4 files changed, 79 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 25e395356d5..7368521259a 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1257,6 +1257,9 @@ class SparkConnectClient(object):
     def copy_from_local_to_fs(self, local_path: str, dest_path: str) -> None:
         self._artifact_manager._add_forward_to_fs_artifacts(local_path, 
dest_path)
 
+    def cache_artifact(self, blob: bytes) -> str:
+        return self._artifact_manager.cache_artifact(blob)
+
 
 class RetryState:
     """
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index fc8b37b102c..406f65080d1 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -363,6 +363,10 @@ class LocalRelation(LogicalPlan):
             plan.local_relation.schema = self._schema
         return plan
 
+    def serialize(self, session: "SparkConnectClient") -> bytes:
+        p = self.plan(session)
+        return bytes(p.local_relation.SerializeToString())
+
     def print(self, indent: int = 0) -> str:
         return f"{' ' * indent}<LocalRelation>\n"
 
@@ -374,6 +378,36 @@ class LocalRelation(LogicalPlan):
         """
 
 
+class CachedLocalRelation(LogicalPlan):
+    """Creates a CachedLocalRelation plan object based on a hash of a 
LocalRelation."""
+
+    def __init__(self, hash: str) -> None:
+        super().__init__(None)
+
+        self._hash = hash
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        plan = self._create_proto_relation()
+        clr = plan.cached_local_relation
+
+        if session._user_id:
+            clr.userId = session._user_id
+        clr.sessionId = session._session_id
+        clr.hash = self._hash
+
+        return plan
+
+    def print(self, indent: int = 0) -> str:
+        return f"{' ' * indent}<CachedLocalRelation>\n"
+
+    def _repr_html_(self) -> str:
+        return """
+        <ul>
+            <li><b>CachedLocalRelation</b></li>
+        </ul>
+        """
+
+
 class ShowString(LogicalPlan):
     def __init__(
         self, child: Optional["LogicalPlan"], num_rows: int, truncate: int, 
vertical: bool
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 5877181963c..2b35ca3d7ea 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -51,7 +51,14 @@ from pyspark import SparkContext, SparkConf, __version__
 from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
 from pyspark.sql.connect.conf import RuntimeConf
 from pyspark.sql.connect.dataframe import DataFrame
-from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation
+from pyspark.sql.connect.plan import (
+    SQL,
+    Range,
+    LocalRelation,
+    LogicalPlan,
+    CachedLocalRelation,
+    CachedRelation,
+)
 from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.connect.streaming import DataStreamReader, 
StreamingQueryManager
 from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
@@ -466,10 +473,16 @@ class SparkSession:
             )
 
         if _schema is not None:
-            df = DataFrame.withPlan(LocalRelation(_table, 
schema=_schema.json()), self)
+            local_relation = LocalRelation(_table, schema=_schema.json())
         else:
-            df = DataFrame.withPlan(LocalRelation(_table), self)
+            local_relation = LocalRelation(_table)
+
+        cache_threshold = 
self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
+        plan: LogicalPlan = local_relation
+        if cache_threshold[0] is not None and int(cache_threshold[0]) <= 
_table.nbytes:
+            plan = 
CachedLocalRelation(self._cache_local_relation(local_relation))
 
+        df = DataFrame.withPlan(plan, self)
         if _cols is not None and len(_cols) > 0:
             df = df.toDF(*_cols)
         return df
@@ -643,6 +656,13 @@ class SparkSession:
 
     addArtifact = addArtifacts
 
+    def _cache_local_relation(self, local_relation: LocalRelation) -> str:
+        """
+        Cache the local relation at the server side if it has not been cached 
yet.
+        """
+        serialized = local_relation.serialize(self._client)
+        return self._client.cache_artifact(serialized)
+
     def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
         """
         Copy file from local to cloud storage file system.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index a5139603919..18a7d8f19b4 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -19,7 +19,9 @@ import array
 import datetime
 import os
 import unittest
+import random
 import shutil
+import string
 import tempfile
 from collections import defaultdict
 
@@ -649,6 +651,23 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.assertEqual(sdf.schema, cdf.schema)
             self.assert_eq(sdf.toPandas(), cdf.toPandas())
 
+    def test_streaming_local_relation(self):
+        threshold_conf = "spark.sql.session.localRelationCacheThreshold"
+        old_threshold = self.connect.conf.get(threshold_conf)
+        threshold = 1024 * 1024
+        self.connect.conf.set(threshold_conf, threshold)
+        try:
+            suffix = "abcdef"
+            letters = string.ascii_lowercase
+            str = "".join(random.choice(letters) for i in range(threshold)) + 
suffix
+            data = [[0, str], [1, str]]
+            for i in range(0, 2):
+                cdf = self.connect.createDataFrame(data, ["a", "b"])
+                self.assert_eq(cdf.count(), len(data))
+                self.assert_eq(cdf.filter(f"endsWith(b, 
'{suffix}')").isEmpty(), False)
+        finally:
+            self.connect.conf.set(threshold_conf, old_threshold)
+
     def test_with_atom_type(self):
         for data in [[(1), (2), (3)], [1, 2, 3]]:
             for schema in ["long", "int", "short"]:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to