grundprinzip commented on code in PR #41829:
URL: https://github.com/apache/spark/pull/41829#discussion_r1252308844
##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala:
##########
@@ -60,14 +64,66 @@ private[sql] class SparkConnectClient(
new ArtifactManager(userContext, sessionId, channel)
}
+ private val retryPolicy: SparkConnectClient.RetryPolicy =
configuration.retryPolicy
+
+ @tailrec private[client] final def retry[T](fn: => T, currentRetryNum: Int =
0): T = {
+ if (currentRetryNum > retryPolicy.maxRetries) {
+ throw new IllegalArgumentException(
+ s"The number of retries ($currentRetryNum) must not exceed " +
+ s"the maximum number of retires (${retryPolicy.maxRetries}).")
+ }
+ try {
+ return fn
+ } catch {
+ case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum <
retryPolicy.maxRetries =>
+ Thread.sleep(
+ (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
+ .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+ }
+ retry(fn, currentRetryNum + 1)
+ }
+
/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
* A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
*/
def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse = {
artifactManager.uploadAllClassFileArtifacts()
- stub.analyzePlan(request)
+ retry {
+ stub.analyzePlan(request)
+ }
+ }
+
+ private class executeRetryIterator(
+ request: proto.ExecutePlanRequest,
+ origIterator: java.util.Iterator[proto.ExecutePlanResponse])
+ extends java.util.Iterator[proto.ExecutePlanResponse] {
+
+ private var hasNextCalled = false
+ private var iterator = origIterator
+
+ override def next(): proto.ExecutePlanResponse = {
+ iterator.next()
+ }
+
+ override def hasNext(): Boolean = {
+ if (!hasNextCalled) {
+ hasNextCalled = true
+ var firstTry = true
+ retry {
+ if (firstTry) {
+ firstTry = false
+ iterator.hasNext()
+ } else {
+ iterator = stub.executePlan(request)
+ iterator.hasNext()
+ }
+ }
+ } else {
+ iterator.hasNext()
+ }
+ }
Review Comment:
```suggestion
override def hasNext(): Boolean = {
if (!hasNextCalled) {
hasNextCalled = true
var firstTry = true
retry {
if (firstTry) {
firstTry = false
} else {
iterator = stub.executePlan(request)
}
}
}
iterator.hasNext()
}
```
There is another way of making it even a bit more shorter - hasNextCalled
becomes an `AtomicBoolean` and you can do the following.
```suggestion
override def hasNext(): Boolean = {
if (!hasNextCalled.compareAndSet(false, true)) {
var firstTry = true
retry {
if (firstTry) {
firstTry = false
} else {
iterator = stub.executePlan(request)
}
}
}
iterator.hasNext()
}
```
But that might be overkill.
##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala:
##########
@@ -60,14 +64,66 @@ private[sql] class SparkConnectClient(
new ArtifactManager(userContext, sessionId, channel)
}
+ private val retryPolicy: SparkConnectClient.RetryPolicy =
configuration.retryPolicy
+
+ @tailrec private[client] final def retry[T](fn: => T, currentRetryNum: Int =
0): T = {
+ if (currentRetryNum > retryPolicy.maxRetries) {
+ throw new IllegalArgumentException(
+ s"The number of retries ($currentRetryNum) must not exceed " +
+ s"the maximum number of retires (${retryPolicy.maxRetries}).")
+ }
+ try {
+ return fn
+ } catch {
+ case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum <
retryPolicy.maxRetries =>
+ Thread.sleep(
+ (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
+ .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+ }
+ retry(fn, currentRetryNum + 1)
+ }
+
/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
* A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
*/
def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse = {
artifactManager.uploadAllClassFileArtifacts()
Review Comment:
Shouldn't this wrap the `uploadAllClassFileArtifacts()` as well?
##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala:
##########
@@ -60,14 +64,66 @@ private[sql] class SparkConnectClient(
new ArtifactManager(userContext, sessionId, channel)
}
+ private val retryPolicy: SparkConnectClient.RetryPolicy =
configuration.retryPolicy
+
+ @tailrec private[client] final def retry[T](fn: => T, currentRetryNum: Int =
0): T = {
+ if (currentRetryNum > retryPolicy.maxRetries) {
+ throw new IllegalArgumentException(
+ s"The number of retries ($currentRetryNum) must not exceed " +
+ s"the maximum number of retires (${retryPolicy.maxRetries}).")
+ }
+ try {
+ return fn
+ } catch {
+ case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum <
retryPolicy.maxRetries =>
+ Thread.sleep(
+ (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
+ .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+ }
+ retry(fn, currentRetryNum + 1)
+ }
+
/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
* A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
*/
def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse = {
artifactManager.uploadAllClassFileArtifacts()
- stub.analyzePlan(request)
+ retry {
+ stub.analyzePlan(request)
+ }
+ }
+
+ private class executeRetryIterator(
Review Comment:
doc please
##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala:
##########
@@ -60,14 +64,66 @@ private[sql] class SparkConnectClient(
new ArtifactManager(userContext, sessionId, channel)
}
+ private val retryPolicy: SparkConnectClient.RetryPolicy =
configuration.retryPolicy
+
+ @tailrec private[client] final def retry[T](fn: => T, currentRetryNum: Int =
0): T = {
+ if (currentRetryNum > retryPolicy.maxRetries) {
+ throw new IllegalArgumentException(
+ s"The number of retries ($currentRetryNum) must not exceed " +
+ s"the maximum number of retires (${retryPolicy.maxRetries}).")
+ }
+ try {
+ return fn
+ } catch {
+ case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum <
retryPolicy.maxRetries =>
+ Thread.sleep(
+ (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
+ .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+ }
+ retry(fn, currentRetryNum + 1)
+ }
+
/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
* A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
*/
def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse = {
artifactManager.uploadAllClassFileArtifacts()
- stub.analyzePlan(request)
+ retry {
+ stub.analyzePlan(request)
+ }
+ }
+
+ private class executeRetryIterator(
+ request: proto.ExecutePlanRequest,
+ origIterator: java.util.Iterator[proto.ExecutePlanResponse])
+ extends java.util.Iterator[proto.ExecutePlanResponse] {
+
+ private var hasNextCalled = false
+ private var iterator = origIterator
+
+ override def next(): proto.ExecutePlanResponse = {
+ iterator.next()
+ }
+
+ override def hasNext(): Boolean = {
+ if (!hasNextCalled) {
+ hasNextCalled = true
+ var firstTry = true
+ retry {
+ if (firstTry) {
+ firstTry = false
+ iterator.hasNext()
+ } else {
+ iterator = stub.executePlan(request)
Review Comment:
Actually, shouldn't this set `hasNextCalled` back to `false`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]