This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f5d7aef1b184 [SPARK-46723][CONNECT][SCALA] Make addArtifact retryable f5d7aef1b184 is described below commit f5d7aef1b184d3ec972b96d482c02b6be573f407 Author: Alice Sayutina <alice.sayut...@databricks.com> AuthorDate: Tue Jan 16 09:07:29 2024 +0900 [SPARK-46723][CONNECT][SCALA] Make addArtifact retryable ### What changes were proposed in this pull request? Make addArtifact API retrying on errors. Note this is safe operation since addArtifact is idempotent operation (https://github.com/apache/spark/pull/43314) ### Why are the changes needed? For the same reasons as we make other API retryable. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Added test. Testing by hand against custom spark server. ### Was this patch authored or co-authored using generative AI tooling? No & never Closes #44740 from cdkrot/SPARK-46723-addartifact. Authored-by: Alice Sayutina <alice.sayut...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/util/SparkThreadUtils.scala | 12 ++++++---- .../connect/client/SparkConnectClientSuite.scala | 27 ++++++++++++++++++++++ .../spark/sql/connect/client/ArtifactManager.scala | 19 ++++++++++++++- .../connect/client/CustomSparkConnectStub.scala | 2 +- 4 files changed, 54 insertions(+), 6 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala index a5e4cef1ec1a..8b2807a80dd1 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala @@ -42,10 +42,7 @@ private[spark] object SparkThreadUtils { @throws(classOf[SparkException]) def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { try { - // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. - // See SPARK-13747. - val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - awaitable.result(atMost)(awaitPermission) + awaitResultNoSparkExceptionConversion(awaitable, atMost) } catch { case e: SparkFatalException => throw e.throwable @@ -56,5 +53,12 @@ private[spark] object SparkThreadUtils { throw new SparkException("Exception thrown in awaitResult: ", t) } } + + def awaitResultNoSparkExceptionConversion[T](awaitable: Awaitable[T], atMost: Duration): T = { + // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. + // See SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + awaitable.result(atMost)(awaitPermission) + } // scalastyle:on awaitresult } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index d14caebe5b81..b0c4564130d3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -451,6 +451,33 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(countAttempted == 7) } + test("ArtifactManager retries errors") { + var attempt = 0 + + startDummyServer(0) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .interceptor(new ClientInterceptor { + override def interceptCall[ReqT, RespT]( + methodDescriptor: MethodDescriptor[ReqT, RespT], + callOptions: CallOptions, + channel: Channel): ClientCall[ReqT, RespT] = { + attempt += 1; + if (attempt <= 3) { + throw Status.UNAVAILABLE.withDescription("").asRuntimeException() + } + + channel.newCall(methodDescriptor, callOptions) + } + }) + .build() + + val session = SparkSession.builder().client(client).create() + val artifactFilePath = commonResourcePath.resolve("artifact-tests") + session.addArtifact(artifactFilePath.resolve("smallClassFile.class").toString) + } + test("SPARK-45871: Client execute iterator.toSeq consumes the reattachable iterator") { startDummyServer(0) client = SparkConnectClient diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index 36bc60c7d63a..6eb59bd37574 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -31,10 +31,12 @@ import scala.util.control.NonFatal import Artifact._ import com.google.protobuf.ByteString +import io.grpc.StatusRuntimeException import io.grpc.stub.StreamObserver import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.apache.commons.lang3.StringUtils +import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.AddArtifactsResponse import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary @@ -63,6 +65,7 @@ class ArtifactManager( private val CHUNK_SIZE: Int = 32 * 1024 private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder] + private[this] val stubState = stub.stubState /** * Register a [[ClassFinder]] for dynamically generated classes. @@ -228,6 +231,17 @@ class ArtifactManager( return } + try { + stubState.retryHandler.retry { + addArtifactsImpl(artifacts) + } + } catch { + case ex: StatusRuntimeException => + throw new SparkException(ex.toString, ex.getCause) + } + } + + private[client] def addArtifactsImpl(artifacts: Iterable[Artifact]): Unit = { val promise = Promise[Seq[ArtifactSummary]]() val responseHandler = new StreamObserver[proto.AddArtifactsResponse] { private val summaries = mutable.Buffer.empty[ArtifactSummary] @@ -284,7 +298,10 @@ class ArtifactManager( writeBatch() } stream.onCompleted() - SparkThreadUtils.awaitResult(promise.future, Duration.Inf) + // Don't convert to SparkException yet for the sake of retrying. + // retryPolicies are designed around underlying grpc StatusRuntimeException's. + // Convert to sparkException only if retrying fails. + SparkThreadUtils.awaitResultNoSparkExceptionConversion(promise.future, Duration.Inf) // TODO(SPARK-42658): Handle responses containing CRC failures. } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala index 382bc8706955..187c2842a0bc 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala @@ -23,7 +23,7 @@ import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse private[client] class CustomSparkConnectStub( channel: ManagedChannel, - stubState: SparkConnectStubState) { + val stubState: SparkConnectStubState) { private val stub = SparkConnectServiceGrpc.newStub(channel) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org