This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 59303f7e54eb [SPARK-52673][CONNECT][CLIENT] Add grpc RetryInfo 
handling to Spark Connect retry policies
59303f7e54eb is described below

commit 59303f7e54ebee342d36920e091300b4e05d3b28
Author: Alex Khakhlyuk <alex.khakhl...@gmail.com>
AuthorDate: Mon Jul 14 18:04:46 2025 +0900

    [SPARK-52673][CONNECT][CLIENT] Add grpc RetryInfo handling to Spark Connect 
retry policies
    
    ### What changes were proposed in this pull request?
    
    Spark Connect Client has a set of retry policies that specify which errors 
coming from the Server can be retried.
    This change adds the capability for the Spark Connect Client to use 
server-provided retry information according to the grpc standards: 
https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91
     The server can include `RetryInfo` gRPC message containing `retry_delay` 
field in its error response. The Client will now use `RetryInfo` message to 
classify the error as retriable and will use `retry_delay` to calculate the 
next time to wait. This behavior is in line with the gRPC standard for 
client-server communication.
    The change is needed for two reasons:
    1) If the Server is under heavy load or a task takes more time, it can tell 
the client to wait longer using the `retry_delay` field.
    2) If the Server needs to introduce a new retryable error, it can simply 
include `RetryInfo` in the error message. The error message will be retried 
automatically by the client. No changes to the client-side retry policies are 
needed to retry the new error.
    
    #### Changes in detail
    
    - Adding new `recognize_server_retry_delay` and `max_server_retry_delay` 
options for `RetryPolicy` classes in Python and Scala clients.
    - All policies with `recognize_server_retry_delay=True` will take 
`RetryInfo.retry_delay` into account when calculating the next backoff.
    - `retry_delay` can override client's `max_backoff`
    - `retry_delay` is limited by `max_server_retry_delay` (10 minutes by 
default).
    - When the server stops sending high retry_delays, the client goes back to 
using its own backoff policy limited by `max_backoff`.
    - `DefaultPolicy` has `recognize_server_retry_delay=True` and will use 
`retry_delay` in the backoff calculation.
    - Additionally, DefaultPolicy will classify all errors with `RetryInfo` as 
retryable.
    - If an error message can be retried by several policies, only retry it 
with the first one (highest prio) and then stop. This change is needed because 
`DefaultPolicy` now retries all errors with `RetryInfo`. If we keep the 
existing behaviour, an error that is both has the `RetryInfo` and is matched by 
a different `CustomPolicy`, would be retried both by the `DefaultPolicy` and by 
the `CustomPolicy`. This can lead to excessively long retry periods and 
complicates the planning of total  [...]
    - Moving retry policy related tests from `test_client.py` to a new 
`test_client_retries.py` file. Same for scala.
    - Extending docstrings.
    
    ### Why are the changes needed?
    
    See above
    
    ### Does this PR introduce _any_ user-facing change?
    
    1. The clients retry all errors with `RetryInfo` grpc message using the 
DefaultPolicy.
    2. The error is only retried by the first policy that matches it.
    
    ### How was this patch tested?
    
    Old and new unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51363 from khakhlyuk/retryinfo.
    
    Authored-by: Alex Khakhlyuk <alex.khakhl...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/client/retries.py       |  89 ++++++-
 .../sql/tests/connect/client/test_client.py        |  31 +--
 .../tests/connect/client/test_client_retries.py    | 225 ++++++++++++++++
 .../sql/tests/connect/test_connect_retry.py        |   4 +-
 .../client/SparkConnectClientRetriesSuite.scala    | 282 +++++++++++++++++++++
 .../connect/client/SparkConnectClientSuite.scala   | 124 ---------
 .../sql/connect/client/GrpcRetryHandler.scala      |  10 +-
 .../spark/sql/connect/client/RetryPolicy.scala     |  69 ++++-
 8 files changed, 662 insertions(+), 172 deletions(-)

diff --git a/python/pyspark/sql/connect/client/retries.py 
b/python/pyspark/sql/connect/client/retries.py
index e27100133b5a..436da250d791 100644
--- a/python/pyspark/sql/connect/client/retries.py
+++ b/python/pyspark/sql/connect/client/retries.py
@@ -19,7 +19,9 @@ import grpc
 import random
 import time
 import typing
-from typing import Optional, Callable, Generator, List, Type
+from google.rpc import error_details_pb2
+from grpc_status import rpc_status
+from typing import Optional, Callable, Generator, List, Type, cast
 from types import TracebackType
 from pyspark.sql.connect.logging import logger
 from pyspark.errors import PySparkRuntimeError, RetriesExceeded
@@ -45,6 +47,34 @@ class RetryPolicy:
     Describes key aspects of RetryPolicy.
 
     It's advised that different policies are implemented as different 
