This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new d646ca3de6f [SPARK-42156][CONNECT] SparkConnectClient supports 
RetryPolicies now
d646ca3de6f is described below

commit d646ca3de6f81bb32f4ec5d3f19176c4be445de6
Author: Martin Grund <[email protected]>
AuthorDate: Tue Jan 31 16:05:18 2023 +0900

    [SPARK-42156][CONNECT] SparkConnectClient supports RetryPolicies now
    
    ### What changes were proposed in this pull request?
    By default all exceptions thrown via the GRPC service are treated as fatal. 
However, this is not always the case and exceptions that are returning an GRPC 
Status code of "UNAVAILABLE" should be retried.
    
    This patch adds support for retrying "UNAVAILABLE" errors and builds a 
default retry policy in the Python client that can be further modified and 
configured if necessary. The retry policy is a simple exponential backoff 
policy with jitter.
    
    ### Why are the changes needed?
    Stability
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #39695 from grundprinzip/retry_policies.
    
    Lead-authored-by: Martin Grund <[email protected]>
    Co-authored-by: Martin Grund <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 068111f61fae87afd1f81521a47897a28436be3f)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client.py               | 246 ++++++++++++++++++---
 .../sql/tests/connect/test_connect_basic.py        | 113 ++++++++++
 2 files changed, 323 insertions(+), 36 deletions(-)

diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index efc970d6a4c..2da63a8add9 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -14,13 +14,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
 import logging
 import os
+import random
+import time
 import urllib.parse
 import uuid
 import json
-from typing import Iterable, Optional, Any, Union, List, Tuple, Dict, 
NoReturn, cast
+from types import TracebackType
+from typing import (
+    Iterable,
+    Optional,
+    Any,
+    Union,
+    List,
+    Tuple,
+    Dict,
+    NoReturn,
+    cast,
+    Callable,
+    Generator,
+    Type,
+)
 
 import pandas as pd
 import pyarrow as pa
@@ -327,11 +342,16 @@ class AnalyzeResult:
 class SparkConnectClient(object):
     """Conceptually the remote spark session that communicates with the 
server"""
 
+    @classmethod
+    def retry_exception(cls, e: grpc.RpcError) -> bool:
+        return e.code() == grpc.StatusCode.UNAVAILABLE
+
     def __init__(
         self,
         connectionString: str,
         userId: Optional[str] = None,
         channelOptions: Optional[List[Tuple[str, Any]]] = None,
+        retryPolicy: Optional[Dict[str, Any]] = None,
     ):
         """
         Creates a new SparkSession for the Spark Connect interface.
@@ -350,6 +370,15 @@ class SparkConnectClient(object):
         # Parse the connection string.
         self._builder = ChannelBuilder(connectionString, channelOptions)
         self._user_id = None
+        self._retry_policy = {
+            "max_retries": 15,
+            "backoff_multiplier": 4,
+            "initial_backoff": 50,
+            "max_backoff": 60000,
+        }
+        if retryPolicy:
+            self._retry_policy.update(retryPolicy)
+
         # Generate a unique session ID for this client. This UUID must be 
unique to allow
         # concurrent Spark sessions of the same user. If the channel is 
closed, creating
         # a new client will create a new session ID.
@@ -516,12 +545,19 @@ class SparkConnectClient(object):
             req.explain.explain_mode = pb2.Explain.ExplainMode.CODEGEN
         else:  # formatted
             req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED
-
         try:
-            resp = self._stub.AnalyzePlan(req, 
metadata=self._builder.metadata())
-            if resp.client_id != self._session_id:
-                raise SparkConnectException("Received incorrect session 
identifier for request.")
-            return AnalyzeResult.fromProto(resp)
+            for attempt in Retrying(
+                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+            ):
+                with attempt:
+                    resp = self._stub.AnalyzePlan(req, 
metadata=self._builder.metadata())
+                    if resp.client_id != self._session_id:
+                        raise SparkConnectException(
+                            "Received incorrect session identifier for 
request:"
+                            f"{resp.client_id} != {self._session_id}"
+                        )
+                    return AnalyzeResult.fromProto(resp)
+            raise SparkConnectException("Invalid state during retry exception 
handling.")
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
 
@@ -537,12 +573,16 @@ class SparkConnectClient(object):
         """
         logger.info("Execute")
         try:
-            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                if b.client_id != self._session_id:
-                    raise SparkConnectException(
-                        "Received incorrect session identifier for request."
-                    )
-                continue
+            for attempt in Retrying(
+                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+            ):
+                with attempt:
+                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                        if b.client_id != self._session_id:
+                            raise SparkConnectException(
+                                "Received incorrect session identifier for 
request: "
+                                f"{b.client_id} != {self._session_id}"
+                            )
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
 
@@ -552,37 +592,38 @@ class SparkConnectClient(object):
         logger.info("ExecuteAndFetch")
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-
         batches: List[pa.RecordBatch] = []
 
         try:
