This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 58863dfa1b4b [SPARK-45394][SPARK-45093][PYTHON][CONNECT] Add retries
for artifact API. Improve error handling (follow-up to [])
58863dfa1b4b is described below
commit 58863dfa1b4b84dee5a0d6323265f6f3bb71a763
Author: Alice Sayutina <[email protected]>
AuthorDate: Fri Oct 6 11:21:17 2023 +0900
[SPARK-45394][SPARK-45093][PYTHON][CONNECT] Add retries for artifact API.
Improve error handling (follow-up to [])
### What changes were proposed in this pull request?
1. Add retries to `add_artifact` api in client
2. Slightly change control flow within `artifact.py` so that client-side
errors (e.g. FileNotFound) are properly thrown. (Previously we attempted to add
logs in https://github.com/apache/spark/pull/42949, but that was imperfect
solution, this should be much better).
3. Accept proper ownership over files in LocalData, and close those
descriptors.
### Why are the changes needed?
Improves user experience
### Does this PR introduce _any_ user-facing change?
Improve error handling, adds retries.
### How was this patch tested?
Added test coverage for add_artifact when there is no artifact.
### Was this patch authored or co-authored using generative AI tooling?
NO
Closes #43216 from cdkrot/SPARK-45394.
Authored-by: Alice Sayutina <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/client/artifact.py | 94 ++++++++++++----------
python/pyspark/sql/connect/client/core.py | 18 ++++-
.../sql/tests/connect/client/test_artifact.py | 7 ++
3 files changed, 72 insertions(+), 47 deletions(-)
diff --git a/python/pyspark/sql/connect/client/artifact.py
b/python/pyspark/sql/connect/client/artifact.py
index fb31a57e0f62..5829ec9a8d4d 100644
--- a/python/pyspark/sql/connect/client/artifact.py
+++ b/python/pyspark/sql/connect/client/artifact.py
@@ -52,7 +52,6 @@ class LocalData(metaclass=abc.ABCMeta):
Payload stored on this machine.
"""
- @cached_property
@abc.abstractmethod
def stream(self) -> BinaryIO:
pass
@@ -70,14 +69,18 @@ class LocalFile(LocalData):
def __init__(self, path: str):
self.path = path
- self._size: int
- self._stream: int
+
+ # Check that the file can be read
+ # so that incorrect references can be discovered during Artifact
creation,
+ # and not at the point of consumption.
+
+ with self.stream():
+ pass
@cached_property
def size(self) -> int:
return os.path.getsize(self.path)
- @cached_property
def stream(self) -> BinaryIO:
return open(self.path, "rb")
@@ -89,14 +92,11 @@ class InMemory(LocalData):
def __init__(self, blob: bytes):
self.blob = blob
- self._size: int
- self._stream: int
@cached_property
def size(self) -> int:
return len(self.blob)
- @cached_property
def stream(self) -> BinaryIO:
return io.BytesIO(self.blob)
@@ -244,18 +244,23 @@ class ArtifactManager:
self, *path: str, pyfile: bool, archive: bool, file: bool
) -> Iterator[proto.AddArtifactsRequest]:
"""Separated for the testing purpose."""
- try:
- yield from self._add_artifacts(
- chain(
- *(
- self._parse_artifacts(p, pyfile=pyfile,
archive=archive, file=file)
- for p in path
- )
- )
- )
- except Exception as e:
- logger.error(f"Failed to submit addArtifacts request: {e}")
- raise
+
+ # It's crucial that this function is not generator, but only returns
generator.
+ # This way we are doing artifact parsing within the original caller
thread
+ # And not during grpc consuming iterator, allowing for much better
error reporting.
+
+ artifacts: Iterator[Artifact] = chain(
+ *(self._parse_artifacts(p, pyfile=pyfile, archive=archive,
file=file) for p in path)
+ )
+
+ def generator() -> Iterator[proto.AddArtifactsRequest]:
+ try:
+ yield from self._add_artifacts(artifacts)
+ except Exception as e:
+ logger.error(f"Failed to submit addArtifacts request: {e}")
+ raise
+
+ return generator()
def _retrieve_responses(
self, requests: Iterator[proto.AddArtifactsRequest]
@@ -279,6 +284,7 @@ class ArtifactManager:
requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(
*path, pyfile=pyfile, archive=archive, file=file
)
+
self._request_add_artifacts(requests)
def _add_forward_to_fs_artifacts(self, local_path: str, dest_path: str) ->
None:
@@ -337,7 +343,8 @@ class ArtifactManager:
artifact_chunks = []
for artifact in artifacts:
- binary = artifact.storage.stream.read()
+ with artifact.storage.stream() as stream:
+ binary = stream.read()
crc32 = zlib.crc32(binary)
data = proto.AddArtifactsRequest.ArtifactChunk(data=binary,
crc=crc32)
artifact_chunks.append(
@@ -363,31 +370,32 @@ class ArtifactManager:
)
# Consume stream in chunks until there is no data left to read.
- for chunk in iter(lambda:
artifact.storage.stream.read(ArtifactManager.CHUNK_SIZE), b""):
- if initial_batch:
- # First RPC contains the `BeginChunkedArtifact` payload
(`begin_chunk`).
- yield proto.AddArtifactsRequest(
- session_id=self._session_id,
- user_context=self._user_context,
- begin_chunk=proto.AddArtifactsRequest.BeginChunkedArtifact(
- name=artifact.path,
- total_bytes=artifact.size,
- num_chunks=get_num_chunks,
- initial_chunk=proto.AddArtifactsRequest.ArtifactChunk(
+ with artifact.storage.stream() as stream:
+ for chunk in iter(lambda: stream.read(ArtifactManager.CHUNK_SIZE),
b""):
+ if initial_batch:
+ # First RPC contains the `BeginChunkedArtifact` payload
(`begin_chunk`).
+ yield proto.AddArtifactsRequest(
+ session_id=self._session_id,
+ user_context=self._user_context,
+
begin_chunk=proto.AddArtifactsRequest.BeginChunkedArtifact(
+ name=artifact.path,
+ total_bytes=artifact.size,
+ num_chunks=get_num_chunks,
+
initial_chunk=proto.AddArtifactsRequest.ArtifactChunk(
+ data=chunk, crc=zlib.crc32(chunk)
+ ),
+ ),
+ )
+ initial_batch = False
+ else:
+ # Subsequent RPCs contains the `ArtifactChunk` payload
(`chunk`).
+ yield proto.AddArtifactsRequest(
+ session_id=self._session_id,
+ user_context=self._user_context,
+ chunk=proto.AddArtifactsRequest.ArtifactChunk(
data=chunk, crc=zlib.crc32(chunk)
),
- ),
- )
- initial_batch = False
- else:
- # Subsequent RPCs contains the `ArtifactChunk` payload
(`chunk`).
- yield proto.AddArtifactsRequest(
- session_id=self._session_id,
- user_context=self._user_context,
- chunk=proto.AddArtifactsRequest.ArtifactChunk(
- data=chunk, crc=zlib.crc32(chunk)
- ),
- )
+ )
def is_cached_artifact(self, hash: str) -> bool:
"""
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index db7f8e6dc75c..9e47379c85e7 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1538,14 +1538,24 @@ class SparkConnectClient(object):
else:
raise SparkConnectGrpcException(str(rpc_error)) from None
- def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file:
bool) -> None:
- self._artifact_manager.add_artifacts(*path, pyfile=pyfile,
archive=archive, file=file)
+ def add_artifacts(self, *paths: str, pyfile: bool, archive: bool, file:
bool) -> None:
+ for path in paths:
+ for attempt in self._retrying():
+ with attempt:
+ self._artifact_manager.add_artifacts(
+ path, pyfile=pyfile, archive=archive, file=file
+ )
def copy_from_local_to_fs(self, local_path: str, dest_path: str) -> None:
- self._artifact_manager._add_forward_to_fs_artifacts(local_path,
dest_path)
+ for attempt in self._retrying():
+ with attempt:
+
self._artifact_manager._add_forward_to_fs_artifacts(local_path, dest_path)
def cache_artifact(self, blob: bytes) -> str:
- return self._artifact_manager.cache_artifact(blob)
+ for attempt in self._retrying():
+ with attempt:
+ return self._artifact_manager.cache_artifact(blob)
+ raise SparkConnectException("Invalid state during retry exception
handling.")
class RetryState:
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py
b/python/pyspark/sql/tests/connect/client/test_artifact.py
index d45230e926b1..7e9f9dbbf569 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -388,6 +388,13 @@ class ArtifactTests(ReusedConnectTestCase,
ArtifactTestsMixin):
self.assertEqual(actualHash, expected_hash)
self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True)
+ def test_add_not_existing_artifact(self):
+ with tempfile.TemporaryDirectory() as d:
+ with self.assertRaises(FileNotFoundError):
+ self.artifact_manager.add_artifacts(
+ os.path.join(d, "not_existing"), file=True, pyfile=False,
archive=False
+ )
+
class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin):
@classmethod
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]