subclasses.
+
+    Parameters
+    ----------
+    max_retries: int, optional
+        Maximum number of retries.
+    initial_backoff: int
+        Start value of the exponential backoff.
+    max_backoff: int, optional
+        Maximal value of the exponential backoff.
+    backoff_multiplier: float
+        Multiplicative base of the exponential backoff.
+    jitter: int
+        Sample a random value uniformly from the range [0, jitter] and add it 
to the backoff.
+    min_jitter_threshold: int
+        Minimal value of the backoff to add random jitter.
+    recognize_server_retry_delay: bool
+        Per gRPC standard, the server can send error messages that contain 
`RetryInfo` message
+        with `retry_delay` field indicating that the client should wait for at 
least `retry_delay`
+        amount of time before retrying again, see:
+        
https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91
+
+        If this flag is set to true, RetryPolicy will use 
`RetryInfo.retry_delay` field
+        in the backoff computation. Server's `retry_delay` can override 
client's `max_backoff`.
+
+        This flag does not change which errors are retried, only how the 
backoff is computed.
+        `DefaultPolicy` additionally has a rule for retrying any error that 
contains `RetryInfo`.
+    max_server_retry_delay: int, optional
+        Limit for the server-provided `retry_delay`.
     """
 
     def __init__(
@@ -55,6 +85,8 @@ class RetryPolicy:
         backoff_multiplier: float = 1.0,
         jitter: int = 0,
         min_jitter_threshold: int = 0,
+        recognize_server_retry_delay: bool = False,
+        max_server_retry_delay: Optional[int] = None,
     ):
         self.max_retries = max_retries
         self.initial_backoff = initial_backoff
@@ -62,6 +94,8 @@ class RetryPolicy:
         self.backoff_multiplier = backoff_multiplier
         self.jitter = jitter
         self.min_jitter_threshold = min_jitter_threshold
+        self.recognize_server_retry_delay = recognize_server_retry_delay
+        self.max_server_retry_delay = max_server_retry_delay
         self._name = self.__class__.__name__
 
     @property
@@ -98,7 +132,7 @@ class RetryPolicyState:
     def can_retry(self, exception: BaseException) -> bool:
         return self.policy.can_retry(exception)
 
-    def next_attempt(self) -> Optional[int]:
+    def next_attempt(self, exception: Optional[BaseException] = None) -> 
Optional[int]:
         """
         Returns
         -------
@@ -119,6 +153,14 @@ class RetryPolicyState:
                 float(self.policy.max_backoff), wait_time * 
self.policy.backoff_multiplier
             )
 
+        if exception is not None and self.policy.recognize_server_retry_delay:
+            retry_delay = extract_retry_delay(exception)
+            if retry_delay is not None:
+                logger.debug(f"The server has sent a retry delay of 
{retry_delay} ms.")
+                if self.policy.max_server_retry_delay is not None:
+                    retry_delay = min(retry_delay, 
self.policy.max_server_retry_delay)
+                wait_time = max(wait_time, retry_delay)
+
         # Jitter current backoff, after the future backoff was computed
         if wait_time >= self.policy.min_jitter_threshold:
             wait_time += random.uniform(0, self.policy.jitter)
@@ -160,6 +202,7 @@ class Retrying:
     This class is a point of entry into the retry logic.
     The class accepts a list of retry policies and applies them in given order.
     The first policy accepting an exception will be used.
+    If the error was matched by one policy, the other policies will be skipped.
 
     The usage of the class should be as follows:
     for attempt in Retrying(...):
@@ -217,17 +260,18 @@ class Retrying:
             return
 
         # Attempt to find a policy to wait with
+        matched_policy = None
         for policy in self._policies:
-            if not policy.can_retry(exception):
-                continue
-
-            wait_time = policy.next_attempt()
+            if policy.can_retry(exception):
+                matched_policy = policy
+                break
+        if matched_policy is not None:
+            wait_time = matched_policy.next_attempt(exception)
             if wait_time is not None:
                 logger.debug(
                     f"Got error: {repr(exception)}. "
-                    + f"Will retry after {wait_time} ms (policy: 
{policy.name})"
+                    + f"Will retry after {wait_time} ms (policy: 
{matched_policy.name})"
                 )
-
                 self._sleep(wait_time / 1000)
                 return
 
@@ -274,6 +318,8 @@ class DefaultPolicy(RetryPolicy):
         max_backoff: Optional[int] = 60000,
         jitter: int = 500,
         min_jitter_threshold: int = 2000,
+        recognize_server_retry_delay: bool = True,
+        max_server_retry_delay: Optional[int] = 10 * 60 * 1000,  # 10 minutes
     ):
         super().__init__(
             max_retries=max_retries,
@@ -282,6 +328,8 @@ class DefaultPolicy(RetryPolicy):
             max_backoff=max_backoff,
             jitter=jitter,
             min_jitter_threshold=min_jitter_threshold,
+            recognize_server_retry_delay=recognize_server_retry_delay,
+            max_server_retry_delay=max_server_retry_delay,
         )
 
     def can_retry(self, e: BaseException) -> bool:
@@ -314,4 +362,29 @@ class DefaultPolicy(RetryPolicy):
         if e.code() == grpc.StatusCode.UNAVAILABLE:
             return True
 
+        if extract_retry_info(e) is not None:
+            # All errors messages containing `RetryInfo` should be retried.
+            return True
+
         return False
