This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f5d7aef1b184 [SPARK-46723][CONNECT][SCALA] Make addArtifact retryable
f5d7aef1b184 is described below

commit f5d7aef1b184d3ec972b96d482c02b6be573f407
Author: Alice Sayutina <alice.sayut...@databricks.com>
AuthorDate: Tue Jan 16 09:07:29 2024 +0900

    [SPARK-46723][CONNECT][SCALA] Make addArtifact retryable
    
    ### What changes were proposed in this pull request?
    
    Make addArtifact API retrying on errors.
    
    Note this is safe operation since addArtifact is idempotent operation 
(https://github.com/apache/spark/pull/43314)
    
    ### Why are the changes needed?
    
    For the same reasons as we make other API retryable.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    Added test.
    
    Testing by hand against custom spark server.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No & never
    
    Closes #44740 from cdkrot/SPARK-46723-addartifact.
    
    Authored-by: Alice Sayutina <alice.sayut...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/util/SparkThreadUtils.scala   | 12 ++++++----
 .../connect/client/SparkConnectClientSuite.scala   | 27 ++++++++++++++++++++++
 .../spark/sql/connect/client/ArtifactManager.scala | 19 ++++++++++++++-
 .../connect/client/CustomSparkConnectStub.scala    |  2 +-
 4 files changed, 54 insertions(+), 6 deletions(-)

diff --git 
a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala 
b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala
index a5e4cef1ec1a..8b2807a80dd1 100644
--- a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala
+++ b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala
@@ -42,10 +42,7 @@ private[spark] object SparkThreadUtils {
   @throws(classOf[SparkException])
   def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
     try {
-      // `awaitPermission` is not actually used anywhere so it's safe to pass 
in null here.
-      // See SPARK-13747.
-      val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
-      awaitable.result(atMost)(awaitPermission)
+      awaitResultNoSparkExceptionConversion(awaitable, atMost)
     } catch {
       case e: SparkFatalException =>
         throw e.throwable
@@ -56,5 +53,12 @@ private[spark] object SparkThreadUtils {
         throw new SparkException("Exception thrown in awaitResult: ", t)
     }
   }
+
+  def awaitResultNoSparkExceptionConversion[T](awaitable: Awaitable[T], 
atMost: Duration): T = {
+    // `awaitPermission` is not actually used anywhere so it's safe to pass in 
null here.
+    // See SPARK-13747.
+    val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
+    awaitable.result(atMost)(awaitPermission)
+  }
   // scalastyle:on awaitresult
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index d14caebe5b81..b0c4564130d3 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -451,6 +451,33 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     assert(countAttempted == 7)
   }
 
+  test("ArtifactManager retries errors") {
+    var attempt = 0
+
+    startDummyServer(0)
+    client = SparkConnectClient
+      .builder()
+      .connectionString(s"sc://localhost:${server.getPort}")
+      .interceptor(new ClientInterceptor {
+        override def interceptCall[ReqT, RespT](
+            methodDescriptor: MethodDescriptor[ReqT, RespT],
+            callOptions: CallOptions,
+            channel: Channel): ClientCall[ReqT, RespT] = {
+          attempt += 1;
+          if (attempt <= 3) {
+            throw Status.UNAVAILABLE.withDescription("").asRuntimeException()
+          }
+
+          channel.newCall(methodDescriptor, callOptions)
+        }
+      })
+      .build()
+
+    val session = SparkSession.builder().client(client).create()
+    val artifactFilePath = commonResourcePath.resolve("artifact-tests")
+    
session.addArtifact(artifactFilePath.resolve("smallClassFile.class").toString)
+  }
+
   test("SPARK-45871: Client execute iterator.toSeq consumes the reattachable 
iterator") {
     startDummyServer(0)
     client = SparkConnectClient
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
index 36bc60c7d63a..6eb59bd37574 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
@@ -31,10 +31,12 @@ import scala.util.control.NonFatal
 
 import Artifact._
 import com.google.protobuf.ByteString
+import io.grpc.StatusRuntimeException
 import io.grpc.stub.StreamObserver
 import org.apache.commons.codec.digest.DigestUtils.sha256Hex
 import org.apache.commons.lang3.StringUtils
 
+import org.apache.spark.SparkException
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.AddArtifactsResponse
 import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
@@ -63,6 +65,7 @@ class ArtifactManager(
   private val CHUNK_SIZE: Int = 32 * 1024
 
   private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder]
+  private[this] val stubState = stub.stubState
 
   /**
    * Register a [[ClassFinder]] for dynamically generated classes.
@@ -228,6 +231,17 @@ class ArtifactManager(
       return
     }
 
+    try {
+      stubState.retryHandler.retry {
+        addArtifactsImpl(artifacts)
+      }
+    } catch {
+      case ex: StatusRuntimeException =>
+        throw new SparkException(ex.toString, ex.getCause)
+    }
+  }
+
+  private[client] def addArtifactsImpl(artifacts: Iterable[Artifact]): Unit = {
     val promise = Promise[Seq[ArtifactSummary]]()
     val responseHandler = new StreamObserver[proto.AddArtifactsResponse] {
       private val summaries = mutable.Buffer.empty[ArtifactSummary]
@@ -284,7 +298,10 @@ class ArtifactManager(
       writeBatch()
     }
     stream.onCompleted()
-    SparkThreadUtils.awaitResult(promise.future, Duration.Inf)
+    // Don't convert to SparkException yet for the sake of retrying.
+    // retryPolicies are designed around underlying grpc 
StatusRuntimeException's.
+    // Convert to sparkException only if retrying fails.
+    SparkThreadUtils.awaitResultNoSparkExceptionConversion(promise.future, 
Duration.Inf)
     // TODO(SPARK-42658): Handle responses containing CRC failures.
   }
 
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala
index 382bc8706955..187c2842a0bc 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala
@@ -23,7 +23,7 @@ import org.apache.spark.connect.proto.{AddArtifactsRequest, 
AddArtifactsResponse
 
 private[client] class CustomSparkConnectStub(
     channel: ManagedChannel,
-    stubState: SparkConnectStubState) {
+    val stubState: SparkConnectStubState) {
 
   private val stub = SparkConnectServiceGrpc.newStub(channel)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to