-            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-                if b.client_id != self._session_id:
-                    raise SparkConnectException(
-                        "Received incorrect session identifier for request."
-                    )
-                if b.metrics is not None:
-                    logger.debug("Received metric batch.")
-                    m = b.metrics
-                if b.HasField("arrow_batch"):
-                    logger.debug(
-                        f"Received arrow batch rows={b.arrow_batch.row_count} "
-                        f"size={len(b.arrow_batch.data)}"
-                    )
-
-                    with pa.ipc.open_stream(b.arrow_batch.data) as reader:
-                        for batch in reader:
-                            assert isinstance(batch, pa.RecordBatch)
-                            batches.append(batch)
+            for attempt in Retrying(
+                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+            ):
+                with attempt:
+                    batches = []
+                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                        if b.client_id != self._session_id:
+                            raise SparkConnectException(
+                                "Received incorrect session identifier for 
request: "
+                                f"{b.client_id} != {self._session_id}"
+                            )
+                        if b.metrics is not None:
+                            logger.debug("Received metric batch.")
+                            m = b.metrics
+                        if b.HasField("arrow_batch"):
+                            logger.debug(
+                                f"Received arrow batch 
rows={b.arrow_batch.row_count} "
+                                f"size={len(b.arrow_batch.data)}"
+                            )
+
+                            with pa.ipc.open_stream(b.arrow_batch.data) as 
reader:
+                                for batch in reader:
+                                    assert isinstance(batch, pa.RecordBatch)
+                                    batches.append(batch)
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
-
         assert len(batches) > 0
-
         table = pa.Table.from_batches(batches=batches)
-
         metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None 
else []
-
         return table, metrics
 
     def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
@@ -641,6 +682,139 @@ class SparkConnectClient(object):
             raise SparkConnectGrpcException(str(rpc_error)) from None
 
 