+
+
+def extract_retry_info(exception: BaseException) -> 
Optional[error_details_pb2.RetryInfo]:
+    """Extract and return RetryInfo from the grpc.RpcError"""
+    if isinstance(exception, grpc.RpcError):
+        status = rpc_status.from_call(cast(grpc.Call, exception))
+        if status:
+            for d in status.details:
+                if d.Is(error_details_pb2.RetryInfo.DESCRIPTOR):
+                    info = error_details_pb2.RetryInfo()
+                    d.Unpack(info)
+                    return info
+    return None
+
+
+def extract_retry_delay(exception: BaseException) -> Optional[int]:
+    """Extract and return RetryInfo.retry_delay in milliseconds from 
grpc.RpcError if present."""
+    retry_info = extract_retry_info(exception)
+    if retry_info is not None:
+        return retry_info.retry_delay.ToMilliseconds()
+    return None
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 37ed9207ed05..c3954827bae5 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -35,7 +35,7 @@ if should_test_connect:
     )
     from pyspark.sql.connect.client.reattach import 
ExecutePlanResponseReattachableIterator
     from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
-    from pyspark.errors import PySparkRuntimeError, RetriesExceeded
+    from pyspark.errors import PySparkRuntimeError
     import pyspark.sql.connect.proto as proto
 
     class TestPolicy(DefaultPolicy):
@@ -227,35 +227,6 @@ class SparkConnectClientTestCase(unittest.TestCase):
         client.close()
         self.assertTrue(client.is_closed)
 
-    def test_retry(self):
-        client = SparkConnectClient("sc://foo/;token=bar")
-
-        total_sleep = 0
-
-        def sleep(t):
-            nonlocal total_sleep
-            total_sleep += t
-
-        try:
-            for attempt in Retrying(client._retry_policies, sleep=sleep):
-                with attempt:
-                    raise TestException("Retryable error", 
grpc.StatusCode.UNAVAILABLE)
-        except RetriesExceeded:
-            pass
-
-        # tolerated at least 10 mins of fails
-        self.assertGreaterEqual(total_sleep, 600)
-
-    def test_retry_client_unit(self):
-        client = SparkConnectClient("sc://foo/;token=bar")
-
-        policyA = TestPolicy()
-        policyB = DefaultPolicy()
-
-        client.set_retry_policies([policyA, policyB])
-
-        self.assertEqual(client.get_retry_policies(), [policyA, policyB])
-
     def test_channel_builder_with_session(self):
         dummy = str(uuid.uuid4())
         chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}")
