itholic commented on code in PR #39695:
URL: https://github.com/apache/spark/pull/39695#discussion_r1083547431
##########
python/pyspark/sql/connect/client.py:
##########
@@ -567,54 +602,48 @@ def _execute_and_fetch(
logger.info("ExecuteAndFetch")
m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
batches: List[pa.RecordBatch] = []
try:
- for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
- if b.client_id != self._session_id:
- raise SparkConnectException(
- "Received incorrect session identifier for request."
- )
- if b.metrics is not None:
- logger.debug("Received metric batch.")
- m = b.metrics
- if b.HasField("arrow_batch"):
- logger.debug(
- f"Received arrow batch rows={b.arrow_batch.row_count} "
- f"size={len(b.arrow_batch.data)}"
- )
-
- with pa.ipc.open_stream(b.arrow_batch.data) as reader:
- for batch in reader:
- assert isinstance(batch, pa.RecordBatch)
- batches.append(batch)
+ for attempt in Retrying(SparkConnectClient.retry_exception,
**self._retry_policy):
+ with attempt:
+ for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
+ if b.client_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request."
Review Comment:
ditto ?
##########
python/pyspark/sql/connect/client.py:
##########
@@ -640,6 +669,136 @@ def _handle_error(self, rpc_error: grpc.RpcError) ->
NoReturn:
raise SparkConnectException(str(rpc_error)) from None
+class RetryState:
+ """
+ Simple state helper that captures the state between retries of the
exceptions. It
+ keeps track of the last exception thrown and how many in total. when the
task
+ finishes successfully done() returns True.
+ """
+
+ def __init__(self) -> None:
+ self._exception: Optional[BaseException] = None
+ self._done = False
+ self._count = 0
+
+ def set_exception(self, exc: Optional[BaseException]) -> None:
+ self._exception = exc
+ self._count += 1
+
+ def exception(self) -> Optional[BaseException]:
+ return self._exception
+
+ def set_done(self) -> None:
+ self._done = True
+
+ def count(self) -> int:
+ return self._count
+
+ def done(self) -> bool:
+ return self._done
+
+
+class AttemptManager:
+ """
+ Simple ContextManager that is used to capture the exception thrown inside
the context.
+ """
+
+ def __init__(self, check: Callable[..., bool], retry_state: RetryState) ->
None:
+ self._retry_state = retry_state
+ self._can_retry = check
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ if isinstance(exc_val, BaseException):
+ # Swallow the exception.
+ if self._can_retry(exc_val):
+ self._retry_state.set_exception(exc_val)
+ return True
+ # Bubble up the exception.
+ return False
+ else:
+ self._retry_state.set_done()
+ return None
+
+
+class Retrying:
+ """
+ This helper class is used as a generator together with a context manager to
+ allow retrying exceptions in particular code blocks. The Retrying can be
configured
+ with a lambda function that is can be filtered what kind of exceptions
should be
+ retried.
+
+ In addition, there are several parameters that are used to configure the
exponential
+ backoff behavior.
+
+ An example to use this class looks like this:
+
+ for attempt in Retrying(lambda x: isinstance(x, TransientError)):
+ with attempt:
+ # do the work.
+
Review Comment:
I think we can use `.. code-block:: python` here in docstring for the better
example as below:
```
An example to use this class looks like this:
.. code-block:: python
for attempt in Retrying(lambda x: isinstance(x, TransientError)):
with attempt:
# do the work.
```
##########
python/pyspark/sql/tests/connect/test_connect_basic.py:
##########
@@ -2591,6 +2591,73 @@ def test_unsupported_io_functions(self):
getattr(df.write, f)()
[email protected](not should_test_connect, connect_requirement_message)
+class ClientTests(unittest.TestCase):
+ def test_retry_error_handling(self):
+ # Helper class for wrapping the test.
+ class TestError(grpc.RpcError, Exception):
+ def __init__(self, code: grpc.StatusCode):
+ self._code = code
+
+ def code(self):
+ return self._code
+
+ def stub(retries, w, code):
+ w["counter"] += 1
+ if w["counter"] < retries:
+ raise TestError(code)
+
+ from pyspark.sql.connect.client import Retrying
Review Comment:
Can we move the imports to the at the top of the tests? Just per
https://peps.python.org/pep-0008/#imports.
##########
python/pyspark/sql/connect/client.py:
##########
@@ -365,6 +385,15 @@ def __init__(
# Parse the connection string.
self._builder = ChannelBuilder(connectionString, channelOptions)
self._user_id = None
+ self._retry_policy = {
+ "max_retries": 15,
+ "backoff_multiplier": 4,
+ "initial_backoff": 50,
+ "max_backoff": 60000,
+ }
Review Comment:
qq: just out of curiosity, where are the numbers from? Is it an arbitrarily
set value?
##########
python/pyspark/sql/connect/client.py:
##########
@@ -567,54 +602,48 @@ def _execute_and_fetch(
logger.info("ExecuteAndFetch")
m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
batches: List[pa.RecordBatch] = []
try:
- for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
- if b.client_id != self._session_id:
- raise SparkConnectException(
- "Received incorrect session identifier for request."
- )
- if b.metrics is not None:
- logger.debug("Received metric batch.")
- m = b.metrics
- if b.HasField("arrow_batch"):
- logger.debug(
- f"Received arrow batch rows={b.arrow_batch.row_count} "
- f"size={len(b.arrow_batch.data)}"
- )
-
- with pa.ipc.open_stream(b.arrow_batch.data) as reader:
- for batch in reader:
- assert isinstance(batch, pa.RecordBatch)
- batches.append(batch)
+ for attempt in Retrying(SparkConnectClient.retry_exception,
**self._retry_policy):
+ with attempt:
+ for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
+ if b.client_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request."
+ )
+ if b.metrics is not None:
+ logger.debug("Received metric batch.")
+ m = b.metrics
+ if b.HasField("arrow_batch"):
+ logger.debug(
+ f"Received arrow batch
rows={b.arrow_batch.row_count} "
+ f"size={len(b.arrow_batch.data)}"
+ )
+
+ with pa.ipc.open_stream(b.arrow_batch.data) as
reader:
+ for batch in reader:
+ assert isinstance(batch, pa.RecordBatch)
+ batches.append(batch)
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
-
assert len(batches) > 0
-
table = pa.Table.from_batches(batches=batches)
-
metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None
else []
-
return table, metrics
def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server
side, certain
exceptions are enriched with additional RPC Status information. These
are
unpacked in this function and put into the exception.
-
To avoid overloading the user with GRPC errors, this message explicitly
swallows the error context from the call. This GRPC Error is logged
however,
and can be enabled.
-
Parameters
----------
rpc_error : grpc.RpcError
RPC Error containing the details of the exception.
-
Review Comment:
ditto
##########
python/pyspark/sql/connect/client.py:
##########
@@ -531,12 +560,16 @@ def _analyze(self, plan: pb2.Plan, explain_mode: str =
"extended") -> AnalyzeRes
req.explain.explain_mode = pb2.Explain.ExplainMode.CODEGEN
else: # formatted
req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED
-
try:
- resp = self._stub.AnalyzePlan(req,
metadata=self._builder.metadata())
- if resp.client_id != self._session_id:
- raise SparkConnectException("Received incorrect session
identifier for request.")
- return AnalyzeResult.fromProto(resp)
+ for attempt in Retrying(SparkConnectClient.retry_exception,
**self._retry_policy):
+ with attempt:
+ resp = self._stub.AnalyzePlan(req,
metadata=self._builder.metadata())
+ if resp.client_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request."
Review Comment:
Maybe providing `resp.client_id` and `self._session_id` into error messages
would helpful? Or maybe is it just verbose?
##########
python/pyspark/sql/connect/client.py:
##########
@@ -567,54 +602,48 @@ def _execute_and_fetch(
logger.info("ExecuteAndFetch")
m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
batches: List[pa.RecordBatch] = []
try:
- for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
- if b.client_id != self._session_id:
- raise SparkConnectException(
- "Received incorrect session identifier for request."
- )
- if b.metrics is not None:
- logger.debug("Received metric batch.")
- m = b.metrics
- if b.HasField("arrow_batch"):
- logger.debug(
- f"Received arrow batch rows={b.arrow_batch.row_count} "
- f"size={len(b.arrow_batch.data)}"
- )
-
- with pa.ipc.open_stream(b.arrow_batch.data) as reader:
- for batch in reader:
- assert isinstance(batch, pa.RecordBatch)
- batches.append(batch)
+ for attempt in Retrying(SparkConnectClient.retry_exception,
**self._retry_policy):
+ with attempt:
+ for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
+ if b.client_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request."
+ )
+ if b.metrics is not None:
+ logger.debug("Received metric batch.")
+ m = b.metrics
+ if b.HasField("arrow_batch"):
+ logger.debug(
+ f"Received arrow batch
rows={b.arrow_batch.row_count} "
+ f"size={len(b.arrow_batch.data)}"
+ )
+
+ with pa.ipc.open_stream(b.arrow_batch.data) as
reader:
+ for batch in reader:
+ assert isinstance(batch, pa.RecordBatch)
+ batches.append(batch)
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
-
assert len(batches) > 0
-
table = pa.Table.from_batches(batches=batches)
-
metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None
else []
-
return table, metrics
def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server
side, certain
exceptions are enriched with additional RPC Status information. These
are
unpacked in this function and put into the exception.
-
Review Comment:
Maybe mistake? I think having a newline here would be better for readability.
##########
python/pyspark/sql/connect/client.py:
##########
@@ -567,54 +602,48 @@ def _execute_and_fetch(
logger.info("ExecuteAndFetch")
m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
batches: List[pa.RecordBatch] = []
try:
- for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
- if b.client_id != self._session_id:
- raise SparkConnectException(
- "Received incorrect session identifier for request."
- )
- if b.metrics is not None:
- logger.debug("Received metric batch.")
- m = b.metrics
- if b.HasField("arrow_batch"):
- logger.debug(
- f"Received arrow batch rows={b.arrow_batch.row_count} "
- f"size={len(b.arrow_batch.data)}"
- )
-
- with pa.ipc.open_stream(b.arrow_batch.data) as reader:
- for batch in reader:
- assert isinstance(batch, pa.RecordBatch)
- batches.append(batch)
+ for attempt in Retrying(SparkConnectClient.retry_exception,
**self._retry_policy):
+ with attempt:
+ for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
+ if b.client_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request."
+ )
+ if b.metrics is not None:
+ logger.debug("Received metric batch.")
+ m = b.metrics
+ if b.HasField("arrow_batch"):
+ logger.debug(
+ f"Received arrow batch
rows={b.arrow_batch.row_count} "
+ f"size={len(b.arrow_batch.data)}"
+ )
+
+ with pa.ipc.open_stream(b.arrow_batch.data) as
reader:
+ for batch in reader:
+ assert isinstance(batch, pa.RecordBatch)
+ batches.append(batch)
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
-
assert len(batches) > 0
-
table = pa.Table.from_batches(batches=batches)
-
metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None
else []
-
return table, metrics
def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server
side, certain
exceptions are enriched with additional RPC Status information. These
are
unpacked in this function and put into the exception.
-
To avoid overloading the user with GRPC errors, this message explicitly
swallows the error context from the call. This GRPC Error is logged
however,
and can be enabled.
-
Review Comment:
ditto
##########
python/pyspark/sql/connect/client.py:
##########
@@ -640,6 +669,136 @@ def _handle_error(self, rpc_error: grpc.RpcError) ->
NoReturn:
raise SparkConnectException(str(rpc_error)) from None
+class RetryState:
+ """
+ Simple state helper that captures the state between retries of the
exceptions. It
+ keeps track of the last exception thrown and how many in total. when the
task
+ finishes successfully done() returns True.
+ """
+
+ def __init__(self) -> None:
+ self._exception: Optional[BaseException] = None
+ self._done = False
+ self._count = 0
+
+ def set_exception(self, exc: Optional[BaseException]) -> None:
+ self._exception = exc
+ self._count += 1
+
+ def exception(self) -> Optional[BaseException]:
+ return self._exception
+
+ def set_done(self) -> None:
+ self._done = True
+
+ def count(self) -> int:
+ return self._count
+
+ def done(self) -> bool:
+ return self._done
+
+
+class AttemptManager:
+ """
+ Simple ContextManager that is used to capture the exception thrown inside
the context.
+ """
+
+ def __init__(self, check: Callable[..., bool], retry_state: RetryState) ->
None:
+ self._retry_state = retry_state
+ self._can_retry = check
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ if isinstance(exc_val, BaseException):
+ # Swallow the exception.
+ if self._can_retry(exc_val):
+ self._retry_state.set_exception(exc_val)
+ return True
+ # Bubble up the exception.
+ return False
+ else:
+ self._retry_state.set_done()
+ return None
+
+
+class Retrying:
+ """
+ This helper class is used as a generator together with a context manager to
+ allow retrying exceptions in particular code blocks. The Retrying can be
configured
+ with a lambda function that is can be filtered what kind of exceptions
should be
+ retried.
+
+ In addition, there are several parameters that are used to configure the
exponential
+ backoff behavior.
+
+ An example to use this class looks like this:
+
+ for attempt in Retrying(lambda x: isinstance(x, TransientError)):
+ with attempt:
+ # do the work.
+
+ """
+
+ def __init__(
+ self,
+ can_retry: Callable[..., bool] = lambda x: True,
+ max_retries: int = 15,
+ initial_backoff: int = 50,
+ max_backoff: int = 60000,
+ backoff_multiplier: float = 4.0,
+ ) -> None:
+ self._can_retry = can_retry
+ self._max_retries = max_retries
+ self._initial_backoff = initial_backoff
+ self._max_backoff = max_backoff
+ self._backoff_multiplier = backoff_multiplier
+
+ def __iter__(self) -> Generator[AttemptManager, None, None]:
+ """
+ Generator function to wrap the exception producing code block.
+ Returns
+ -------
+
Review Comment:
Seems like something is missing here?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]