itholic commented on code in PR #42399:
URL: https://github.com/apache/spark/pull/42399#discussion_r1290821732


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala:
##########
@@ -148,37 +151,58 @@ private[client] object GrpcRetryHandler extends Logging {
 
   /**
    * Retries the given function with exponential backoff according to the 
client's retryPolicy.
+   *
    * @param retryPolicy
    *   The retry policy
+   * @param sleep
+   *   The function which sleeps (takes number of milliseconds to sleep)
    * @param fn
    *   The function to retry.
-   * @param currentRetryNum
-   *   Current number of retries.
    * @tparam T
    *   The return type of the function.
    * @return
    *   The result of the function.
    */
-  @tailrec final def retry[T](retryPolicy: RetryPolicy)(fn: => T, 
currentRetryNum: Int = 0): T = {
-    if (currentRetryNum > retryPolicy.maxRetries) {
-      throw new IllegalArgumentException(
-        s"The number of retries ($currentRetryNum) must not exceed " +
-          s"the maximum number of retires (${retryPolicy.maxRetries}).")
-    }
-    try {
-      return fn
-    } catch {
-      case NonFatal(e)
-          if (retryPolicy.canRetry(e) || e.isInstanceOf[RetryException])
-            && currentRetryNum < retryPolicy.maxRetries =>
-        logWarning(
-          s"Non fatal error during RPC execution: $e, " +
-            s"retrying (currentRetryNum=$currentRetryNum)")
-        Thread.sleep(
-          (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
-            .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+  final def retry[T](retryPolicy: RetryPolicy, sleep: Long => Unit = 
Thread.sleep)(
+      fn: => T): T = {
+    var currentRetryNum = 0
+    var lastException: Throwable = null
+    var nextBackoff: Duration = retryPolicy.initialBackoff
+
+    while (currentRetryNum <= retryPolicy.maxRetries) {
+      if (currentRetryNum != 0) {
+        var currentBackoff = nextBackoff
+        if (currentBackoff >= retryPolicy.minJitterThreshold) {
+          currentBackoff += Random.nextDouble() * retryPolicy.jitter
+        }
+        nextBackoff = nextBackoff * retryPolicy.backoffMultiplier min 
retryPolicy.maxBackoff
+
+        sleep(currentBackoff.toMillis)
+      }
+
+      try {
+        return fn
+      } catch {
+        case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < 
retryPolicy.maxRetries =>
+          currentRetryNum += 1
+          lastException = e
+
+          if (currentRetryNum <= retryPolicy.maxRetries) {
+            logWarning(
+              s"Non-Fatal error during RPC execution: $e, " +
+                s"retrying (currentRetryNum=$currentRetryNum)")
+          } else {
+            logWarning(
+              s"Non-Fatal error during RPC execution: $e, " +
+                s"exceeded retries (currentRetryNum=$currentRetryNum)")
+          }
+      }
     }
-    retry(retryPolicy)(fn, currentRetryNum + 1)
+
+    throw new SparkException(
+      errorClass = "EXCEED_RETRY",

Review Comment:
   Just had a brief discussion on offline, it seems that it doesn't matter even 
if the error class name is duplicated because it can be distinguished by the 
package name since we basically manage JVM and Python errors in different ways. 
Sorry for the confusion 🙏 
   
   So, we can just keep using `EXCEED_RETRY` if you believe that it is the name 
of the error class `EXCEED_RETRY` that best expresses the currently occurring 
error (however, I think the error message still needs to be modified. Using the 
Python error message as it is does not seem appropriate for this situation)



##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala:
##########
@@ -148,37 +151,58 @@ private[client] object GrpcRetryHandler extends Logging {
 
   /**
    * Retries the given function with exponential backoff according to the 
client's retryPolicy.
+   *
    * @param retryPolicy
    *   The retry policy
+   * @param sleep
+   *   The function which sleeps (takes number of milliseconds to sleep)
    * @param fn
    *   The function to retry.
-   * @param currentRetryNum
-   *   Current number of retries.
    * @tparam T
    *   The return type of the function.
    * @return
    *   The result of the function.
    */
-  @tailrec final def retry[T](retryPolicy: RetryPolicy)(fn: => T, 
currentRetryNum: Int = 0): T = {
-    if (currentRetryNum > retryPolicy.maxRetries) {
-      throw new IllegalArgumentException(
-        s"The number of retries ($currentRetryNum) must not exceed " +
-          s"the maximum number of retires (${retryPolicy.maxRetries}).")
-    }
-    try {
-      return fn
-    } catch {
-      case NonFatal(e)
-          if (retryPolicy.canRetry(e) || e.isInstanceOf[RetryException])
-            && currentRetryNum < retryPolicy.maxRetries =>
-        logWarning(
-          s"Non fatal error during RPC execution: $e, " +
-            s"retrying (currentRetryNum=$currentRetryNum)")
-        Thread.sleep(
-          (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
-            .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+  final def retry[T](retryPolicy: RetryPolicy, sleep: Long => Unit = 
Thread.sleep)(
+      fn: => T): T = {
+    var currentRetryNum = 0
+    var lastException: Throwable = null
+    var nextBackoff: Duration = retryPolicy.initialBackoff
+
+    while (currentRetryNum <= retryPolicy.maxRetries) {
+      if (currentRetryNum != 0) {
+        var currentBackoff = nextBackoff
+        if (currentBackoff >= retryPolicy.minJitterThreshold) {
+          currentBackoff += Random.nextDouble() * retryPolicy.jitter
+        }
+        nextBackoff = nextBackoff * retryPolicy.backoffMultiplier min 
retryPolicy.maxBackoff
+
+        sleep(currentBackoff.toMillis)
+      }
+
+      try {
+        return fn
+      } catch {
+        case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < 
retryPolicy.maxRetries =>
+          currentRetryNum += 1
+          lastException = e
+
+          if (currentRetryNum <= retryPolicy.maxRetries) {
+            logWarning(
+              s"Non-Fatal error during RPC execution: $e, " +
+                s"retrying (currentRetryNum=$currentRetryNum)")
+          } else {
+            logWarning(
+              s"Non-Fatal error during RPC execution: $e, " +
+                s"exceeded retries (currentRetryNum=$currentRetryNum)")
+          }
+      }
     }
-    retry(retryPolicy)(fn, currentRetryNum + 1)
+
+    throw new SparkException(
+      errorClass = "EXCEED_RETRY",

Review Comment:
   Just had a brief discussion on offline, it seems that it doesn't matter even 
if the error class name is duplicated because it can be distinguished by the 
package name since we basically manage JVM and Python errors in different ways. 
Sorry for the confusion 🙏 
   
   So, we can just keep using `EXCEED_RETRY` if you believe that it is the name 
of the error class `EXCEED_RETRY` that best expresses the currently occurring 
error (however, I think the error message still needs to be modified. Using the 
Python error message as it is does not seem appropriate for this situation)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to