diff --git a/python/pyspark/sql/tests/connect/client/test_client_retries.py 
b/python/pyspark/sql/tests/connect/client/test_client_retries.py
new file mode 100644
index 000000000000..400442363b47
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/client/test_client_retries.py
@@ -0,0 +1,225 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
+
+if should_test_connect:
+    import grpc
+    import google.protobuf.any_pb2 as any_pb2
+    import google.protobuf.duration_pb2 as duration_pb2
+    from google.rpc import status_pb2
+    from google.rpc import error_details_pb2
+    from pyspark.sql.connect.client import SparkConnectClient
+    from pyspark.sql.connect.client.retries import (
+        Retrying,
+        DefaultPolicy,
+    )
+    from pyspark.errors import RetriesExceeded
+    from pyspark.sql.tests.connect.client.test_client import (
+        TestPolicy,
+        TestException,
+    )
+
+    class SleepTimeTracker:
+        """Tracks sleep times in ms for testing purposes."""
+
+        def __init__(self):
+            self._times = []
+
+        def sleep(self, t: float):
+            self._times.append(int(1000 * t))
+
+        @property
+        def times(self):
+            return list(self._times)
+
+    def create_test_exception_with_details(
+        msg: str,
+        code: grpc.StatusCode = grpc.StatusCode.INTERNAL,
+        retry_delay: int = 0,
+    ) -> TestException:
+        """Helper function for creating TestException with additional error 
details
+        like retry_delay.
+        """
+        retry_delay_msg = duration_pb2.Duration()
+        retry_delay_msg.FromMilliseconds(retry_delay)
+        retry_info = error_details_pb2.RetryInfo()
+        retry_info.retry_delay.CopyFrom(retry_delay_msg)
+
+        # Pack RetryInfo into an Any type
+        retry_info_any = any_pb2.Any()
+        retry_info_any.Pack(retry_info)
+        status = status_pb2.Status(
+            code=code.value[0],
+            message=msg,
+            details=[retry_info_any],
+        )
+        return TestException(msg=msg, code=code, trailing_status=status)
+
+    def get_client_policies_map(client: SparkConnectClient) -> dict:
+        return {type(policy): policy for policy in client.get_retry_policies()}
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
+class SparkConnectClientRetriesTestCase(unittest.TestCase):
+    def assertListsAlmostEqual(self, first, second, places=None, msg=None, 
delta=None):
+        self.assertEqual(len(first), len(second), msg)
+        for i in range(len(first)):
+            self.assertAlmostEqual(first[i], second[i], places, msg, delta)
+
+    def test_retry(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+
+        sleep_tracker = SleepTimeTracker()
+        try:
+            for attempt in Retrying(client._retry_policies, 
sleep=sleep_tracker.sleep):
+                with attempt:
+                    raise TestException("Retryable error", 
grpc.StatusCode.UNAVAILABLE)
+        except RetriesExceeded:
+            pass
+
+        # tolerated at least 10 mins of fails
+        self.assertGreaterEqual(sum(sleep_tracker.times), 600)
+
+    def test_retry_client_unit(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+
+        policyA = TestPolicy()
+        policyB = DefaultPolicy()
+
+        client.set_retry_policies([policyA, policyB])
+
+        self.assertEqual(client.get_retry_policies(), [policyA, policyB])
+
+    def test_default_policy_retries_retry_info(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+        policy = get_client_policies_map(client).get(DefaultPolicy)
+        self.assertIsNotNone(policy)
+
+        # retry delay = 0, error code not matched by any policy.
+        # Testing if errors with RetryInfo are being retried by the 
DefaultPolicy.
+        retry_delay = 0
+        sleep_tracker = SleepTimeTracker()
+        try:
+            for attempt in Retrying(client._retry_policies, 
sleep=sleep_tracker.sleep):
+                with attempt:
+                    raise create_test_exception_with_details(
+                        msg="Some error message",
+                        code=grpc.StatusCode.UNIMPLEMENTED,
+                        retry_delay=retry_delay,
+                    )
+        except RetriesExceeded:
+            pass
+        expected_times = [
+            min(policy.max_backoff, policy.initial_backoff * 
policy.backoff_multiplier**i)
+            for i in range(policy.max_retries)
+        ]
+        self.assertListsAlmostEqual(sleep_tracker.times, expected_times, 
delta=policy.jitter)
+
+    def test_retry_delay_overrides_max_backoff(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+        policy = get_client_policies_map(client).get(DefaultPolicy)
+        self.assertIsNotNone(policy)
+
+        # retry delay = 5 mins.
+        # Testing if retry_delay overrides max_backoff.
+        retry_delay = 5 * 60 * 1000
+        sleep_tracker = SleepTimeTracker()
+        # assert that retry_delay is greater than max_backoff to make sure the 
test is valid
+        self.assertGreaterEqual(retry_delay, policy.max_backoff)
+        try:
+            for attempt in Retrying(client._retry_policies, 
sleep=sleep_tracker.sleep):
+                with attempt:
+                    raise create_test_exception_with_details(
+                        "Some error message",
+                        grpc.StatusCode.UNAVAILABLE,
+                        retry_delay,
+                    )
+        except RetriesExceeded:
+            pass
+        expected_times = [retry_delay] * policy.max_retries
+        self.assertListsAlmostEqual(sleep_tracker.times, expected_times, 
delta=policy.jitter)
+
+    def test_max_server_retry_delay(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+        policy = get_client_policies_map(client).get(DefaultPolicy)
+        self.assertIsNotNone(policy)
+
+        # retry delay = 10 hours
+        # Testing if max_server_retry_delay limit works.
+        retry_delay = 10 * 60 * 60 * 1000
+        sleep_tracker = SleepTimeTracker()
+        try:
+            for attempt in Retrying(client._retry_policies, 
sleep=sleep_tracker.sleep):
+                with attempt:
+                    raise create_test_exception_with_details(
+                        "Some error message",
+                        grpc.StatusCode.UNAVAILABLE,
+                        retry_delay,
+                    )
+        except RetriesExceeded:
+            pass
+
+        expected_times = [policy.max_server_retry_delay] * policy.max_retries
+        self.assertListsAlmostEqual(sleep_tracker.times, expected_times, 
delta=policy.jitter)
+
+    def test_return_to_exponential_backoff(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+        policy = get_client_policies_map(client).get(DefaultPolicy)
+        self.assertIsNotNone(policy)
+
+        # Start with retry_delay = 5 mins, then set it to zero.
+        # Test if backoff goes back to client's exponential strategy.
+        initial_retry_delay = 5 * 60 * 1000
+        sleep_tracker = SleepTimeTracker()
+        try:
+            for i, attempt in enumerate(
+                Retrying(client._retry_policies, sleep=sleep_tracker.sleep)
+            ):
+                if i < 2:
+                    retry_delay = initial_retry_delay
+                elif i < 5:
+                    retry_delay = 0
+                else:
+                    break
+                with attempt:
+                    raise create_test_exception_with_details(
+                        "Some error message",
+                        grpc.StatusCode.UNAVAILABLE,
+                        retry_delay,
+                    )
+        except RetriesExceeded:
+            pass
+
+        expected_times = [initial_retry_delay] * 2 + [
+            policy.initial_backoff * policy.backoff_multiplier**i for i in 
range(2, 5)
+        ]
+        self.assertListsAlmostEqual(sleep_tracker.times, expected_times, 
delta=policy.jitter)
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.client.test_client_retries import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_connect_retry.py 
b/python/pyspark/sql/tests/connect/test_connect_retry.py
index f51e06247928..61ab0dcea862 100644
--- a/python/pyspark/sql/tests/connect/test_connect_retry.py
+++ b/python/pyspark/sql/tests/connect/test_connect_retry.py
@@ -162,8 +162,8 @@ class RetryTests(unittest.TestCase):
                 with attempt:
                     self.stub(10, grpc.StatusCode.INTERNAL)
 
-        self.assertEqual(self.call_wrap["attempts"], 7)
-        self.assertEqual(self.call_wrap["raised"], 7)
+        self.assertEqual(self.call_wrap["attempts"], 3)
+        self.assertEqual(self.call_wrap["raised"], 3)
 
 
 if __name__ == "__main__":
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala
new file mode 100644
index 000000000000..3408c15b73f0
--- /dev/null
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import scala.concurrent.duration.FiniteDuration
+
+import com.google.protobuf.{Any, Duration}
+import com.google.rpc
+import io.grpc.{Status, StatusRuntimeException}
+import io.grpc.protobuf.StatusProto
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.sql.connect.test.ConnectFunSuite
+
+class SparkConnectClientRetriesSuite
+    extends ConnectFunSuite
+    with BeforeAndAfterEach
+    with Eventually {
+
+  private class DummyFn(e: => Throwable, numFails: Int = 3) {
+    var counter = 0
+    def fn(): Int = {
+      if (counter < numFails) {
+        counter += 1
+        throw e
+      } else {
+        42
+      }
+    }
+  }
+
+  /** Tracks sleep times in milliseconds for testing purposes. */
+  private class SleepTimeTracker {
+    private val data = scala.collection.mutable.ListBuffer[Long]()
+    def sleep(t: Long): Unit = data.append(t)
+    def times: List[Long] = data.toList
+    def totalSleep: Long = data.sum
+  }
+
+  /** Helper function for creating a test exception with retry_delay */
+  private def createTestExceptionWithDetails(
+      msg: String,
+      code: Status.Code = Status.Code.INTERNAL,
+      retryDelay: FiniteDuration = FiniteDuration(0, "s")): 
StatusRuntimeException = {
+    // In grpc-java, RetryDelay should be specified as seconds: Long + nanos: 
Int
+    val seconds = retryDelay.toSeconds
+    val nanos = (retryDelay - FiniteDuration(seconds, "s")).toNanos.toInt
+    val retryDelayMsg = Duration
+      .newBuilder()
+      .setSeconds(seconds)
+      .setNanos(nanos)
+      .build()
+    val retryInfo = rpc.RetryInfo
+      .newBuilder()
+      .setRetryDelay(retryDelayMsg)
+      .build()
+    val status = rpc.Status
+      .newBuilder()
+      .setMessage(msg)
+      .setCode(code.value())
+      .addDetails(Any.pack(retryInfo))
+      .build()
+    StatusProto.toStatusRuntimeException(status)
+  }
+
+  /** helper function for comparing two sequences of sleep times */
+  private def assertLongSequencesAlmostEqual(
+      first: Seq[Long],
+      second: Seq[Long],
+      delta: Long): Unit = {
+    assert(first.length == second.length, "Lists have different lengths.")
+    for ((a, b) <- first.zip(second)) {
+      assert(math.abs(a - b) <= delta, s"Elements $a and $b differ by more 
than $delta.")
+    }
+  }
+
+  test("SPARK-44721: Retries run for a minimum period") {
+    // repeat test few times to avoid random flakes
+    for (_ <- 1 to 10) {
+      val st = new SleepTimeTracker()
+      val dummyFn = new DummyFn(new 
StatusRuntimeException(Status.UNAVAILABLE), numFails = 100)
+      val retryHandler = new GrpcRetryHandler(RetryPolicy.defaultPolicies(), 
st.sleep)
+
+      assertThrows[RetriesExceeded] {
+        retryHandler.retry {
+          dummyFn.fn()
+        }
+      }
+
+      assert(st.totalSleep >= 10 * 60 * 1000) // waited at least 10 minutes
+    }
+  }
+
+  test("SPARK-44275: retry actually retries") {
+    val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
+    val retryPolicies = RetryPolicy.defaultPolicies()
+    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
+    val result = retryHandler.retry { dummyFn.fn() }
+
+    assert(result == 42)
+    assert(dummyFn.counter == 3)
+  }
+
+  test("SPARK-44275: default retryException retries only on UNAVAILABLE") {
+    val dummyFn = new DummyFn(new StatusRuntimeException(Status.ABORTED))
+    val retryPolicies = RetryPolicy.defaultPolicies()
+    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
+
+    assertThrows[StatusRuntimeException] {
+      retryHandler.retry { dummyFn.fn() }
+    }
+    assert(dummyFn.counter == 1)
+  }
+
+  test("SPARK-44275: retry uses canRetry to filter exceptions") {
+    val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
+    val retryPolicy = RetryPolicy(canRetry = _ => false, name = "TestPolicy")
+    val retryHandler = new GrpcRetryHandler(retryPolicy)
+
+    assertThrows[StatusRuntimeException] {
+      retryHandler.retry { dummyFn.fn() }
+    }
+    assert(dummyFn.counter == 1)
+  }
+
+  test("SPARK-44275: retry does not exceed maxRetries") {
+    val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
+    val retryPolicy = RetryPolicy(canRetry = _ => true, maxRetries = Some(1), 
name = "TestPolicy")
+    val retryHandler = new GrpcRetryHandler(retryPolicy, sleep = _ => {})
+
+    assertThrows[RetriesExceeded] {
+      retryHandler.retry { dummyFn.fn() }
+    }
+    assert(dummyFn.counter == 2)
+  }
+
+  def testPolicySpecificError(maxRetries: Int, status: Status): RetryPolicy = {
+    RetryPolicy(
+      maxRetries = Some(maxRetries),
+      name = s"Policy for ${status.getCode}",
+      canRetry = {
+        case e: StatusRuntimeException => e.getStatus.getCode == status.getCode
+        case _ => false
+      })
+  }
+
+  test("Test multiple policies") {
+    val policy1 = testPolicySpecificError(maxRetries = 2, status = 
Status.UNAVAILABLE)
+    val policy2 = testPolicySpecificError(maxRetries = 4, status = 
Status.INTERNAL)
+
+    // Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors
+
+    val errors = (List.fill(2)(Status.UNAVAILABLE) ++ 
List.fill(4)(Status.INTERNAL)).iterator
+
+    new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
+      val e = errors.nextOption()
+      if (e.isDefined) {
+        throw e.get.asRuntimeException()
+      }
+    })
+
+    assert(!errors.hasNext)
+  }
+
+  test("Test multiple policies exceed") {
+    val policy1 = testPolicySpecificError(maxRetries = 2, status = 
Status.INTERNAL)
+    val policy2 = testPolicySpecificError(maxRetries = 4, status = 
Status.INTERNAL)
+
+    val errors = List.fill(10)(Status.INTERNAL).iterator
+    var countAttempted = 0
+
+    assertThrows[RetriesExceeded](
+      new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
+        countAttempted += 1
+        val e = errors.nextOption()
+        if (e.isDefined) {
+          throw e.get.asRuntimeException()
+        }
+      }))
+
+    assert(countAttempted == 3)
+  }
+
+  test("DefaultPolicy retries exceptions with RetryInfo") {
+    // Error contains RetryInfo with retry_delay set to 0
+    val dummyFn =
+      new DummyFn(createTestExceptionWithDetails(msg = "Some error message"), 
numFails = 100)
+    val retryPolicies = RetryPolicy.defaultPolicies()
+    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
+    assertThrows[RetriesExceeded] {
+      retryHandler.retry { dummyFn.fn() }
+    }
+
+    // Should be retried by DefaultPolicy
+    val policy = retryPolicies.find(_.name == "DefaultPolicy").get
+    assert(dummyFn.counter == policy.maxRetries.get + 1)
+  }
+
+  test("retry_delay overrides maxBackoff") {
+    val st = new SleepTimeTracker()
+    val retryDelay = FiniteDuration(5, "min")
+    val dummyFn = new DummyFn(
+      createTestExceptionWithDetails(msg = "Some error message", retryDelay = 
retryDelay),
+      numFails = 100)
+    val retryPolicies = RetryPolicy.defaultPolicies()
+    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
+
+    assertThrows[RetriesExceeded] {
+      retryHandler.retry { dummyFn.fn() }
+    }
+
+    // Should be retried by DefaultPolicy
+    val policy = retryPolicies.find(_.name == "DefaultPolicy").get
+    // sleep times are higher than maxBackoff and are equal to retryDelay + 
jitter
+    st.times.foreach(t => assert(t > policy.maxBackoff.get.toMillis + 
policy.jitter.toMillis))
+    val expectedSleeps = List.fill(policy.maxRetries.get)(retryDelay.toMillis)
+    assertLongSequencesAlmostEqual(st.times, expectedSleeps, 
policy.jitter.toMillis)
+  }
+
+  test("maxServerRetryDelay limits retry_delay") {
+    val st = new SleepTimeTracker()
+    val retryDelay = FiniteDuration(5, "d")
+    val dummyFn = new DummyFn(
+      createTestExceptionWithDetails(msg = "Some error message", retryDelay = 
retryDelay),
+      numFails = 100)
+    val retryPolicies = RetryPolicy.defaultPolicies()
+    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
+
+    assertThrows[RetriesExceeded] {
+      retryHandler.retry { dummyFn.fn() }
+    }
+
+    // Should be retried by DefaultPolicy
+    val policy = retryPolicies.find(_.name == "DefaultPolicy").get
+    val expectedSleeps = 
List.fill(policy.maxRetries.get)(policy.maxServerRetryDelay.get.toMillis)
+    assertLongSequencesAlmostEqual(st.times, expectedSleeps, 
policy.jitter.toMillis)
+  }
+
+  test("Policy uses to exponential backoff after retry_delay is unset") {
+    val st = new SleepTimeTracker()
+    val retryDelay = FiniteDuration(5, "min")
+    val retryPolicies = RetryPolicy.defaultPolicies()
+    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
+    val errors = (
+      List.fill(2)(
+        createTestExceptionWithDetails(
+          msg = "Some error message",
+          retryDelay = retryDelay)) ++ List.fill(3)(
+        createTestExceptionWithDetails(
+          msg = "Some error message",
+          code = Status.Code.UNAVAILABLE))
+    ).iterator
+
+    retryHandler.retry({
+      if (errors.hasNext) {
+        throw errors.next()
+      }
+    })
+    assert(!errors.hasNext)
+
+    // Should be retried by DefaultPolicy
+    val policy = retryPolicies.find(_.name == "DefaultPolicy").get
+    val expectedSleeps = List.fill(2)(retryDelay.toMillis) ++ 
List.tabulate(3)(i =>
+      policy.initialBackoff.toMillis * math.pow(policy.backoffMultiplier, i + 
2).toLong)
+    assertLongSequencesAlmostEqual(st.times, expectedSleeps, delta = 
policy.jitter.toMillis)
+  }
+}
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 9bb8f5889d33..a41ea344cbd4 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -339,130 +339,6 @@ class SparkConnectClientSuite extends ConnectFunSuite 
with BeforeAndAfterEach {
     }
   }
 
-  private class DummyFn(e: => Throwable, numFails: Int = 3) {
-    var counter = 0
-    def fn(): Int = {
-      if (counter < numFails) {
-        counter += 1
-        throw e
-      } else {
-        42
-      }
-    }
-  }
-
-  test("SPARK-44721: Retries run for a minimum period") {
-    // repeat test few times to avoid random flakes
-    for (_ <- 1 to 10) {
-      var totalSleepMs: Long = 0
-
-      def sleep(t: Long): Unit = {
-        totalSleepMs += t
-      }
-
-      val dummyFn = new DummyFn(new 
StatusRuntimeException(Status.UNAVAILABLE), numFails = 100)
-      val retryHandler = new GrpcRetryHandler(RetryPolicy.defaultPolicies(), 
sleep)
-
-      assertThrows[RetriesExceeded] {
-        retryHandler.retry {
-          dummyFn.fn()
-        }
-      }
-
-      assert(totalSleepMs >= 10 * 60 * 1000) // waited at least 10 minutes
-    }
-  }
-
-  test("SPARK-44275: retry actually retries") {
-    val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
-    val retryPolicies = RetryPolicy.defaultPolicies()
-    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
-    val result = retryHandler.retry { dummyFn.fn() }
-
-    assert(result == 42)
-    assert(dummyFn.counter == 3)
-  }
-
-  test("SPARK-44275: default retryException retries only on UNAVAILABLE") {
-    val dummyFn = new DummyFn(new StatusRuntimeException(Status.ABORTED))
-    val retryPolicies = RetryPolicy.defaultPolicies()
-    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
-
-    assertThrows[StatusRuntimeException] {
-      retryHandler.retry { dummyFn.fn() }
-    }
-    assert(dummyFn.counter == 1)
-  }
-
-  test("SPARK-44275: retry uses canRetry to filter exceptions") {
-    val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
-    val retryPolicy = RetryPolicy(canRetry = _ => false, name = "TestPolicy")
-    val retryHandler = new GrpcRetryHandler(retryPolicy)
-
-    assertThrows[StatusRuntimeException] {
-      retryHandler.retry { dummyFn.fn() }
-    }
-    assert(dummyFn.counter == 1)
-  }
-
-  test("SPARK-44275: retry does not exceed maxRetries") {
-    val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
-    val retryPolicy = RetryPolicy(canRetry = _ => true, maxRetries = Some(1), 
name = "TestPolicy")
-    val retryHandler = new GrpcRetryHandler(retryPolicy, sleep = _ => {})
-
-    assertThrows[RetriesExceeded] {
-      retryHandler.retry { dummyFn.fn() }
-    }
-    assert(dummyFn.counter == 2)
-  }
-
-  def testPolicySpecificError(maxRetries: Int, status: Status): RetryPolicy = {
-    RetryPolicy(
-      maxRetries = Some(maxRetries),
-      name = s"Policy for ${status.getCode}",
-      canRetry = {
-        case e: StatusRuntimeException => e.getStatus.getCode == status.getCode
-        case _ => false
-      })
-  }
-
-  test("Test multiple policies") {
-    val policy1 = testPolicySpecificError(maxRetries = 2, status = 
Status.UNAVAILABLE)
-    val policy2 = testPolicySpecificError(maxRetries = 4, status = 
Status.INTERNAL)
-
-    // Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors
-
-    val errors = (List.fill(2)(Status.UNAVAILABLE) ++ 
List.fill(4)(Status.INTERNAL)).iterator
-
-    new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
-      val e = errors.nextOption()
-      if (e.isDefined) {
-        throw e.get.asRuntimeException()
-      }
-    })
-
-    assert(!errors.hasNext)
-  }
-
-  test("Test multiple policies exceed") {
-    val policy1 = testPolicySpecificError(maxRetries = 2, status = 
Status.INTERNAL)
-    val policy2 = testPolicySpecificError(maxRetries = 4, status = 
Status.INTERNAL)
-
-    val errors = List.fill(10)(Status.INTERNAL).iterator
-    var countAttempted = 0
-
-    assertThrows[RetriesExceeded](
-      new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
-        countAttempted += 1
-        val e = errors.nextOption()
-        if (e.isDefined) {
-          throw e.get.asRuntimeException()
-        }
-      }))
-
-    assert(countAttempted == 7)
-  }
-
   test("ArtifactManager retries errors") {
     var attempt = 0
 
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
index 7e0a356b9e49..0a38d18773de 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
@@ -194,15 +194,17 @@ private[sql] object GrpcRetryHandler extends Logging {
         return
       }
 
-      for (policy <- policies if policy.canRetry(lastException)) {
-        val time = policy.nextAttempt()
-
+      // find a policy to wait with
+      val matchedPolicyOpt = policies.find(_.canRetry(lastException))
+      if (matchedPolicyOpt.isDefined) {
+        val matchedPolicy = matchedPolicyOpt.get
+        val time = matchedPolicy.nextAttempt(lastException)
         if (time.isDefined) {
           logWarning(
             log"Non-Fatal error during RPC execution: ${MDC(ERROR, 
lastException)}, " +
               log"retrying (wait=${MDC(RETRY_WAIT_TIME, time.get.toMillis)} 
ms, " +
               log"currentRetryNum=${MDC(NUM_RETRY, currentRetryNum)}, " +
-              log"policy=${MDC(POLICY, policy.getName)}).")
+              log"policy=${MDC(POLICY, matchedPolicy.getName)}).")
           sleep(time.get.toMillis)
           return
         }
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
index 8c8472d780db..5b5c4b517923 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
@@ -18,9 +18,14 @@
 package org.apache.spark.sql.connect.client
 
 import scala.concurrent.duration.{Duration, FiniteDuration}