+class RetryState:
+    """
+    Simple state helper that captures the state between retries of the 
exceptions. It
+    keeps track of the last exception thrown and how many in total. When the 
task
+    finishes successfully done() returns True.
+    """
+
+    def __init__(self) -> None:
+        self._exception: Optional[BaseException] = None
+        self._done = False
+        self._count = 0
+
+    def set_exception(self, exc: Optional[BaseException]) -> None:
+        self._exception = exc
+        self._count += 1
+
+    def exception(self) -> Optional[BaseException]:
+        return self._exception
+
+    def set_done(self) -> None:
+        self._done = True
+
+    def count(self) -> int:
+        return self._count
+
+    def done(self) -> bool:
+        return self._done
+
+
+class AttemptManager:
+    """
+    Simple ContextManager that is used to capture the exception thrown inside 
the context.
+    """
+
+    def __init__(self, check: Callable[..., bool], retry_state: RetryState) -> 
None:
+        self._retry_state = retry_state
+        self._can_retry = check
+
+    def __enter__(self) -> None:
+        pass
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        if isinstance(exc_val, BaseException):
+            # Swallow the exception.
+            if self._can_retry(exc_val):
+                self._retry_state.set_exception(exc_val)
+                return True
+            # Bubble up the exception.
+            return False
+        else:
+            self._retry_state.set_done()
+            return None
+
+
+class Retrying:
+    """
+    This helper class is used as a generator together with a context manager to
+    allow retrying exceptions in particular code blocks. The Retrying can be 
configured
+    with a lambda function that is can be filtered what kind of exceptions 
should be
+    retried.
+
+    In addition, there are several parameters that are used to configure the 
exponential
+    backoff behavior.
+
+    An example to use this class looks like this:
+
+    .. code-block:: python
+
+        for attempt in Retrying(can_retry=lambda x: isinstance(x, 
TransientError)):
+            with attempt:
+                # do the work.
+
+    """
+
+    def __init__(
+        self,
+        max_retries: int,
+        initial_backoff: int,
+        max_backoff: int,
+        backoff_multiplier: float,
+        can_retry: Callable[..., bool] = lambda x: True,
+    ) -> None:
+        self._can_retry = can_retry
+        self._max_retries = max_retries
+        self._initial_backoff = initial_backoff
+        self._max_backoff = max_backoff
+        self._backoff_multiplier = backoff_multiplier
+
+    def __iter__(self) -> Generator[AttemptManager, None, None]:
+        """
+        Generator function to wrap the exception producing code block.
+
+        Returns
+        -------
+        A generator that yields the current attempt.
+        """
+        retry_state = RetryState()
+        while True:
+            # Check if the operation was completed successfully.
+            if retry_state.done():
+                break
+
+            # If the number of retries have exceeded the maximum allowed 
retries.
+            if retry_state.count() > self._max_retries:
+                e = retry_state.exception()
+                if e is not None:
+                    raise e
+                else:
+                    raise ValueError("Retries exceeded but no exception 
caught.")
+
+            # Do backoff
+            if retry_state.count() > 0:
+                backoff = random.randrange(
+                    0,
+                    int(
+                        min(
+                            self._initial_backoff * self._backoff_multiplier 
** retry_state.count(),
+                            self._max_backoff,
+                        )
+                    ),
+                )
+                logger.debug(f"Retrying call after {backoff} ms sleep")
+                # Pythons sleep takes seconds as arguments.
+                time.sleep(backoff / 1000.0)
+
+            yield AttemptManager(self._can_retry, retry_state)
+
+
 __all__ = [
     "ChannelBuilder",
     "SparkConnectClient",
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index d51de331f7a..044740a881e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -21,10 +21,12 @@ import os
 import unittest
 import shutil
 import tempfile
+from collections import defaultdict
 
 from pyspark.errors import PySparkTypeError
 from pyspark.testing.sqlutils import SQLTestUtils
 from pyspark.sql import SparkSession as PySparkSession, Row
+from pyspark.sql.connect.client import Retrying
 from pyspark.sql.types import (
     StructType,
     StructField,
@@ -2625,6 +2627,117 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
                 getattr(df.write, f)()
 
 
+class ClientTests(unittest.TestCase):
+    def test_retry_error_handling(self):
+        # Helper class for wrapping the test.
+        class TestError(grpc.RpcError, Exception):
+            def __init__(self, code: grpc.StatusCode):
+                self._code = code
+
+            def code(self):
+                return self._code
+
+        def stub(retries, w, code):
+            w["attempts"] += 1
+            if w["attempts"] < retries:
+                w["raised"] += 1
+                raise TestError(code)
+
+        # Check that max_retries 1 is only one retry so two attempts.
+        call_wrap = defaultdict(int)
+        for attempt in Retrying(
+            can_retry=lambda x: True,
+            max_retries=1,
+            backoff_multiplier=1,
+            initial_backoff=1,
+            max_backoff=10,
+        ):
+            with attempt:
+                stub(2, call_wrap, grpc.StatusCode.INTERNAL)
+
+        self.assertEqual(2, call_wrap["attempts"])
+        self.assertEqual(1, call_wrap["raised"])
+
+        # Check that if we have less than 4 retries all is ok.
+        call_wrap = defaultdict(int)
+        for attempt in Retrying(
+            can_retry=lambda x: True,
+            max_retries=4,
+            backoff_multiplier=1,
+            initial_backoff=1,
+            max_backoff=10,
+        ):
+            with attempt:
+                stub(2, call_wrap, grpc.StatusCode.INTERNAL)
+
+        self.assertTrue(call_wrap["attempts"] < 4)
+        self.assertEqual(call_wrap["raised"], 1)
+
+        # Exceed the retries.
+        call_wrap = defaultdict(int)
+        with self.assertRaises(TestError):
+            for attempt in Retrying(
+                can_retry=lambda x: True,
+                max_retries=2,
+                max_backoff=50,
+                backoff_multiplier=1,
+                initial_backoff=50,
+            ):
+                with attempt:
+                    stub(5, call_wrap, grpc.StatusCode.INTERNAL)
+
+        self.assertTrue(call_wrap["attempts"] < 5)
+        self.assertEqual(call_wrap["raised"], 3)
+
+        # Check that only specific exceptions are retried.
+        # Check that if we have less than 4 retries all is ok.
+        call_wrap = defaultdict(int)
+        for attempt in Retrying(
+            can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
+            max_retries=4,
+            backoff_multiplier=1,
+            initial_backoff=1,
+            max_backoff=10,
+        ):
+            with attempt:
+                stub(2, call_wrap, grpc.StatusCode.UNAVAILABLE)
+
+        self.assertTrue(call_wrap["attempts"] < 4)
+        self.assertEqual(call_wrap["raised"], 1)
+
+        # Exceed the retries.
+        call_wrap = defaultdict(int)
+        with self.assertRaises(TestError):
+            for attempt in Retrying(
+                can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
+                max_retries=2,
+                max_backoff=50,
+                backoff_multiplier=1,
+                initial_backoff=50,
+            ):
+                with attempt:
+                    stub(5, call_wrap, grpc.StatusCode.UNAVAILABLE)
+
+        self.assertTrue(call_wrap["attempts"] < 4)
+        self.assertEqual(call_wrap["raised"], 3)
+
+        # Test that another error is always thrown.
+        call_wrap = defaultdict(int)
+        with self.assertRaises(TestError):
+            for attempt in Retrying(
+                can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
+                max_retries=4,
+                backoff_multiplier=1,
+                initial_backoff=1,
+                max_backoff=10,
+            ):
+                with attempt:
+                    stub(5, call_wrap, grpc.StatusCode.INTERNAL)
+
+        self.assertEqual(call_wrap["attempts"], 1)
+        self.assertEqual(call_wrap["raised"], 1)
+
+
 class ChannelBuilderTests(unittest.TestCase):
     def test_invalid_connection_strings(self):
         invalid = [


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to