juliuszsompolski commented on code in PR #43757:
URL: https://github.com/apache/spark/pull/43757#discussion_r1392602943


##########
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala:
##########
@@ -151,125 +150,72 @@ private[sql] class GrpcRetryHandler(
 private[sql] object GrpcRetryHandler extends Logging {
 
   /**
-   * Retries the given function with exponential backoff according to the 
client's retryPolicy.
-   *
-   * @param retryPolicy
-   *   The retry policy
+   * Class managing the state of the retrying logic during a single retryable 
block.
+   * @param retryPolicies
+   *   list of policies to apply (in order)
    * @param sleep
-   *   The function which sleeps (takes number of milliseconds to sleep)
+   *   typically Thread.sleep
    * @param fn
-   *   The function to retry.
+   *   the function to compute
    * @tparam T
-   *   The return type of the function.
-   * @return
-   *   The result of the function.
+   *   result of function fn
    */
-  final def retry[T](retryPolicy: RetryPolicy, sleep: Long => Unit = 
Thread.sleep)(
-      fn: => T): T = {
-    var currentRetryNum = 0
-    var exceptionList: Seq[Throwable] = Seq.empty
-    var nextBackoff: Duration = retryPolicy.initialBackoff
-
-    if (retryPolicy.maxRetries < 0) {
-      throw new IllegalArgumentException("Can't have negative number of 
retries")
-    }
-
-    while (currentRetryNum <= retryPolicy.maxRetries) {
-      if (currentRetryNum != 0) {
-        var currentBackoff = nextBackoff
-        nextBackoff = nextBackoff * retryPolicy.backoffMultiplier min 
retryPolicy.maxBackoff
+  class Retrying[T](retryPolicies: Seq[RetryPolicy], sleep: Long => Unit, fn: 
=> T) {
+    private var currentRetryNum: Int = 0
+    private var exceptionList: Seq[Throwable] = Seq.empty
+    private val policies: Seq[RetryPolicy.RetryPolicyState] = 
retryPolicies.map(_.toState)
 
-        if (currentBackoff >= retryPolicy.minJitterThreshold) {
-          currentBackoff += Random.nextDouble() * retryPolicy.jitter
-        }
-
-        sleep(currentBackoff.toMillis)
-      }
+    def canRetry(throwable: Throwable): Boolean = {
+      policies.exists(p => p.canRetry(throwable))
+    }
 
+    def makeAttempt(): Option[T] = {
       try {
-        return fn
+        Some(fn)
       } catch {
-        case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < 
retryPolicy.maxRetries =>
+        case NonFatal(e) if canRetry(e) =>
           currentRetryNum += 1
           exceptionList = e +: exceptionList
-
-          if (currentRetryNum <= retryPolicy.maxRetries) {
-            logWarning(
-              s"Non-Fatal error during RPC execution: $e, " +
-                s"retrying (currentRetryNum=$currentRetryNum)")
-          } else {
-            logWarning(
-              s"Non-Fatal error during RPC execution: $e, " +
-                s"exceeded retries (currentRetryNum=$currentRetryNum)")
-          }
+          None
       }
     }
 
-    val exception = exceptionList.head
-    exceptionList.tail.foreach(exception.addSuppressed(_))
-    throw exception
-  }
+    def waitAfterAttempt(): Unit = {
+      // find policy which will accept this exception
+      val lastException = exceptionList.head

Review Comment:
   Could you actually match for RetryException here, and just retry it without 
waiting here, instead of making it part of default policy?
   We want RetryException to be always retried, even if someone overrides the 
default policy, so it should be handled outside any policy.
   
   The way RetryException is used right now is as a convenient way to throw and 
let control handle the retry. It's ok if RetryException does not count towards 
num_exceptions to be retried, and there is no wait to retry it. We could also 
add this to https://github.com/apache/spark/pull/43800/ for python side.



##########
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala:
##########
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client
+
+import scala.concurrent.duration.{Duration, FiniteDuration}
+import scala.util.Random
+
+import io.grpc.{Status, StatusRuntimeException}
+
+/**
+ * [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]]
+ *
+ * @param maxRetries
+ * Maximum number of retries.
+ * @param initialBackoff
+ * Start value of the exponential backoff (ms).
+ * @param maxBackoff
+ * Maximal value of the exponential backoff (ms).
+ * @param backoffMultiplier
+ * Multiplicative base of the exponential backoff.
+ * @param canRetry
+ * Function that determines whether a retry is to be performed in the event of 
an error.
+ */
+case class RetryPolicy(
+  maxRetries: Option[Int] = None,
+  initialBackoff: FiniteDuration = FiniteDuration(1000, "ms"),
+  maxBackoff: Option[FiniteDuration] = None,
+  backoffMultiplier: Double = 1.0,
+  jitter: FiniteDuration = FiniteDuration(0, "s"),
+  minJitterThreshold: FiniteDuration = FiniteDuration(0, "s"),
+  canRetry: Throwable => Boolean,
+  name: String) {
+
+  def getName: String = name
+
+  def toState: RetryPolicy.RetryPolicyState = new 
RetryPolicy.RetryPolicyState(this)
+}
+
+object RetryPolicy {
+  def defaultPolicy(): RetryPolicy = RetryPolicy(
+    name = "DefaultPolicy",
+    // Please synchronize changes here with Python side:
+    // pyspark/sql/connect/client/core.py

Review Comment:
   changes with https://github.com/apache/spark/pull/43800



##########
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala:
##########
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client
+
+import scala.concurrent.duration.{Duration, FiniteDuration}
+import scala.util.Random
+
+import io.grpc.{Status, StatusRuntimeException}
+
+/**
+ * [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]]
+ *
+ * @param maxRetries
+ * Maximum number of retries.
+ * @param initialBackoff
+ * Start value of the exponential backoff (ms).
+ * @param maxBackoff
+ * Maximal value of the exponential backoff (ms).
+ * @param backoffMultiplier
+ * Multiplicative base of the exponential backoff.
+ * @param canRetry
+ * Function that determines whether a retry is to be performed in the event of 
an error.
+ */
+case class RetryPolicy(
+  maxRetries: Option[Int] = None,
+  initialBackoff: FiniteDuration = FiniteDuration(1000, "ms"),
+  maxBackoff: Option[FiniteDuration] = None,
+  backoffMultiplier: Double = 1.0,
+  jitter: FiniteDuration = FiniteDuration(0, "s"),
+  minJitterThreshold: FiniteDuration = FiniteDuration(0, "s"),
+  canRetry: Throwable => Boolean,
+  name: String) {
+
+  def getName: String = name
+
+  def toState: RetryPolicy.RetryPolicyState = new 
RetryPolicy.RetryPolicyState(this)
+}
+
+object RetryPolicy {
+  def defaultPolicy(): RetryPolicy = RetryPolicy(
+    name = "DefaultPolicy",
+    // Please synchronize changes here with Python side:
+    // pyspark/sql/connect/client/core.py
+    //
+    // Note: these constants are selected so that the maximum tolerated wait 
is guaranteed
+    // to be at least 10 minutes
+    maxRetries = Some(15),
+    initialBackoff = FiniteDuration(50, "ms"),
+    maxBackoff = Some(FiniteDuration(1, "min")),
+    backoffMultiplier = 4.0,
+    jitter = FiniteDuration(500, "ms"),
+    minJitterThreshold = FiniteDuration(2, "s"),
+    canRetry = defaultPolicyRetryException)
+
+  // list of policies to be used by this client
+  def defaultPolicies(): Seq[RetryPolicy] = List(defaultPolicy())
+
+  // represents a state of the specific policy
+  // (how many retries have happened and how much to wait until next one)
+  class RetryPolicyState(val policy: RetryPolicy) {
+    private var numberAttempts = 0
+    private var nextWait: Duration = policy.initialBackoff
+
+    // return waiting time until next attempt, or None if has exceeded max 
retries
+    def nextAttempt(): Option[Duration] = {
+      if (policy.maxRetries.isDefined && numberAttempts >= 
policy.maxRetries.get) {
+        return None
+      }
+
+      numberAttempts += 1
+
+      var currentWait = nextWait
+      nextWait = nextWait * policy.backoffMultiplier
+      if (policy.maxBackoff.isDefined) {
+        nextWait = nextWait min policy.maxBackoff.get
+      }
+
+      if (currentWait >= policy.minJitterThreshold) {
+        currentWait += Random.nextDouble() * policy.jitter
+      }
+
+      Some(currentWait)
+    }
+
+    def canRetry(throwable: Throwable): Boolean = policy.canRetry(throwable)
+
+    def getName: String = policy.getName
+  }
+
+  /**
+   * Default canRetry in [[RetryPolicy]].
+   *
+   * @param e
+   * The exception to check.
+   * @return
+   * true if the exception is a [[StatusRuntimeException]] with code 
UNAVAILABLE.
+   */
+  private[client] def defaultPolicyRetryException(e: Throwable): Boolean = {
+    e match {
+      case _: RetryPolicy.RetryException => true
+      case e: StatusRuntimeException =>
+        val statusCode: Status.Code = e.getStatus.getCode
+
+        if (statusCode == Status.Code.INTERNAL) {
+          val msg: String = e.toString
+
+          // This error happens if another RPC preempts this RPC.
+          if (msg.contains("INVALID_CURSOR.DISCONNECTED")) {
+            return true
+          }
+        }
+
+        if (statusCode == Status.Code.UNAVAILABLE) {
+          return true
+        }
+        false
+      case _ => false
+    }
+  }
+
+  /**
+   * An exception that can be thrown upstream when inside retry and which will 
be always retryable
+   */
+  class RetryException extends Throwable
+}

Review Comment:
   CI will be unhappy about newline



##########
connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala:
##########
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client
+
+import scala.concurrent.duration.{Duration, FiniteDuration}
+import scala.util.Random
+
+import io.grpc.{Status, StatusRuntimeException}
+
+/**
+ * [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]]
+ *
+ * @param maxRetries
+ * Maximum number of retries.
+ * @param initialBackoff
+ * Start value of the exponential backoff (ms).
+ * @param maxBackoff
+ * Maximal value of the exponential backoff (ms).
+ * @param backoffMultiplier
+ * Multiplicative base of the exponential backoff.
+ * @param canRetry
+ * Function that determines whether a retry is to be performed in the event of 
an error.
+ */
+case class RetryPolicy(
+  maxRetries: Option[Int] = None,
+  initialBackoff: FiniteDuration = FiniteDuration(1000, "ms"),
+  maxBackoff: Option[FiniteDuration] = None,
+  backoffMultiplier: Double = 1.0,
+  jitter: FiniteDuration = FiniteDuration(0, "s"),
+  minJitterThreshold: FiniteDuration = FiniteDuration(0, "s"),
+  canRetry: Throwable => Boolean,
+  name: String) {
+
+  def getName: String = name
+
+  def toState: RetryPolicy.RetryPolicyState = new 
RetryPolicy.RetryPolicyState(this)
+}

Review Comment:
   When it's a case class, you cannot extend it to override `toState` like you 
mentioned.
   Maybe either:
   -> make it a regular class, so it can be overriden
   -> keeping it a case class, add `toState: () => RetryPolicy.RetryPolicyState 
= (() => new RetryPolicy.RetryPolicyState(this))` as a case class argument, so 
different policies can pass a different constructor, constructing potentiallly 
a subclass of the RetryPolicyState.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to