+import scala.jdk.CollectionConverters._
 import scala.util.Random
 
+import com.google.rpc.RetryInfo
 import io.grpc.{Status, StatusRuntimeException}
+import io.grpc.protobuf.StatusProto
+
+import org.apache.spark.internal.Logging
 
 /**
  * [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]]
@@ -33,8 +38,27 @@ import io.grpc.{Status, StatusRuntimeException}
  *   Maximal value of the exponential backoff (ms).
  * @param backoffMultiplier
  *   Multiplicative base of the exponential backoff.
+ * @param jitter
+ *   Sample a random value uniformly from the range [0, jitter] and add it to 
the backoff.
+ * @param minJitterThreshold
+ *   Minimal value of the backoff to add random jitter.
  * @param canRetry
  *   Function that determines whether a retry is to be performed in the event 
of an error.
+ * @param name
+ *   Name of the policy.
+ * @param recognizeServerRetryDelay
+ *   Per gRPC standard, the server can send error messages that contain 
`RetryInfo` message with
+ *   `retry_delay` field indicating that the client should wait for at least 
`retry_delay` amount
+ *   of time before retrying again, see:
+ *   
https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91
+ *
+ * If this flag is set to true, RetryPolicy will use `RetryInfo.retry_delay` 
field in the backoff
+ * computation. Server's `retry_delay` can override client's `maxBackoff`.
+ *
+ * This flag does not change which errors are retried, only how the backoff is 
computed.
+ * `DefaultPolicy` additionally has a rule for retrying any error that 
contains `RetryInfo`.
+ * @param maxServerRetryDelay
+ *   Limit for the server-provided `retry_delay`.
  */
 case class RetryPolicy(
     maxRetries: Option[Int] = None,
@@ -44,14 +68,16 @@ case class RetryPolicy(
     jitter: FiniteDuration = FiniteDuration(0, "s"),
     minJitterThreshold: FiniteDuration = FiniteDuration(0, "s"),
     canRetry: Throwable => Boolean,
-    name: String) {
+    name: String,
+    recognizeServerRetryDelay: Boolean = false,
+    maxServerRetryDelay: Option[FiniteDuration] = None) {
 
   def getName: String = name
 
   def toState: RetryPolicy.RetryPolicyState = new 
RetryPolicy.RetryPolicyState(this)
 }
 
-object RetryPolicy {
+object RetryPolicy extends Logging {
   def defaultPolicy(): RetryPolicy = RetryPolicy(
     name = "DefaultPolicy",
     // Please synchronize changes here with Python side:
@@ -65,7 +91,9 @@ object RetryPolicy {
     backoffMultiplier = 4.0,
     jitter = FiniteDuration(500, "ms"),
     minJitterThreshold = FiniteDuration(2, "s"),
-    canRetry = defaultPolicyRetryException)
+    canRetry = defaultPolicyRetryException,
+    recognizeServerRetryDelay = true,
+    maxServerRetryDelay = Some(FiniteDuration(10, "min")))
 
   // list of policies to be used by this client
   def defaultPolicies(): Seq[RetryPolicy] = List(defaultPolicy())
@@ -77,7 +105,7 @@ object RetryPolicy {
     private var nextWait: Duration = policy.initialBackoff
 
     // return waiting time until next attempt, or None if has exceeded max 
retries
-    def nextAttempt(): Option[Duration] = {
+    def nextAttempt(e: Throwable): Option[Duration] = {
       if (policy.maxRetries.isDefined && numberAttempts >= 
policy.maxRetries.get) {
         return None
       }
@@ -90,6 +118,14 @@ object RetryPolicy {
         nextWait = nextWait min policy.maxBackoff.get
       }
 
+      if (policy.recognizeServerRetryDelay) {
+        extractRetryDelay(e).foreach { retryDelay =>
+          logDebug(s"The server has sent a retry delay of $retryDelay ms.")
+          val retryDelayLimited = retryDelay min 
policy.maxServerRetryDelay.getOrElse(retryDelay)
+          currentWait = currentWait max retryDelayLimited
+        }
+      }
+
       if (currentWait >= policy.minJitterThreshold) {
         currentWait += Random.nextDouble() * policy.jitter
       }
@@ -127,8 +163,33 @@ object RetryPolicy {
         if (statusCode == Status.Code.UNAVAILABLE) {
           return true
         }
+
+        // All errors messages containing `RetryInfo` should be retried.
+        if (extractRetryInfo(e).isDefined) {
+          return true
+        }
+
         false
       case _ => false
     }
   }
+
+  private def extractRetryInfo(e: Throwable): Option[RetryInfo] = {
+    e match {
+      case e: StatusRuntimeException =>
+        Option(StatusProto.fromThrowable(e))
+          .flatMap(status =>
+            status.getDetailsList.asScala
+              .find(_.is(classOf[RetryInfo]))
+              .map(_.unpack(classOf[RetryInfo])))
+      case _ => None
+    }
+  }
+
+  private def extractRetryDelay(e: Throwable): Option[FiniteDuration] = {
+    extractRetryInfo(e)
+      .flatMap(retryInfo => Option(retryInfo.getRetryDelay))
+      .map(retryDelay =>
+        FiniteDuration(retryDelay.getSeconds, "s") + 
FiniteDuration(retryDelay.getNanos, "ns"))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to