This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c291564c7d4 [SPARK-42437][CONNECT][PYTHON][FOLLOW-UP] Storage level proto converters c291564c7d4 is described below commit c291564c7d493a9da5e2315d6bab28796dfce7ce Author: khalidmammadov <khalidmammad...@gmail.com> AuthorDate: Wed Apr 19 13:30:53 2023 -0700 [SPARK-42437][CONNECT][PYTHON][FOLLOW-UP] Storage level proto converters ### What changes were proposed in this pull request? Converters between Proto and StorageLevel to avoid code duplication It's follow up from https://github.com/apache/spark/pull/40015 ### Why are the changes needed? Code deduplication ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #40859 from khalidmammadov/storage_level_converter. Authored-by: khalidmammadov <khalidmammad...@gmail.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/connect/client.py | 19 +++---------------- python/pyspark/sql/connect/conversion.py | 24 ++++++++++++++++++++++++ python/pyspark/sql/connect/plan.py | 11 ++--------- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 780c5702477..60f3f1ac2ba 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -60,6 +60,7 @@ from google.protobuf import text_format from google.rpc import error_details_pb2 from pyspark.resource.information import ResourceInformation +from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib import pyspark.sql.connect.types as types @@ -469,13 +470,7 @@ class AnalyzeResult: elif pb.HasField("unpersist"): pass elif pb.HasField("get_storage_level"): - storage_level = StorageLevel( - useDisk=pb.get_storage_level.storage_level.use_disk, - useMemory=pb.get_storage_level.storage_level.use_memory, - useOffHeap=pb.get_storage_level.storage_level.use_off_heap, - deserialized=pb.get_storage_level.storage_level.deserialized, - replication=pb.get_storage_level.storage_level.replication, - ) + storage_level = proto_to_storage_level(pb.get_storage_level.storage_level) else: raise SparkConnectException("No analyze result found!") @@ -877,15 +872,7 @@ class SparkConnectClient(object): req.persist.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation"))) if kwargs.get("storage_level", None) is not None: storage_level = cast(StorageLevel, kwargs.get("storage_level")) - req.persist.storage_level.CopyFrom( - pb2.StorageLevel( - use_disk=storage_level.useDisk, - use_memory=storage_level.useMemory, - use_off_heap=storage_level.useOffHeap, - deserialized=storage_level.deserialized, - replication=storage_level.replication, - ) - ) + req.persist.storage_level.CopyFrom(storage_level_to_proto(storage_level)) elif method == "unpersist": req.unpersist.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation"))) if kwargs.get("blocking", None) is not None: diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index 5a31d1df67e..a6fe0c00e09 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -43,7 +43,9 @@ from pyspark.sql.types import ( cast, ) +from pyspark.storagelevel import StorageLevel from pyspark.sql.connect.types import to_arrow_schema +import pyspark.sql.connect.proto as pb2 from typing import ( Any, @@ -486,3 +488,25 @@ class ArrowTableToRowsConversion: values = [field_converters[j](columnar_data[j][i]) for j in range(table.num_columns)] rows.append(_create_row(fields=schema.fieldNames(), values=values)) return rows + + +def storage_level_to_proto(storage_level: StorageLevel) -> pb2.StorageLevel: + assert storage_level is not None and isinstance(storage_level, StorageLevel) + return pb2.StorageLevel( + use_disk=storage_level.useDisk, + use_memory=storage_level.useMemory, + use_off_heap=storage_level.useOffHeap, + deserialized=storage_level.deserialized, + replication=storage_level.replication, + ) + + +def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel: + assert storage_level is not None and isinstance(storage_level, pb2.StorageLevel) + return StorageLevel( + useDisk=storage_level.use_disk, + useMemory=storage_level.use_memory, + useOffHeap=storage_level.use_off_heap, + deserialized=storage_level.deserialized, + replication=storage_level.replication, + ) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index c3b81cf80f0..9e221814f12 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -30,6 +30,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.sql.types import DataType import pyspark.sql.connect.proto as proto +from pyspark.sql.connect.conversion import storage_level_to_proto from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( SortOrder, @@ -1896,15 +1897,7 @@ class CacheTable(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: _cache_table = proto.CacheTable(table_name=self._table_name) if self._storage_level: - _cache_table.storage_level.CopyFrom( - proto.StorageLevel( - use_disk=self._storage_level.useDisk, - use_memory=self._storage_level.useMemory, - use_off_heap=self._storage_level.useOffHeap, - deserialized=self._storage_level.deserialized, - replication=self._storage_level.replication, - ) - ) + _cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level)) plan = proto.Relation(catalog=proto.Catalog(cache_table=_cache_table)) return plan --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org