juliuszsompolski commented on code in PR #42399:
URL: https://github.com/apache/spark/pull/42399#discussion_r1291342200


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala:
##########
@@ -148,37 +150,62 @@ private[client] object GrpcRetryHandler extends Logging {
 
   /**
    * Retries the given function with exponential backoff according to the 
client's retryPolicy.
+   *
    * @param retryPolicy
    *   The retry policy
+   * @param sleep
+   *   The function which sleeps (takes number of milliseconds to sleep)
    * @param fn
    *   The function to retry.
-   * @param currentRetryNum
-   *   Current number of retries.
    * @tparam T
    *   The return type of the function.
    * @return
    *   The result of the function.
    */
-  @tailrec final def retry[T](retryPolicy: RetryPolicy)(fn: => T, 
currentRetryNum: Int = 0): T = {
-    if (currentRetryNum > retryPolicy.maxRetries) {
-      throw new IllegalArgumentException(
-        s"The number of retries ($currentRetryNum) must not exceed " +
-          s"the maximum number of retires (${retryPolicy.maxRetries}).")
+  final def retry[T](retryPolicy: RetryPolicy, sleep: Long => Unit = 
Thread.sleep)(
+      fn: => T): T = {
+    var currentRetryNum = 0
+    var exceptionList: Seq[Throwable] = Seq.empty
+    var nextBackoff: Duration = retryPolicy.initialBackoff
+
+    if (retryPolicy.maxRetries < 0) {
+      throw new IllegalArgumentException("Can't have negative number of 
retries")
     }
-    try {
-      return fn
-    } catch {
-      case NonFatal(e)
-          if (retryPolicy.canRetry(e) || e.isInstanceOf[RetryException])
-            && currentRetryNum < retryPolicy.maxRetries =>
-        logWarning(
-          s"Non fatal error during RPC execution: $e, " +
-            s"retrying (currentRetryNum=$currentRetryNum)")
-        Thread.sleep(
-          (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
-            .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+
+    while (currentRetryNum <= retryPolicy.maxRetries) {
+      if (currentRetryNum != 0) {

Review Comment:
   optional nit: if this was at the end of the loop instead of start, could 
that avoid the `if (currentRetryNum != 0)`?



##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala:
##########
@@ -148,37 +150,62 @@ private[client] object GrpcRetryHandler extends Logging {
 
   /**
    * Retries the given function with exponential backoff according to the 
client's retryPolicy.
+   *
    * @param retryPolicy
    *   The retry policy
+   * @param sleep
+   *   The function which sleeps (takes number of milliseconds to sleep)
    * @param fn
    *   The function to retry.
-   * @param currentRetryNum
-   *   Current number of retries.
    * @tparam T
    *   The return type of the function.
    * @return
    *   The result of the function.
    */
-  @tailrec final def retry[T](retryPolicy: RetryPolicy)(fn: => T, 
currentRetryNum: Int = 0): T = {
-    if (currentRetryNum > retryPolicy.maxRetries) {
-      throw new IllegalArgumentException(
-        s"The number of retries ($currentRetryNum) must not exceed " +
-          s"the maximum number of retires (${retryPolicy.maxRetries}).")
+  final def retry[T](retryPolicy: RetryPolicy, sleep: Long => Unit = 
Thread.sleep)(
+      fn: => T): T = {
+    var currentRetryNum = 0
+    var exceptionList: Seq[Throwable] = Seq.empty
+    var nextBackoff: Duration = retryPolicy.initialBackoff
+
+    if (retryPolicy.maxRetries < 0) {
+      throw new IllegalArgumentException("Can't have negative number of 
retries")
     }
-    try {
-      return fn
-    } catch {
-      case NonFatal(e)
-          if (retryPolicy.canRetry(e) || e.isInstanceOf[RetryException])
-            && currentRetryNum < retryPolicy.maxRetries =>
-        logWarning(
-          s"Non fatal error during RPC execution: $e, " +
-            s"retrying (currentRetryNum=$currentRetryNum)")
-        Thread.sleep(
-          (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
-            .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+
+    while (currentRetryNum <= retryPolicy.maxRetries) {
+      if (currentRetryNum != 0) {
+        var currentBackoff = nextBackoff
+        nextBackoff = nextBackoff * retryPolicy.backoffMultiplier min 
retryPolicy.maxBackoff
+
+        if (currentBackoff >= retryPolicy.minJitterThreshold) {
+          currentBackoff += Random.nextDouble() * retryPolicy.jitter
+        }
+
+        sleep(currentBackoff.toMillis)
+      }
+
+      try {
+        return fn
+      } catch {
+        case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < 
retryPolicy.maxRetries =>
+          currentRetryNum += 1
+          exceptionList = e +: exceptionList
+
+          if (currentRetryNum <= retryPolicy.maxRetries) {
+            logWarning(
+              s"Non-Fatal error during RPC execution: $e, " +
+                s"retrying (currentRetryNum=$currentRetryNum)")
+          } else {
+            logWarning(
+              s"Non-Fatal error during RPC execution: $e, " +
+                s"exceeded retries (currentRetryNum=$currentRetryNum)")
+          }
+      }
     }
-    retry(retryPolicy)(fn, currentRetryNum + 1)
+
+    val exception = exceptionList.head
+    exceptionList.tail.foreach(exception.addSuppressed(_))
+    throw exception

Review Comment:
   this is nice.



##########
python/pyspark/sql/connect/client/core.py:
##########
@@ -1630,35 +1643,25 @@ def __iter__(self) -> Generator[AttemptManager, None, 
None]:
         A generator that yields the current attempt.
         """
         retry_state = RetryState()
-        while True:
-            # Check if the operation was completed successfully.
-            if retry_state.done():
-                break
-
-            # If the number of retries have exceeded the maximum allowed 
retries.
-            if retry_state.count() > self._max_retries:
-                e = retry_state.exception()
-                if e is not None:
-                    raise e
-                else:
-                    raise PySparkRuntimeError(
-                        error_class="EXCEED_RETRY",
-                        message_parameters={},
-                    )
+        next_backoff: float = self._initial_backoff
+
+        if self._max_retries < 0:
+            raise ValueError("Can't have negative number of retries")
 
+        while not retry_state.done() and retry_state.count() <= 
self._max_retries:
             # Do backoff
             if retry_state.count() > 0:
-                backoff = random.randrange(
-                    0,
-                    int(
-                        min(
-                            self._initial_backoff * self._backoff_multiplier 
** retry_state.count(),
-                            self._max_backoff,
-                        )
-                    ),
-                )
-                logger.debug(f"Retrying call after {backoff} ms sleep")
-                # Pythons sleep takes seconds as arguments.
-                time.sleep(backoff / 1000.0)
+                # Randomize backoff for this iteration
+                backoff = next_backoff
+                next_backoff = min(self._max_backoff, next_backoff * 
self._backoff_multiplier)
+
+                if backoff >= self._min_jitter_threshold:
+                    backoff += random.uniform(0, self._jitter)
 
+                logger.debug(f"Retrying call after {backoff} ms sleep")
+                self._sleep(backoff / 1000.0)
             yield AttemptManager(self._can_retry, retry_state)
+
+        if not retry_state.done():
+            # Exceeded number of retries, throw last exception we had
+            raise retry_state.exception()

Review Comment:
   defer to @HyukjinKwon and @itholic whether we want
   ```
                   e = retry_state.exception()
                   if e is not None:
                       raise e
                   else:
                       raise PySparkRuntimeError(
                           error_class="EXCEED_RETRY",
                           message_parameters={},
                       )
   ```
   here.
   If we don't use it here, we can also just remove it from error_classes.py, 
because this was the only place that uses it.
   Even if we do use throw an exception here, isn't there some more generic 
"pythonic" internal error that we could throw without needing it's own custom 
class? In scala I would throw IllegalStateException here.
   
   In scala, the code is simpler and everything happens inside one retry 
function, so there is no conceivable way that the exception may be missing (as 
exceptionList is a local variable that stays in scope all the time). In python 
the code is more complicated, with multiple objects like AttemptManager, 
RetryState etc. involved, so maybe it is justified to have an internal error 
exception in case something goes wrong?
   
   I defer to @HyukjinKwon and @itholic .



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to