itholic commented on code in PR #39695:
URL: https://github.com/apache/spark/pull/39695#discussion_r1083547431


##########
python/pyspark/sql/connect/client.py:
##########
@@ -567,54 +602,48 @@ def _execute_and_fetch(
         logger.info("ExecuteAndFetch")
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
         batches: List[pa.RecordBatch] = []
 
         try:
-            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                if b.client_id != self._session_id:
-                    raise SparkConnectException(
-                        "Received incorrect session identifier for request."
-                    )
-                if b.metrics is not None:
-                    logger.debug("Received metric batch.")
-                    m = b.metrics
-                if b.HasField("arrow_batch"):
-                    logger.debug(
-                        f"Received arrow batch rows={b.arrow_batch.row_count} "
-                        f"size={len(b.arrow_batch.data)}"
-                    )
-
-                    with pa.ipc.open_stream(b.arrow_batch.data) as reader:
-                        for batch in reader:
-                            assert isinstance(batch, pa.RecordBatch)
-                            batches.append(batch)
+            for attempt in Retrying(SparkConnectClient.retry_exception, 
**self._retry_policy):
+                with attempt:
+                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                        if b.client_id != self._session_id:
+                            raise SparkConnectException(
+                                "Received incorrect session identifier for 
request."

Review Comment:
   ditto ?



##########
python/pyspark/sql/connect/client.py:
##########
@@ -640,6 +669,136 @@ def _handle_error(self, rpc_error: grpc.RpcError) -> 
NoReturn:
             raise SparkConnectException(str(rpc_error)) from None
 
 
+class RetryState:
+    """
+    Simple state helper that captures the state between retries of the 
exceptions. It
+    keeps track of the last exception thrown and how many in total. when the 
task
+    finishes successfully done() returns True.
+    """
+
+    def __init__(self) -> None:
+        self._exception: Optional[BaseException] = None
+        self._done = False
+        self._count = 0
+
+    def set_exception(self, exc: Optional[BaseException]) -> None:
+        self._exception = exc
+        self._count += 1
+
+    def exception(self) -> Optional[BaseException]:
+        return self._exception
+
+    def set_done(self) -> None:
+        self._done = True
+
+    def count(self) -> int:
+        return self._count
+
+    def done(self) -> bool:
+        return self._done
+
+
+class AttemptManager:
+    """
+    Simple ContextManager that is used to capture the exception thrown inside 
the context.
+    """
+
+    def __init__(self, check: Callable[..., bool], retry_state: RetryState) -> 
None:
+        self._retry_state = retry_state
+        self._can_retry = check
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        if isinstance(exc_val, BaseException):
+            # Swallow the exception.
+            if self._can_retry(exc_val):
+                self._retry_state.set_exception(exc_val)
+                return True
+            # Bubble up the exception.
+            return False
+        else:
+            self._retry_state.set_done()
+            return None
+
+
+class Retrying:
+    """
+    This helper class is used as a generator together with a context manager to
+    allow retrying exceptions in particular code blocks. The Retrying can be 
configured
+    with a lambda function that is can be filtered what kind of exceptions 
should be
+    retried.
+
+    In addition, there are several parameters that are used to configure the 
exponential
+    backoff behavior.
+
+    An example to use this class looks like this:
+
+        for attempt in Retrying(lambda x: isinstance(x, TransientError)):
+            with attempt:
+                # do the work.
+

Review Comment:
   I think we can use `.. code-block:: python` here in docstring for the better 
example as below:
   ```
   An example to use this class looks like this:
   
   .. code-block:: python
   
           for attempt in Retrying(lambda x: isinstance(x, TransientError)):
               with attempt:
                   # do the work.
   ```



##########
python/pyspark/sql/tests/connect/test_connect_basic.py:
##########
@@ -2591,6 +2591,73 @@ def test_unsupported_io_functions(self):
                 getattr(df.write, f)()
 
 
[email protected](not should_test_connect, connect_requirement_message)
+class ClientTests(unittest.TestCase):
+    def test_retry_error_handling(self):
+        # Helper class for wrapping the test.
+        class TestError(grpc.RpcError, Exception):
+            def __init__(self, code: grpc.StatusCode):
+                self._code = code
+
+            def code(self):
+                return self._code
+
+        def stub(retries, w, code):
+            w["counter"] += 1
+            if w["counter"] < retries:
+                raise TestError(code)
+
+        from pyspark.sql.connect.client import Retrying

Review Comment:
   Can we move the imports to the at the top of the tests? Just per 
https://peps.python.org/pep-0008/#imports.



##########
python/pyspark/sql/connect/client.py:
##########
@@ -365,6 +385,15 @@ def __init__(
         # Parse the connection string.
         self._builder = ChannelBuilder(connectionString, channelOptions)
         self._user_id = None
+        self._retry_policy = {
+            "max_retries": 15,
+            "backoff_multiplier": 4,
+            "initial_backoff": 50,
+            "max_backoff": 60000,
+        }

Review Comment:
   qq: just out of curiosity, where are the numbers from? Is it an arbitrarily 
set value?



##########
python/pyspark/sql/connect/client.py:
##########
@@ -567,54 +602,48 @@ def _execute_and_fetch(
         logger.info("ExecuteAndFetch")
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
         batches: List[pa.RecordBatch] = []
 
         try:
-            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                if b.client_id != self._session_id:
-                    raise SparkConnectException(
-                        "Received incorrect session identifier for request."
-                    )
-                if b.metrics is not None:
-                    logger.debug("Received metric batch.")
-                    m = b.metrics
-                if b.HasField("arrow_batch"):
-                    logger.debug(
-                        f"Received arrow batch rows={b.arrow_batch.row_count} "
-                        f"size={len(b.arrow_batch.data)}"
-                    )
-
-                    with pa.ipc.open_stream(b.arrow_batch.data) as reader:
-                        for batch in reader:
-                            assert isinstance(batch, pa.RecordBatch)
-                            batches.append(batch)
+            for attempt in Retrying(SparkConnectClient.retry_exception, 
**self._retry_policy):
+                with attempt:
+                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                        if b.client_id != self._session_id:
+                            raise SparkConnectException(
+                                "Received incorrect session identifier for 
request."
+                            )
+                        if b.metrics is not None:
+                            logger.debug("Received metric batch.")
+                            m = b.metrics
+                        if b.HasField("arrow_batch"):
+                            logger.debug(
+                                f"Received arrow batch 
rows={b.arrow_batch.row_count} "
+                                f"size={len(b.arrow_batch.data)}"
+                            )
+
+                            with pa.ipc.open_stream(b.arrow_batch.data) as 
reader:
+                                for batch in reader:
+                                    assert isinstance(batch, pa.RecordBatch)
+                                    batches.append(batch)
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
-
         assert len(batches) > 0
-
         table = pa.Table.from_batches(batches=batches)
-
         metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None 
else []
-
         return table, metrics
 
     def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
         """
         Error handling helper for dealing with GRPC Errors. On the server 
side, certain
         exceptions are enriched with additional RPC Status information. These 
are
         unpacked in this function and put into the exception.
-
         To avoid overloading the user with GRPC errors, this message explicitly
         swallows the error context from the call. This GRPC Error is logged 
however,
         and can be enabled.
-
         Parameters
         ----------
         rpc_error : grpc.RpcError
            RPC Error containing the details of the exception.
-

Review Comment:
   ditto



##########
python/pyspark/sql/connect/client.py:
##########
@@ -531,12 +560,16 @@ def _analyze(self, plan: pb2.Plan, explain_mode: str = 
"extended") -> AnalyzeRes
             req.explain.explain_mode = pb2.Explain.ExplainMode.CODEGEN
         else:  # formatted
             req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED
-
         try:
-            resp = self._stub.AnalyzePlan(req, 
metadata=self._builder.metadata())
-            if resp.client_id != self._session_id:
-                raise SparkConnectException("Received incorrect session 
identifier for request.")
-            return AnalyzeResult.fromProto(resp)
+            for attempt in Retrying(SparkConnectClient.retry_exception, 
**self._retry_policy):
+                with attempt:
+                    resp = self._stub.AnalyzePlan(req, 
metadata=self._builder.metadata())
+                    if resp.client_id != self._session_id:
+                        raise SparkConnectException(
+                            "Received incorrect session identifier for 
request."

Review Comment:
   Maybe providing `resp.client_id` and `self._session_id` into error messages 
would helpful? Or maybe is it just verbose?



##########
python/pyspark/sql/connect/client.py:
##########
@@ -567,54 +602,48 @@ def _execute_and_fetch(
         logger.info("ExecuteAndFetch")
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
         batches: List[pa.RecordBatch] = []
 
         try:
-            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                if b.client_id != self._session_id:
-                    raise SparkConnectException(
-                        "Received incorrect session identifier for request."
-                    )
-                if b.metrics is not None:
-                    logger.debug("Received metric batch.")
-                    m = b.metrics
-                if b.HasField("arrow_batch"):
-                    logger.debug(
-                        f"Received arrow batch rows={b.arrow_batch.row_count} "
-                        f"size={len(b.arrow_batch.data)}"
-                    )
-
-                    with pa.ipc.open_stream(b.arrow_batch.data) as reader:
-                        for batch in reader:
-                            assert isinstance(batch, pa.RecordBatch)
-                            batches.append(batch)
+            for attempt in Retrying(SparkConnectClient.retry_exception, 
**self._retry_policy):
+                with attempt:
+                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                        if b.client_id != self._session_id:
+                            raise SparkConnectException(
+                                "Received incorrect session identifier for 
request."
+                            )
+                        if b.metrics is not None:
+                            logger.debug("Received metric batch.")
+                            m = b.metrics
+                        if b.HasField("arrow_batch"):
+                            logger.debug(
+                                f"Received arrow batch 
rows={b.arrow_batch.row_count} "
+                                f"size={len(b.arrow_batch.data)}"
+                            )
+
+                            with pa.ipc.open_stream(b.arrow_batch.data) as 
reader:
+                                for batch in reader:
+                                    assert isinstance(batch, pa.RecordBatch)
+                                    batches.append(batch)
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
-
         assert len(batches) > 0
-
         table = pa.Table.from_batches(batches=batches)
-
         metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None 
else []
-
         return table, metrics
 
     def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
         """
         Error handling helper for dealing with GRPC Errors. On the server 
side, certain
         exceptions are enriched with additional RPC Status information. These 
are
         unpacked in this function and put into the exception.
-

Review Comment:
   Maybe mistake? I think having a newline here would be better for readability.



##########
python/pyspark/sql/connect/client.py:
##########
@@ -567,54 +602,48 @@ def _execute_and_fetch(
         logger.info("ExecuteAndFetch")
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
         batches: List[pa.RecordBatch] = []
 
         try:
-            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                if b.client_id != self._session_id:
-                    raise SparkConnectException(
-                        "Received incorrect session identifier for request."
-                    )
-                if b.metrics is not None:
-                    logger.debug("Received metric batch.")
-                    m = b.metrics
-                if b.HasField("arrow_batch"):
-                    logger.debug(
-                        f"Received arrow batch rows={b.arrow_batch.row_count} "
-                        f"size={len(b.arrow_batch.data)}"
-                    )
-
-                    with pa.ipc.open_stream(b.arrow_batch.data) as reader:
-                        for batch in reader:
-                            assert isinstance(batch, pa.RecordBatch)
-                            batches.append(batch)
+            for attempt in Retrying(SparkConnectClient.retry_exception, 
**self._retry_policy):
+                with attempt:
+                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                        if b.client_id != self._session_id:
+                            raise SparkConnectException(
+                                "Received incorrect session identifier for 
request."
+                            )
+                        if b.metrics is not None:
+                            logger.debug("Received metric batch.")
+                            m = b.metrics
+                        if b.HasField("arrow_batch"):
+                            logger.debug(
+                                f"Received arrow batch 
rows={b.arrow_batch.row_count} "
+                                f"size={len(b.arrow_batch.data)}"
+                            )
+
+                            with pa.ipc.open_stream(b.arrow_batch.data) as 
reader:
+                                for batch in reader:
+                                    assert isinstance(batch, pa.RecordBatch)
+                                    batches.append(batch)
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
-
         assert len(batches) > 0
-
         table = pa.Table.from_batches(batches=batches)
-
         metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None 
else []
-
         return table, metrics
 
     def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
         """
         Error handling helper for dealing with GRPC Errors. On the server 
side, certain
         exceptions are enriched with additional RPC Status information. These 
are
         unpacked in this function and put into the exception.
-
         To avoid overloading the user with GRPC errors, this message explicitly
         swallows the error context from the call. This GRPC Error is logged 
however,
         and can be enabled.
-

Review Comment:
   ditto



##########
python/pyspark/sql/connect/client.py:
##########
@@ -640,6 +669,136 @@ def _handle_error(self, rpc_error: grpc.RpcError) -> 
NoReturn:
             raise SparkConnectException(str(rpc_error)) from None
 
 
+class RetryState:
+    """
+    Simple state helper that captures the state between retries of the 
exceptions. It
+    keeps track of the last exception thrown and how many in total. when the 
task
+    finishes successfully done() returns True.
+    """
+
+    def __init__(self) -> None:
+        self._exception: Optional[BaseException] = None
+        self._done = False
+        self._count = 0
+
+    def set_exception(self, exc: Optional[BaseException]) -> None:
+        self._exception = exc
+        self._count += 1
+
+    def exception(self) -> Optional[BaseException]:
+        return self._exception
+
+    def set_done(self) -> None:
+        self._done = True
+
+    def count(self) -> int:
+        return self._count
+
+    def done(self) -> bool:
+        return self._done
+
+
+class AttemptManager:
+    """
+    Simple ContextManager that is used to capture the exception thrown inside 
the context.
+    """
+
+    def __init__(self, check: Callable[..., bool], retry_state: RetryState) -> 
None:
+        self._retry_state = retry_state
+        self._can_retry = check
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        if isinstance(exc_val, BaseException):
+            # Swallow the exception.
+            if self._can_retry(exc_val):
+                self._retry_state.set_exception(exc_val)
+                return True
+            # Bubble up the exception.
+            return False
+        else:
+            self._retry_state.set_done()
+            return None
+
+
+class Retrying:
+    """
+    This helper class is used as a generator together with a context manager to
+    allow retrying exceptions in particular code blocks. The Retrying can be 
configured
+    with a lambda function that is can be filtered what kind of exceptions 
should be
+    retried.
+
+    In addition, there are several parameters that are used to configure the 
exponential
+    backoff behavior.
+
+    An example to use this class looks like this:
+
+        for attempt in Retrying(lambda x: isinstance(x, TransientError)):
+            with attempt:
+                # do the work.
+
+    """
+
+    def __init__(
+        self,
+        can_retry: Callable[..., bool] = lambda x: True,
+        max_retries: int = 15,
+        initial_backoff: int = 50,
+        max_backoff: int = 60000,
+        backoff_multiplier: float = 4.0,
+    ) -> None:
+        self._can_retry = can_retry
+        self._max_retries = max_retries
+        self._initial_backoff = initial_backoff
+        self._max_backoff = max_backoff
+        self._backoff_multiplier = backoff_multiplier
+
+    def __iter__(self) -> Generator[AttemptManager, None, None]:
+        """
+        Generator function to wrap the exception producing code block.
+        Returns
+        -------
+

Review Comment:
   Seems like something is missing here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to