This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ef240cdf6eaa [SPARK-45733][CONNECT][PYTHON] Support multiple retry
policies
ef240cdf6eaa is described below
commit ef240cdf6eaaa95f85aadc0f1272e991cc50bd35
Author: Alice Sayutina <[email protected]>
AuthorDate: Mon Nov 13 09:37:48 2023 +0900
[SPARK-45733][CONNECT][PYTHON] Support multiple retry policies
### What changes were proposed in this pull request?
Support multiple retry policies defined at the same time. Each policy
determines which error types it can retry and how exactly those should be
spread out.
### Why are the changes needed?
Different error types should be treated differently For instance,
networking connectivity errors and remote resources being initialized should be
treated separately.
### Does this PR introduce _any_ user-facing change?
No (as long as user doesn't poke within client internals).
### How was this patch tested?
Unit tests, some hand testing.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43591 from cdkrot/SPARK-45733.
Authored-by: Alice Sayutina <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 218 ++-------------
python/pyspark/sql/connect/client/reattach.py | 35 +--
python/pyspark/sql/connect/client/retries.py | 293 +++++++++++++++++++++
.../sql/tests/connect/client/test_client.py | 53 ++--
.../sql/tests/connect/test_connect_basic.py | 203 +++++++-------
5 files changed, 468 insertions(+), 334 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 7eafcc501f5f..b98de0f9ceea 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -19,7 +19,6 @@ __all__ = [
"SparkConnectClient",
]
-
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
@@ -27,12 +26,9 @@ check_dependencies(__name__)
import threading
import os
import platform
-import random
-import time
import urllib.parse
import uuid
import sys
-from types import TracebackType
from typing import (
Iterable,
Iterator,
@@ -45,9 +41,6 @@ from typing import (
Set,
NoReturn,
cast,
- Callable,
- Generator,
- Type,
TYPE_CHECKING,
Sequence,
)
@@ -66,10 +59,8 @@ from pyspark.version import __version__
from pyspark.resource.information import ResourceInformation
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.client.logging import logger
-from pyspark.sql.connect.client.reattach import (
- ExecutePlanResponseReattachableIterator,
- RetryException,
-)
+from pyspark.sql.connect.client.reattach import
ExecutePlanResponseReattachableIterator
+from pyspark.sql.connect.client.retries import RetryPolicy, Retrying,
DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto,
proto_to_storage_level
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
@@ -555,38 +546,6 @@ class SparkConnectClient(object):
Conceptually the remote spark session that communicates with the server
"""
- @classmethod
- def retry_exception(cls, e: Exception) -> bool:
- """
- Helper function that is used to identify if an exception thrown by the
server
- can be retried or not.
-
- Parameters
- ----------
- e : Exception
- The GRPC error as received from the server. Typed as Exception,
because other exception
- thrown during client processing can be passed here as well.
-
- Returns
- -------
- True if the exception can be retried, False otherwise.
-
- """
- if not isinstance(e, grpc.RpcError):
- return False
-
- if e.code() in [grpc.StatusCode.INTERNAL]:
- msg = str(e)
-
- # This error happens if another RPC preempts this RPC.
- if "INVALID_CURSOR.DISCONNECTED" in msg:
- return True
-
- if e.code() == grpc.StatusCode.UNAVAILABLE:
- return True
-
- return False
-
def __init__(
self,
connection: Union[str, ChannelBuilder],
@@ -634,7 +593,9 @@ class SparkConnectClient(object):
else ChannelBuilder(connection, channel_options)
)
self._user_id = None
- self._retry_policy = {
+ self._retry_policies: List[RetryPolicy] = []
+
+ default_policy_args = {
# Please synchronize changes here with Scala side
# GrpcRetryHandler.scala
#
@@ -648,7 +609,10 @@ class SparkConnectClient(object):
"min_jitter_threshold": 2000,
}
if retry_policy:
- self._retry_policy.update(retry_policy)
+ default_policy_args.update(retry_policy)
+
+ default_policy = DefaultPolicy(**default_policy_args)
+ self.set_retry_policies([default_policy])
if self._builder.session_id is None:
# Generate a unique session ID for this client. This UUID must be
unique to allow
@@ -680,9 +644,7 @@ class SparkConnectClient(object):
self._server_session_id: Optional[str] = None
def _retrying(self) -> "Retrying":
- return Retrying(
- can_retry=SparkConnectClient.retry_exception, **self._retry_policy
# type: ignore
- )
+ return Retrying(self._retry_policies)
def disable_reattachable_execute(self) -> "SparkConnectClient":
self._use_reattachable_execute = False
@@ -692,6 +654,20 @@ class SparkConnectClient(object):
self._use_reattachable_execute = True
return self
+ def set_retry_policies(self, policies: Iterable[RetryPolicy]) -> None:
+ """
+ Sets list of policies to be used for retries.
+ I.e. set_retry_policies([DefaultPolicy(), CustomPolicy()]).
+
+ """
+ self._retry_policies = list(policies)
+
+ def get_retry_policies(self) -> List[RetryPolicy]:
+ """
+ Return list of currently used policies
+ """
+ return list(self._retry_policies)
+
def register_udf(
self,
function: Any,
@@ -1152,7 +1128,7 @@ class SparkConnectClient(object):
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
- req, self._stub, self._retry_policy,
self._builder.metadata()
+ req, self._stub, self._retrying, self._builder.metadata()
)
for b in generator:
handle_response(b)
@@ -1262,7 +1238,7 @@ class SparkConnectClient(object):
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
- req, self._stub, self._retry_policy,
self._builder.metadata()
+ req, self._stub, self._retrying, self._builder.metadata()
)
for b in generator:
yield from handle_response(b)
@@ -1641,145 +1617,3 @@ class SparkConnectClient(object):
else:
# Update the server side session ID.
self._server_session_id = response.server_side_session_id
-
-
-class RetryState:
- """
- Simple state helper that captures the state between retries of the
exceptions. It
- keeps track of the last exception thrown and how many in total. When the
task
- finishes successfully done() returns True.
- """
-
- def __init__(self) -> None:
- self._exception: Optional[BaseException] = None
- self._done = False
- self._count = 0
-
- def set_exception(self, exc: BaseException) -> None:
- self._exception = exc
- self._count += 1
-
- def throw(self) -> None:
- raise self.exception()
-
- def exception(self) -> BaseException:
- if self._exception is None:
- raise RuntimeError("No exception is set")
- return self._exception
-
- def set_done(self) -> None:
- self._done = True
-
- def count(self) -> int:
- return self._count
-
- def done(self) -> bool:
- return self._done
-
-
-class AttemptManager:
- """
- Simple ContextManager that is used to capture the exception thrown inside
the context.
- """
-
- def __init__(self, check: Callable[..., bool], retry_state: RetryState) ->
None:
- self._retry_state = retry_state
- self._can_retry = check
-
- def __enter__(self) -> None:
- pass
-
- def __exit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType],
- ) -> Optional[bool]:
- if isinstance(exc_val, BaseException):
- # Swallow the exception.
- if self._can_retry(exc_val) or isinstance(exc_val, RetryException):
- self._retry_state.set_exception(exc_val)
- return True
- # Bubble up the exception.
- return False
- else:
- self._retry_state.set_done()
- return None
-
- def is_first_try(self) -> bool:
- return self._retry_state._count == 0
-
-
-class Retrying:
- """
- This helper class is used as a generator together with a context manager to
- allow retrying exceptions in particular code blocks. The Retrying can be
configured
- with a lambda function that is can be filtered what kind of exceptions
should be
- retried.
-
- In addition, there are several parameters that are used to configure the
exponential
- backoff behavior.
-
- An example to use this class looks like this:
-
- .. code-block:: python
-
- for attempt in Retrying(can_retry=lambda x: isinstance(x,
TransientError)):
- with attempt:
- # do the work.
-
- """
-
- def __init__(
- self,
- max_retries: int,
- initial_backoff: int,
- max_backoff: int,
- backoff_multiplier: float,
- jitter: int,
- min_jitter_threshold: int,
- can_retry: Callable[..., bool] = lambda x: True,
- sleep: Callable[[float], None] = time.sleep,
- ) -> None:
- self._can_retry = can_retry
- self._max_retries = max_retries
- self._initial_backoff = initial_backoff
- self._max_backoff = max_backoff
- self._backoff_multiplier = backoff_multiplier
- self._jitter = jitter
- self._min_jitter_threshold = min_jitter_threshold
- self._sleep = sleep
-
- def __iter__(self) -> Generator[AttemptManager, None, None]:
- """
- Generator function to wrap the exception producing code block.
-
- Returns
- -------
- A generator that yields the current attempt.
- """
- retry_state = RetryState()
- next_backoff: float = self._initial_backoff
-
- if self._max_retries < 0:
- raise ValueError("Can't have negative number of retries")
-
- while not retry_state.done() and retry_state.count() <=
self._max_retries:
- # Do backoff
- if retry_state.count() > 0:
- # Randomize backoff for this iteration
- backoff = next_backoff
- next_backoff = min(self._max_backoff, next_backoff *
self._backoff_multiplier)
-
- if backoff >= self._min_jitter_threshold:
- backoff += random.uniform(0, self._jitter)
-
- logger.debug(
- f"Will retry call after {backoff} ms sleep (error:
{retry_state.exception()})"
- )
- self._sleep(backoff / 1000.0)
- yield AttemptManager(self._can_retry, retry_state)
-
- if not retry_state.done():
- # Exceeded number of retries, throw last exception we had
- retry_state.throw()
diff --git a/python/pyspark/sql/connect/client/reattach.py
b/python/pyspark/sql/connect/client/reattach.py
index 6addb5bd2c65..9fa0f2541337 100644
--- a/python/pyspark/sql/connect/client/reattach.py
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from pyspark.sql.connect.client.retries import Retrying, RetryException
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
@@ -22,7 +23,7 @@ from threading import RLock
import warnings
import uuid
from collections.abc import Generator
-from typing import Optional, Dict, Any, Iterator, Iterable, Tuple, Callable,
cast, Type, ClassVar
+from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast,
Type, ClassVar
from multiprocessing.pool import ThreadPool
import os
@@ -83,12 +84,12 @@ class ExecutePlanResponseReattachableIterator(Generator):
self,
request: pb2.ExecutePlanRequest,
stub: grpc_lib.SparkConnectServiceStub,
- retry_policy: Dict[str, Any],
+ retrying: Callable[[], Retrying],
metadata: Iterable[Tuple[str, str]],
):
ExecutePlanResponseReattachableIterator._initialize_pool_if_necessary()
self._request = request
- self._retry_policy = retry_policy
+ self._retrying = retrying
if request.operation_id:
self._operation_id = request.operation_id
else:
@@ -143,17 +144,12 @@ class ExecutePlanResponseReattachableIterator(Generator):
return ret
def _has_next(self) -> bool:
- from pyspark.sql.connect.client.core import SparkConnectClient
- from pyspark.sql.connect.client.core import Retrying
-
if self._result_complete:
# After response complete response
return False
else:
try:
- for attempt in Retrying(
- can_retry=SparkConnectClient.retry_exception,
**self._retry_policy
- ):
+ for attempt in self._retrying():
with attempt:
if self._current is None:
try:
@@ -199,16 +195,11 @@ class ExecutePlanResponseReattachableIterator(Generator):
if self._result_complete:
return
- from pyspark.sql.connect.client.core import SparkConnectClient
- from pyspark.sql.connect.client.core import Retrying
-
request = self._create_release_execute_request(until_response_id)
def target() -> None:
try:
- for attempt in Retrying(
- can_retry=SparkConnectClient.retry_exception,
**self._retry_policy
- ):
+ for attempt in self._retrying():
with attempt:
self._stub.ReleaseExecute(request,
metadata=self._metadata)
except Exception as e:
@@ -228,16 +219,11 @@ class ExecutePlanResponseReattachableIterator(Generator):
if self._result_complete:
return
- from pyspark.sql.connect.client.core import SparkConnectClient
- from pyspark.sql.connect.client.core import Retrying
-
request = self._create_release_execute_request(None)
def target() -> None:
try:
- for attempt in Retrying(
- can_retry=SparkConnectClient.retry_exception,
**self._retry_policy
- ):
+ for attempt in self._retrying():
with attempt:
self._stub.ReleaseExecute(request,
metadata=self._metadata)
except Exception as e:
@@ -331,10 +317,3 @@ class ExecutePlanResponseReattachableIterator(Generator):
def __del__(self) -> None:
return self.close()
-
-
-class RetryException(Exception):
- """
- An exception that can be thrown upstream when inside retry and which will
be retryable
- regardless of policy.
- """
diff --git a/python/pyspark/sql/connect/client/retries.py
b/python/pyspark/sql/connect/client/retries.py
new file mode 100644
index 000000000000..6aa959e09b5b
--- /dev/null
+++ b/python/pyspark/sql/connect/client/retries.py
@@ -0,0 +1,293 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import grpc
+import random
+import time
+import typing
+from typing import Optional, Callable, Generator, List, Type
+from types import TracebackType
+from pyspark.sql.connect.client.logging import logger
+
+"""
+This module contains retry system. The system is designed to be
+significantly customizable.
+
+A key aspect of retries is RetryPolicy class, describing a single policy.
+There can be more than one policy defined at the same time. Each policy
+determines which error types it can retry and how exactly.
+
+For instance, networking errors should likely be retried differently that
+remote resource being unavailable.
+
+Given a sequence of policies, retry logic applies all of them in sequential
+order, keeping track of different policies budgets.
+"""
+
+
+class RetryPolicy:
+ """
+ Describes key aspects of RetryPolicy.
+
+ It's advised that different policies are implemented as different
subclasses.
+ """
+
+ def __init__(
+ self,
+ max_retries: Optional[int] = None,
+ initial_backoff: int = 1000,
+ max_backoff: Optional[int] = None,
+ backoff_multiplier: float = 1.0,
+ jitter: int = 0,
+ min_jitter_threshold: int = 0,
+ ):
+ self.max_retries = max_retries
+ self.initial_backoff = initial_backoff
+ self.max_backoff = max_backoff
+ self.backoff_multiplier = backoff_multiplier
+ self.jitter = jitter
+ self.min_jitter_threshold = min_jitter_threshold
+ self._name = self.__class__.__name__
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ def can_retry(self, exception: BaseException) -> bool:
+ return False
+
+ def to_state(self) -> "RetryPolicyState":
+ return RetryPolicyState(self)
+
+
+class RetryPolicyState:
+ """
+ This class represents stateful part of the specific policy.
+ """
+
+ def __init__(self, policy: RetryPolicy):
+ self._policy = policy
+
+ # Will allow attempts [0, self._policy.max_retries)
+ self._attempt = 0
+ self._next_wait: float = self._policy.initial_backoff
+
+ @property
+ def policy(self) -> RetryPolicy:
+ return self._policy
+
+ @property
+ def name(self) -> str:
+ return self.policy.name
+
+ def can_retry(self, exception: BaseException) -> bool:
+ return self.policy.can_retry(exception)
+
+ def next_attempt(self) -> Optional[int]:
+ """
+ Returns
+ -------
+ Randomized time (in milliseconds) to wait until this attempt
+ or None if this policy doesn't allow more retries.
+ """
+
+ if self.policy.max_retries is not None and self._attempt >=
self.policy.max_retries:
+ # No more retries under this policy
+ return None
+
+ self._attempt += 1
+ wait_time = self._next_wait
+
+ # Calculate future backoff
+ if self.policy.max_backoff is not None:
+ self._next_wait = min(
+ float(self.policy.max_backoff), wait_time *
self.policy.backoff_multiplier
+ )
+
+ # Jitter current backoff, after the future backoff was computed
+ if wait_time >= self.policy.min_jitter_threshold:
+ wait_time += random.uniform(0, self.policy.jitter)
+
+ # Round to whole number of milliseconds
+ return int(wait_time)
+
+
+class AttemptManager:
+ """
+ Simple ContextManager that is used to capture the exception thrown inside
the context.
+ """
+
+ def __init__(self, retrying: "Retrying") -> None:
+ self._retrying = retrying
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exception: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> Optional[bool]:
+ if isinstance(exception, BaseException):
+ # Swallow the exception.
+ if self._retrying.accept_exception(exception):
+ return True
+ # Bubble up the exception.
+ return False
+ else:
+ self._retrying.accept_succeeded()
+ return None
+
+
+class Retrying:
+ """
+ This class is a point of entry into the retry logic.
+ The class accepts a list of retry policies and applies them in given order.
+ The first policy accepting an exception will be used.
+
+ The usage of the class should be as follows:
+ for attempt in Retrying(...):
+ with attempt:
+ Do something that can throw exception
+
+ In case error is considered retriable, it would be retried based on
policies, and
+ RetriesExceeded will be raised if the retries limit would exceed.
+
+ Exceptions not considered retriable will be passed through transparently.
+ """
+
+ def __init__(
+ self,
+ policies: typing.Union[RetryPolicy, typing.Iterable[RetryPolicy]],
+ sleep: Callable[[float], None] = time.sleep,
+ ) -> None:
+ if isinstance(policies, RetryPolicy):
+ policies = [policies]
+ self._policies: List[RetryPolicyState] = [policy.to_state() for policy
in policies]
+ self._sleep = sleep
+
+ self._exception: Optional[BaseException] = None
+ self._done = False
+
+ def can_retry(self, exception: BaseException) -> bool:
+ return any(policy.can_retry(exception) for policy in self._policies)
+
+ def accept_exception(self, exception: BaseException) -> bool:
+ if self.can_retry(exception):
+ self._exception = exception
+ return True
+ return False
+
+ def accept_succeeded(self) -> None:
+ self._done = True
+
+ def _last_exception(self) -> BaseException:
+ if self._exception is None:
+ raise RuntimeError("No active exception")
+ return self._exception
+
+ def _wait(self) -> None:
+ exception = self._last_exception()
+
+ # Attempt to find a policy to wait with
+
+ for policy in self._policies:
+ if not policy.can_retry(exception):
+ continue
+
+ wait_time = policy.next_attempt()
+ if wait_time is not None:
+ logger.debug(
+ f"Got error: {repr(exception)}. "
+ + f"Will retry after {wait_time} ms (policy:
{policy.name})"
+ )
+
+ self._sleep(wait_time / 1000)
+ return
+
+ # Exceeded retries
+ logger.debug(f"Given up on retrying. error: {repr(exception)}")
+ raise RetriesExceeded from exception
+
+ def __iter__(self) -> Generator[AttemptManager, None, None]:
+ """
+ Generator function to wrap the exception producing code block.
+
+ Returns
+ -------
+ A generator that yields the current attempt.
+ """
+
+ # First attempt is free, no need to do waiting.
+ yield AttemptManager(self)
+
+ while not self._done:
+ self._wait()
+ yield AttemptManager(self)
+
+
+class RetryException(Exception):
+ """
+ An exception that can be thrown upstream when inside retry and which is
always retryable
+ """
+
+
+class DefaultPolicy(RetryPolicy):
+ def __init__(self, **kwargs): # type: ignore[no-untyped-def]
+ super().__init__(**kwargs)
+
+ def can_retry(self, e: BaseException) -> bool:
+ """
+ Helper function that is used to identify if an exception thrown by the
server
+ can be retried or not.
+
+ Parameters
+ ----------
+ e : Exception
+ The GRPC error as received from the server. Typed as Exception,
because other exception
+ thrown during client processing can be passed here as well.
+
+ Returns
+ -------
+ True if the exception can be retried, False otherwise.
+
+ """
+ if isinstance(e, RetryException):
+ return True
+
+ if not isinstance(e, grpc.RpcError):
+ return False
+
+ if e.code() in [grpc.StatusCode.INTERNAL]:
+ msg = str(e)
+
+ # This error happens if another RPC preempts this RPC.
+ if "INVALID_CURSOR.DISCONNECTED" in msg:
+ return True
+
+ if e.code() == grpc.StatusCode.UNAVAILABLE:
+ return True
+
+ return False
+
+
+class RetriesExceeded(Exception):
+ """
+ Represents an exception which is considered retriable, but retry limits
+ were exceeded
+ """
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py
b/python/pyspark/sql/tests/connect/client/test_client.py
index fb137662f42f..580ebc3965bb 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -28,11 +28,13 @@ if should_test_connect:
import pandas as pd
import pyarrow as pa
from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
- from pyspark.sql.connect.client.core import Retrying
- from pyspark.sql.connect.client.reattach import (
+ from pyspark.sql.connect.client.retries import (
+ Retrying,
+ DefaultPolicy,
RetryException,
- ExecutePlanResponseReattachableIterator,
+ RetriesExceeded,
)
+ from pyspark.sql.connect.client.reattach import
ExecutePlanResponseReattachableIterator
import pyspark.sql.connect.proto as proto
@@ -107,17 +109,25 @@ class SparkConnectClientTestCase(unittest.TestCase):
total_sleep += t
try:
- for attempt in Retrying(
- can_retry=SparkConnectClient.retry_exception, sleep=sleep,
**client._retry_policy
- ):
+ for attempt in Retrying(client._retry_policies, sleep=sleep):
with attempt:
raise RetryException()
- except RetryException:
+ except RetriesExceeded:
pass
# tolerated at least 10 mins of fails
self.assertGreaterEqual(total_sleep, 600)
+ def test_retry_client_unit(self):
+ client = SparkConnectClient("sc://foo/;token=bar")
+
+ policyA = TestPolicy()
+ policyB = DefaultPolicy()
+
+ client.set_retry_policies([policyA, policyB])
+
+ self.assertEqual(client.get_retry_policies(), [policyA, policyB])
+
def test_channel_builder_with_session(self):
dummy = str(uuid.uuid4())
chan = ChannelBuilder(f"sc://foo/;session_id={dummy}")
@@ -125,18 +135,23 @@ class SparkConnectClientTestCase(unittest.TestCase):
self.assertEqual(client._session_id, chan.session_id)
+class TestPolicy(DefaultPolicy):
+ def __init__(self):
+ super().__init__(
+ max_retries=3,
+ backoff_multiplier=4.0,
+ initial_backoff=10,
+ max_backoff=10,
+ jitter=10,
+ min_jitter_threshold=10,
+ )
+
+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectClientReattachTestCase(unittest.TestCase):
def setUp(self) -> None:
self.request = proto.ExecutePlanRequest()
- self.policy = {
- "max_retries": 3,
- "backoff_multiplier": 4.0,
- "initial_backoff": 10,
- "max_backoff": 10,
- "jitter": 10,
- "min_jitter_threshold": 10,
- }
+ self.retrying = lambda: Retrying(TestPolicy())
self.response = proto.ExecutePlanResponse(
response_id="1",
)
@@ -153,7 +168,7 @@ class SparkConnectClientReattachTestCase(unittest.TestCase):
def test_basic_flow(self):
stub = self._stub_with([self.response, self.finished])
- ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.policy, [])
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.retrying, [])
for b in ite:
pass
@@ -171,7 +186,7 @@ class SparkConnectClientReattachTestCase(unittest.TestCase):
stub = self._stub_with([self.response, fatal])
with self.assertRaises(TestException):
- ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.policy, [])
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.retrying, [])
for b in ite:
pass
@@ -190,7 +205,7 @@ class SparkConnectClientReattachTestCase(unittest.TestCase):
stub = self._stub_with(
[self.response, non_fatal], [self.response, self.response,
self.finished]
)
- ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.policy, [])
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.retrying, [])
for b in ite:
pass
@@ -216,7 +231,7 @@ class SparkConnectClientReattachTestCase(unittest.TestCase):
stub = self._stub_with(
[self.response, non_fatal], [self.response, non_fatal,
self.response, self.finished]
)
- ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.policy, [])
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.retrying, [])
for b in ite:
pass
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 7a224d68219b..e926eb835a80 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -34,6 +34,7 @@ from pyspark.errors import (
)
from pyspark.errors.exceptions.base import SessionNotSameException
from pyspark.sql import SparkSession as PySparkSession, Row
+from pyspark.sql.connect.client.retries import RetryPolicy, RetriesExceeded
from pyspark.sql.types import (
StructType,
StructField,
@@ -3484,128 +3485,140 @@ class
SparkConnectSessionWithOptionsTest(unittest.TestCase):
self.assertEqual(self.spark.conf.get("integer"), "1")
+class TestError(grpc.RpcError, Exception):
+ def __init__(self, code: grpc.StatusCode):
+ self._code = code
+
+ def code(self):
+ return self._code
+
+
+class TestPolicy(RetryPolicy):
+ # Put a small value for initial backoff so that tests don't spend
+ # Time waiting
+ def __init__(self, initial_backoff=10, **kwargs):
+ super().__init__(initial_backoff=initial_backoff, **kwargs)
+
+ def can_retry(self, exception: BaseException):
+ return isinstance(exception, TestError)
+
+
+class TestPolicySpecificError(TestPolicy):
+ def __init__(self, specific_code: grpc.StatusCode, **kwargs):
+ super().__init__(**kwargs)
+ self.specific_code = specific_code
+
+ def can_retry(self, exception: BaseException):
+ return exception.code() == self.specific_code
+
+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class ClientTests(unittest.TestCase):
- def test_retry_error_handling(self):
- # Helper class for wrapping the test.
- class TestError(grpc.RpcError, Exception):
- def __init__(self, code: grpc.StatusCode):
- self._code = code
-
- def code(self):
- return self._code
-
- def stub(retries, w, code):
- w["attempts"] += 1
- if w["attempts"] < retries:
- w["raised"] += 1
- raise TestError(code)
+class RetryTests(unittest.TestCase):
+ def setUp(self) -> None:
+ self.call_wrap = defaultdict(int)
+ def stub(self, retries, code):
+ self.call_wrap["attempts"] += 1
+ if self.call_wrap["attempts"] < retries:
+ self.call_wrap["raised"] += 1
+ raise TestError(code)
+
+ def test_simple(self):
# Check that max_retries 1 is only one retry so two attempts.
- call_wrap = defaultdict(int)
- for attempt in Retrying(
- can_retry=lambda x: True,
- max_retries=1,
- backoff_multiplier=1,
- initial_backoff=1,
- max_backoff=10,
- jitter=0,
- min_jitter_threshold=0,
- ):
+ for attempt in Retrying(TestPolicy(max_retries=1)):
with attempt:
- stub(2, call_wrap, grpc.StatusCode.INTERNAL)
+ self.stub(2, grpc.StatusCode.INTERNAL)
- self.assertEqual(2, call_wrap["attempts"])
- self.assertEqual(1, call_wrap["raised"])
+ self.assertEqual(2, self.call_wrap["attempts"])
+ self.assertEqual(1, self.call_wrap["raised"])
+ def test_below_limit(self):
# Check that if we have less than 4 retries all is ok.
- call_wrap = defaultdict(int)
- for attempt in Retrying(
- can_retry=lambda x: True,
- max_retries=4,
- backoff_multiplier=1,
- initial_backoff=1,
- max_backoff=10,
- jitter=0,
- min_jitter_threshold=0,
- ):
+ for attempt in Retrying(TestPolicy(max_retries=4)):
with attempt:
- stub(2, call_wrap, grpc.StatusCode.INTERNAL)
+ self.stub(2, grpc.StatusCode.INTERNAL)
- self.assertTrue(call_wrap["attempts"] < 4)
- self.assertEqual(call_wrap["raised"], 1)
+ self.assertLess(self.call_wrap["attempts"], 4)
+ self.assertEqual(self.call_wrap["raised"], 1)
+ def test_exceed_retries(self):
# Exceed the retries.
- call_wrap = defaultdict(int)
- with self.assertRaises(TestError):
- for attempt in Retrying(
- can_retry=lambda x: True,
- max_retries=2,
- max_backoff=50,
- backoff_multiplier=1,
- initial_backoff=50,
- jitter=0,
- min_jitter_threshold=0,
- ):
+ with self.assertRaises(RetriesExceeded):
+ for attempt in Retrying(TestPolicy(max_retries=2)):
with attempt:
- stub(5, call_wrap, grpc.StatusCode.INTERNAL)
+ self.stub(5, grpc.StatusCode.INTERNAL)
- self.assertTrue(call_wrap["attempts"] < 5)
- self.assertEqual(call_wrap["raised"], 3)
+ self.assertLess(self.call_wrap["attempts"], 5)
+ self.assertEqual(self.call_wrap["raised"], 3)
+ def test_throw_not_retriable_error(self):
+ with self.assertRaises(ValueError):
+ for attempt in Retrying(TestPolicy(max_retries=2)):
+ with attempt:
+ raise ValueError
+
+ def test_specific_exception(self):
# Check that only specific exceptions are retried.
# Check that if we have less than 4 retries all is ok.
- call_wrap = defaultdict(int)
- for attempt in Retrying(
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
- max_retries=4,
- backoff_multiplier=1,
- initial_backoff=1,
- max_backoff=10,
- jitter=0,
- min_jitter_threshold=0,
- ):
+ policy = TestPolicySpecificError(max_retries=4,
specific_code=grpc.StatusCode.UNAVAILABLE)
+
+ for attempt in Retrying(policy):
with attempt:
- stub(2, call_wrap, grpc.StatusCode.UNAVAILABLE)
+ self.stub(2, grpc.StatusCode.UNAVAILABLE)
- self.assertTrue(call_wrap["attempts"] < 4)
- self.assertEqual(call_wrap["raised"], 1)
+ self.assertLess(self.call_wrap["attempts"], 4)
+ self.assertEqual(self.call_wrap["raised"], 1)
+ def test_specific_exception_exceed_retries(self):
# Exceed the retries.
- call_wrap = defaultdict(int)
- with self.assertRaises(TestError):
- for attempt in Retrying(
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
- max_retries=2,
- max_backoff=50,
- backoff_multiplier=1,
- initial_backoff=50,
- jitter=0,
- min_jitter_threshold=0,
- ):
+ policy = TestPolicySpecificError(max_retries=2,
specific_code=grpc.StatusCode.UNAVAILABLE)
+ with self.assertRaises(RetriesExceeded):
+ for attempt in Retrying(policy):
with attempt:
- stub(5, call_wrap, grpc.StatusCode.UNAVAILABLE)
+ self.stub(5, grpc.StatusCode.UNAVAILABLE)
- self.assertTrue(call_wrap["attempts"] < 4)
- self.assertEqual(call_wrap["raised"], 3)
+ self.assertLess(self.call_wrap["attempts"], 4)
+ self.assertEqual(self.call_wrap["raised"], 3)
+ def test_rejected_by_policy(self):
# Test that another error is always thrown.
- call_wrap = defaultdict(int)
+ policy = TestPolicySpecificError(max_retries=4,
specific_code=grpc.StatusCode.UNAVAILABLE)
+
with self.assertRaises(TestError):
- for attempt in Retrying(
- can_retry=lambda x: x.code() == grpc.StatusCode.UNAVAILABLE,
- max_retries=4,
- backoff_multiplier=1,
- initial_backoff=1,
- max_backoff=10,
- jitter=0,
- min_jitter_threshold=0,
- ):
+ for attempt in Retrying(policy):
+ with attempt:
+ self.stub(5, grpc.StatusCode.INTERNAL)
+
+ self.assertEqual(self.call_wrap["attempts"], 1)
+ self.assertEqual(self.call_wrap["raised"], 1)
+
+ def test_multiple_policies(self):
+ policy1 = TestPolicySpecificError(max_retries=2,
specific_code=grpc.StatusCode.UNAVAILABLE)
+ policy2 = TestPolicySpecificError(max_retries=4,
specific_code=grpc.StatusCode.INTERNAL)
+
+ # Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors
+
+ error_suply = iter([grpc.StatusCode.UNAVAILABLE] * 2 +
[grpc.StatusCode.INTERNAL] * 4)
+
+ for attempt in Retrying([policy1, policy2]):
+ with attempt:
+ error = next(error_suply, None)
+ if error:
+ raise TestError(error)
+
+ self.assertEqual(next(error_suply, None), None)
+
+ def test_multiple_policies_exceed(self):
+ policy1 = TestPolicySpecificError(max_retries=2,
specific_code=grpc.StatusCode.INTERNAL)
+ policy2 = TestPolicySpecificError(max_retries=4,
specific_code=grpc.StatusCode.INTERNAL)
+
+ with self.assertRaises(RetriesExceeded):
+ for attempt in Retrying([policy1, policy2]):
with attempt:
- stub(5, call_wrap, grpc.StatusCode.INTERNAL)
+ self.stub(10, grpc.StatusCode.INTERNAL)
- self.assertEqual(call_wrap["attempts"], 1)
- self.assertEqual(call_wrap["raised"], 1)
+ self.assertEqual(self.call_wrap["attempts"], 7)
+ self.assertEqual(self.call_wrap["raised"], 7)
@unittest.skipIf(not should_test_connect, connect_requirement_message)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]