This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f6a38cfcb3 [SPARK-45922][CONNECT][CLIENT] Minor retries refactoring 
(follow-up to multiple policies)
2f6a38cfcb3 is described below

commit 2f6a38cfcb384b4f504e1c08264887ae90d441bc
Author: Alice Sayutina <[email protected]>
AuthorDate: Sat Nov 25 09:52:27 2023 +0900

    [SPARK-45922][CONNECT][CLIENT] Minor retries refactoring (follow-up to 
multiple policies)
    
    ### What changes were proposed in this pull request?
    
    Follow up to https://github.com/apache/spark/pull/43591.
    
    Refactor default policy arguments into being an arguments on the class, not 
within core.py
    
    ### Why are the changes needed?
    General refactoring, also makes it easier for other policies to derive.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing coverage
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43800 from cdkrot/SPARK-45922.
    
    Authored-by: Alice Sayutina <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/sql/connect/client/RetryPolicy.scala     |  2 +-
 python/pyspark/sql/connect/client/core.py          | 19 ++---------
 python/pyspark/sql/connect/client/retries.py       | 37 +++++++++++++++++++---
 .../sql/tests/connect/client/test_client.py        |  3 +-
 4 files changed, 36 insertions(+), 25 deletions(-)

diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
index cb5b97f2e4a..8c8472d780d 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
@@ -55,7 +55,7 @@ object RetryPolicy {
   def defaultPolicy(): RetryPolicy = RetryPolicy(
     name = "DefaultPolicy",
     // Please synchronize changes here with Python side:
-    // pyspark/sql/connect/client/core.py
+    // pyspark/sql/connect/client/retries.py
     //
     // Note: these constants are selected so that the maximum tolerated wait 
is guaranteed
     // to be at least 10 minutes
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 58b48bd69ba..5d8db69c641 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -595,23 +595,8 @@ class SparkConnectClient(object):
         self._user_id = None
         self._retry_policies: List[RetryPolicy] = []
 
-        default_policy_args = {
-            # Please synchronize changes here with Scala side
-            # GrpcRetryHandler.scala
-            #
-            # Note: the number of retries is selected so that the maximum 
tolerated wait
-            # is guaranteed to be at least 10 minutes
-            "max_retries": 15,
-            "backoff_multiplier": 4.0,
-            "initial_backoff": 50,
-            "max_backoff": 60000,
-            "jitter": 500,
-            "min_jitter_threshold": 2000,
-        }
-        if retry_policy:
-            default_policy_args.update(retry_policy)
-
-        default_policy = DefaultPolicy(**default_policy_args)
+        retry_policy_args = retry_policy or dict()
+        default_policy = DefaultPolicy(**retry_policy_args)
         self.set_retry_policies([default_policy])
 
         if self._builder.session_id is None:
diff --git a/python/pyspark/sql/connect/client/retries.py 
b/python/pyspark/sql/connect/client/retries.py
index 6aa959e09b5..26aa6893dfa 100644
--- a/python/pyspark/sql/connect/client/retries.py
+++ b/python/pyspark/sql/connect/client/retries.py
@@ -185,6 +185,9 @@ class Retrying:
         self._done = False
 
     def can_retry(self, exception: BaseException) -> bool:
+        if isinstance(exception, RetryException):
+            return True
+
         return any(policy.can_retry(exception) for policy in self._policies)
 
     def accept_exception(self, exception: BaseException) -> bool:
@@ -204,8 +207,12 @@ class Retrying:
     def _wait(self) -> None:
         exception = self._last_exception()
 
-        # Attempt to find a policy to wait with
+        if isinstance(exception, RetryException):
+            # Considered immediately retriable
+            logger.debug(f"Got error: {repr(exception)}. Retrying.")
+            return
 
+        # Attempt to find a policy to wait with
         for policy in self._policies:
             if not policy.can_retry(exception):
                 continue
@@ -244,12 +251,34 @@ class Retrying:
 class RetryException(Exception):
     """
     An exception that can be thrown upstream when inside retry and which is 
always retryable
+    even without policies
     """
 
 
 class DefaultPolicy(RetryPolicy):
-    def __init__(self, **kwargs):  # type: ignore[no-untyped-def]
-        super().__init__(**kwargs)
+    # Please synchronize changes here with Scala side in
+    # org.apache.spark.sql.connect.client.RetryPolicy
+    #
+    # Note: the number of retries is selected so that the maximum tolerated 
wait
+    # is guaranteed to be at least 10 minutes
+
+    def __init__(
+        self,
+        max_retries: Optional[int] = 15,
+        backoff_multiplier: float = 4.0,
+        initial_backoff: int = 50,
+        max_backoff: Optional[int] = 60000,
+        jitter: int = 500,
+        min_jitter_threshold: int = 2000,
+    ):
+        super().__init__(
+            max_retries=max_retries,
+            backoff_multiplier=backoff_multiplier,
+            initial_backoff=initial_backoff,
+            max_backoff=max_backoff,
+            jitter=jitter,
+            min_jitter_threshold=min_jitter_threshold,
+        )
 
     def can_retry(self, e: BaseException) -> bool:
         """
@@ -267,8 +296,6 @@ class DefaultPolicy(RetryPolicy):
         True if the exception can be retried, False otherwise.
 
         """
-        if isinstance(e, RetryException):
-            return True
 
         if not isinstance(e, grpc.RpcError):
             return False
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 580ebc3965b..12e690c3a30 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -31,7 +31,6 @@ if should_test_connect:
     from pyspark.sql.connect.client.retries import (
         Retrying,
         DefaultPolicy,
-        RetryException,
         RetriesExceeded,
     )
     from pyspark.sql.connect.client.reattach import 
ExecutePlanResponseReattachableIterator
@@ -111,7 +110,7 @@ class SparkConnectClientTestCase(unittest.TestCase):
         try:
             for attempt in Retrying(client._retry_policies, sleep=sleep):
                 with attempt:
-                    raise RetryException()
+                    raise TestException("Retryable error", 
grpc.StatusCode.UNAVAILABLE)
         except RetriesExceeded:
             pass
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to