This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c291564c7d4 [SPARK-42437][CONNECT][PYTHON][FOLLOW-UP] Storage level 
proto converters
c291564c7d4 is described below

commit c291564c7d493a9da5e2315d6bab28796dfce7ce
Author: khalidmammadov <khalidmammad...@gmail.com>
AuthorDate: Wed Apr 19 13:30:53 2023 -0700

    [SPARK-42437][CONNECT][PYTHON][FOLLOW-UP] Storage level proto converters
    
    ### What changes were proposed in this pull request?
    Converters between Proto and StorageLevel to avoid code duplication
    It's follow up from https://github.com/apache/spark/pull/40015
    
    ### Why are the changes needed?
    Code deduplication
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing tests
    
    Closes #40859 from khalidmammadov/storage_level_converter.
    
    Authored-by: khalidmammadov <khalidmammad...@gmail.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/sql/connect/client.py     | 19 +++----------------
 python/pyspark/sql/connect/conversion.py | 24 ++++++++++++++++++++++++
 python/pyspark/sql/connect/plan.py       | 11 ++---------
 3 files changed, 29 insertions(+), 25 deletions(-)

diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 780c5702477..60f3f1ac2ba 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -60,6 +60,7 @@ from google.protobuf import text_format
 from google.rpc import error_details_pb2
 
 from pyspark.resource.information import ResourceInformation
+from pyspark.sql.connect.conversion import storage_level_to_proto, 
proto_to_storage_level
 import pyspark.sql.connect.proto as pb2
 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
 import pyspark.sql.connect.types as types
@@ -469,13 +470,7 @@ class AnalyzeResult:
         elif pb.HasField("unpersist"):
             pass
         elif pb.HasField("get_storage_level"):
-            storage_level = StorageLevel(
-                useDisk=pb.get_storage_level.storage_level.use_disk,
-                useMemory=pb.get_storage_level.storage_level.use_memory,
-                useOffHeap=pb.get_storage_level.storage_level.use_off_heap,
-                deserialized=pb.get_storage_level.storage_level.deserialized,
-                replication=pb.get_storage_level.storage_level.replication,
-            )
+            storage_level = 
proto_to_storage_level(pb.get_storage_level.storage_level)
         else:
             raise SparkConnectException("No analyze result found!")
 
@@ -877,15 +872,7 @@ class SparkConnectClient(object):
             req.persist.relation.CopyFrom(cast(pb2.Relation, 
kwargs.get("relation")))
             if kwargs.get("storage_level", None) is not None:
                 storage_level = cast(StorageLevel, kwargs.get("storage_level"))
-                req.persist.storage_level.CopyFrom(
-                    pb2.StorageLevel(
-                        use_disk=storage_level.useDisk,
-                        use_memory=storage_level.useMemory,
-                        use_off_heap=storage_level.useOffHeap,
-                        deserialized=storage_level.deserialized,
-                        replication=storage_level.replication,
-                    )
-                )
+                
req.persist.storage_level.CopyFrom(storage_level_to_proto(storage_level))
         elif method == "unpersist":
             req.unpersist.relation.CopyFrom(cast(pb2.Relation, 
kwargs.get("relation")))
             if kwargs.get("blocking", None) is not None:
diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index 5a31d1df67e..a6fe0c00e09 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -43,7 +43,9 @@ from pyspark.sql.types import (
     cast,
 )
 
+from pyspark.storagelevel import StorageLevel
 from pyspark.sql.connect.types import to_arrow_schema
+import pyspark.sql.connect.proto as pb2
 
 from typing import (
     Any,
@@ -486,3 +488,25 @@ class ArrowTableToRowsConversion:
             values = [field_converters[j](columnar_data[j][i]) for j in 
range(table.num_columns)]
             rows.append(_create_row(fields=schema.fieldNames(), values=values))
         return rows
+
+
+def storage_level_to_proto(storage_level: StorageLevel) -> pb2.StorageLevel:
+    assert storage_level is not None and isinstance(storage_level, 
StorageLevel)
+    return pb2.StorageLevel(
+        use_disk=storage_level.useDisk,
+        use_memory=storage_level.useMemory,
+        use_off_heap=storage_level.useOffHeap,
+        deserialized=storage_level.deserialized,
+        replication=storage_level.replication,
+    )
+
+
+def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel:
+    assert storage_level is not None and isinstance(storage_level, 
pb2.StorageLevel)
+    return StorageLevel(
+        useDisk=storage_level.use_disk,
+        useMemory=storage_level.use_memory,
+        useOffHeap=storage_level.use_off_heap,
+        deserialized=storage_level.deserialized,
+        replication=storage_level.replication,
+    )
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index c3b81cf80f0..9e221814f12 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -30,6 +30,7 @@ from pyspark.storagelevel import StorageLevel
 from pyspark.sql.types import DataType
 
 import pyspark.sql.connect.proto as proto
+from pyspark.sql.connect.conversion import storage_level_to_proto
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
     SortOrder,
@@ -1896,15 +1897,7 @@ class CacheTable(LogicalPlan):
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         _cache_table = proto.CacheTable(table_name=self._table_name)
         if self._storage_level:
-            _cache_table.storage_level.CopyFrom(
-                proto.StorageLevel(
-                    use_disk=self._storage_level.useDisk,
-                    use_memory=self._storage_level.useMemory,
-                    use_off_heap=self._storage_level.useOffHeap,
-                    deserialized=self._storage_level.deserialized,
-                    replication=self._storage_level.replication,
-                )
-            )
+            
_cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level))
         plan = proto.Relation(catalog=proto.Catalog(cache_table=_cache_table))
         return plan
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to