This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c291e1165c1 [SPARK-46042][FOLLOWUP][CONNECT] Test and adapt to 
streaming RPC behavior change from grpc 1.56 to 1.59
9c291e1165c1 is described below

commit 9c291e1165c145104becf69ecafcdba2914c29f1
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Wed Nov 22 16:59:56 2023 -0800

    [SPARK-46042][FOLLOWUP][CONNECT] Test and adapt to streaming RPC behavior 
change from grpc 1.56 to 1.59
    
    ### What changes were proposed in this pull request?
    
    This is a followup to https://github.com/apache/spark/pull/43955
    
    In grpc 1.56, when calling a server streaming RPC like `client.execute`, 
the request would not be send to server until the first interaction with the 
resulting iterator (next or hasNext). In grpc 1.59, it appears that the request 
is send to the server immediately. See 
https://github.com/grpc/grpc-java/issues/10697.
    
    I propose to embrace this new behaviour. I found it weird that calling 
`client.execute()` before wouldn't send the query to server until the first 
`hasNext()`. All the public APIs except for `toLocalIterator` consume the 
result immediately, so this change does not affect user facing behavior, except 
for the `toLocalIterator` change described below.
    
    ### Why are the changes needed?
    
    Test and fix behavior after grpc upgrade.
    
    Tested that reverting grpc to 1.56 makes the requests not be submitted by 
just calling `client.execute()`:
    ```
    [info] - Execute is sent eagerly to the server upon iterator creation *** 
FAILED *** (30 seconds, 445 milliseconds)
    [info]   The code passed to eventually never returned normally. Attempted 
1941 times over 30.009993892 seconds. Last failure message: List() had length 0 
instead of expected length 1. (SparkConnectServiceE2ESuite.scala:39)
    ```
    
    The new tests added in SparkConnectClientSuite test that the error is 
thrown from the response iterator, and not directly when creating the iterator. 
GrpcRetryHandler and ExecutePlanResponseReattachableIterator rely on that 
assumption. Since it holds, they don't need changes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    Calling `dataset.toLocalIterator` used in Spark Connect used to not send 
the query to the server until the resulted iterator was attempted to be opened 
with `hasNext` or `next`. Now, the query will be submitted upon the call to 
`toLocalIterator`.
    Note that the behavior of `toLocalIterator` in Spark Connect was already 
different from non-Spark Connect. In non-Spark Connect, the query would be 
executed wholly lazily, submitting every result task as a separate job on 
demand as the iterator progressed. In Spark Connect, once the query was 
submitted to the server, the execution was not lazy.
    
    ### How was this patch tested?
    
    Added tests and tweaked existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    I am using Github Copilot in my IDE. It helps auto-complete some trivial 
boilerplate code.
    Generated-by: Github Copilot
    
    Closes #43962 from juliuszsompolski/SPARK-46042-followup.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../connect/client/SparkConnectClientSuite.scala   | 75 ++++++++++++++++++++++
 .../ExecutePlanResponseReattachableIterator.scala  | 19 +-----
 .../execution/ReattachableExecuteSuite.scala       | 21 +++---
 .../service/SparkConnectServiceE2ESuite.scala      | 50 +++++++--------
 4 files changed, 111 insertions(+), 54 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index e226484d87a0..698457ddb91d 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -26,6 +26,9 @@ import io.grpc.{CallOptions, Channel, ClientCall, 
ClientInterceptor, MethodDescr
 import io.grpc.netty.NettyServerBuilder
 import io.grpc.stub.StreamObserver
 import org.scalatest.BeforeAndAfterEach
+import org.scalatest.concurrent.Eventually
+import org.scalatest.concurrent.Futures.timeout
+import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkException, SparkThrowable}
 import org.apache.spark.connect.proto
@@ -482,6 +485,78 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     iter.foreach(_ => ())
     assert(reattachableIter.resultComplete)
   }
+
+  test("GRPC stub unary call throws error immediately") {
+    // Spark Connect error retry handling depends on the error being returned 
from the unary
+    // call immediately.
+    val channel = SparkConnectClient.Configuration(host = 
"ABC").createChannel()
+    val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
+    // The request is invalid, but it shouldn't even reach the server.
+    val request = proto.AnalyzePlanRequest.newBuilder().build()
+
+    // calling unary call immediately throws connection exception
+    val ex = intercept[StatusRuntimeException] {
+      stub.analyzePlan(request)
+    }
+    assert(ex.getMessage.contains("UNAVAILABLE: Unable to resolve host ABC"))
+  }
+
+  test("GRPC stub server streaming call throws error on first next() / 
hasNext()") {
+    // Spark Connect error retry handling depends on the error being returned 
from the response
+    // iterator and not immediately upon iterator creation.
+    val channel = SparkConnectClient.Configuration(host = 
"ABC").createChannel()
+    val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
+    // The request is invalid, but it shouldn't even reach the server.
+    val request = proto.ExecutePlanRequest.newBuilder().build()
+
+    // creating the iterator doesn't throw exception
+    val iter = stub.executePlan(request)
+    // error is thrown only when the iterator is open.
+    val ex = intercept[StatusRuntimeException] {
+      iter.hasNext()
+    }
+    assert(ex.getMessage.contains("UNAVAILABLE: Unable to resolve host ABC"))
+  }
+
+  test("GRPC stub client streaming call throws error on first client request 
sent") {
+    // Spark Connect error retry handling depends on the error being returned 
from the response
+    // iterator and not immediately upon iterator creation or request being 
sent.
+    val channel = SparkConnectClient.Configuration(host = 
"ABC").createChannel()
+    val stub = proto.SparkConnectServiceGrpc.newStub(channel)
+
+    var onNextResponse: Option[proto.AddArtifactsResponse] = None
+    var onErrorThrowable: Option[Throwable] = None
+    var onCompletedCalled: Boolean = false
+
+    val responseObserver = new StreamObserver[proto.AddArtifactsResponse] {
+      override def onNext(value: proto.AddArtifactsResponse): Unit = {
+        onNextResponse = Some(value)
+      }
+
+      override def onError(t: Throwable): Unit = {
+        onErrorThrowable = Some(t)
+      }
+
+      override def onCompleted(): Unit = {
+        onCompletedCalled = false
+      }
+    }
+
+    // calling client streaming call doesn't throw exception
+    val observer = stub.addArtifacts(responseObserver)
+
+    // but exception will get returned on the responseObserver.
+    Eventually.eventually(timeout(30.seconds)) {
+      assert(onNextResponse == None)
+      assert(onErrorThrowable.isDefined)
+      assert(onErrorThrowable.get.getMessage.contains("UNAVAILABLE: Unable to 
resolve host ABC"))
+      assert(onCompletedCalled == false)
+    }
+
+    // despite that, requests can be sent to the request observer without 
error being thrown.
+    observer.onNext(proto.AddArtifactsRequest.newBuilder().build())
+    observer.onCompleted()
+  }
 }
 
 class DummySparkConnectService() extends 
SparkConnectServiceGrpc.SparkConnectServiceImplBase {
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
index dff9a53991f8..5854a9225dbe 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
@@ -101,20 +101,7 @@ class ExecutePlanResponseReattachableIterator(
   // throw error on first iter.hasNext() or iter.next()
   // Visible for testing.
   private[connect] var iter: 
Option[java.util.Iterator[proto.ExecutePlanResponse]] =
-    Some(makeLazyIter(rawBlockingStub.executePlan(initialRequest)))
-
-  // Creates a request that contains the query and returns a stream of 
`ExecutePlanResponse`.
-  // After upgrading gRPC from 1.56.0 to 1.59.3, it makes the first request 
when
-  // the stream is created, but here the code here assumes that no request is 
made before
-  // that, see also SPARK-46042
-  private def makeLazyIter(f: => java.util.Iterator[proto.ExecutePlanResponse])
-      : java.util.Iterator[proto.ExecutePlanResponse] = {
-    new java.util.Iterator[proto.ExecutePlanResponse] {
-      private lazy val internalIter = f
-      override def hasNext: Boolean = internalIter.hasNext
-      override def next(): proto.ExecutePlanResponse = internalIter.next
-    }
-  }
+    Some(rawBlockingStub.executePlan(initialRequest))
 
   // Server side session ID, used to detect if the server side session 
changed. This is set upon
   // receiving the first response from the server.
@@ -241,7 +228,7 @@ class ExecutePlanResponseReattachableIterator(
   private def callIter[V](iterFun: 
java.util.Iterator[proto.ExecutePlanResponse] => V) = {
     try {
       if (iter.isEmpty) {
-        iter = 
Some(makeLazyIter(rawBlockingStub.reattachExecute(createReattachExecuteRequest())))
+        iter = 
Some(rawBlockingStub.reattachExecute(createReattachExecuteRequest()))
       }
       iterFun(iter.get)
     } catch {
@@ -254,7 +241,7 @@ class ExecutePlanResponseReattachableIterator(
             ex)
         }
         // Try a new ExecutePlan, and throw upstream for retry.
-        iter = Some(makeLazyIter(rawBlockingStub.executePlan(initialRequest)))
+        iter = Some(rawBlockingStub.executePlan(initialRequest))
         val error = new RetryException()
         error.addSuppressed(ex)
         throw error
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
index 784b978f447d..2a6b89620886 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
@@ -40,8 +40,7 @@ class ReattachableExecuteSuite extends SparkConnectServerTest 
{
       val reattachableIter = getReattachableIterator(iter)
       val initialInnerIter = reattachableIter.innerIterator
 
-      // open the iterator
-      iter.next()
+      iter.next() // open iterator, guarantees that the RPC reached the server
       // expire all RPCs on server
       
SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis()
 - 1)
       assertEventuallyNoActiveRpcs()
@@ -59,7 +58,7 @@ class ReattachableExecuteSuite extends SparkConnectServerTest 
{
   test("raw interrupted RPC results in INVALID_CURSOR.DISCONNECTED error") {
     withRawBlockingStub { stub =>
       val iter = 
stub.executePlan(buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY)))
-      iter.next() // open the iterator
+      iter.next() // open iterator, guarantees that the RPC reached the server
       // interrupt all RPCs on server
       SparkConnectService.executionManager.interruptAllRPCs()
       assertEventuallyNoActiveRpcs()
@@ -76,11 +75,11 @@ class ReattachableExecuteSuite extends 
SparkConnectServerTest {
       val operationId = UUID.randomUUID().toString
       val iter = stub.executePlan(
         buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = 
operationId))
-      iter.next() // open the iterator
+      iter.next() // open the iterator, guarantees that the RPC reached the 
server
 
       // send reattach
       val iter2 = 
stub.reattachExecute(buildReattachExecuteRequest(operationId, None))
-      iter2.next() // open the iterator
+      iter2.next() // open the iterator, guarantees that the RPC reached the 
server
 
       // should result in INVALID_CURSOR.DISCONNECTED error on the original 
iterator
       val e = intercept[StatusRuntimeException] {
@@ -91,7 +90,7 @@ class ReattachableExecuteSuite extends SparkConnectServerTest 
{
       // send another reattach
       val iter3 = 
stub.reattachExecute(buildReattachExecuteRequest(operationId, None))
       assert(iter3.hasNext)
-      iter3.next() // open the iterator
+      iter3.next() // open the iterator, guarantees that the RPC reached the 
server
 
       // should result in INVALID_CURSOR.DISCONNECTED error on the previous 
reattach iterator
       val e2 = intercept[StatusRuntimeException] {
@@ -108,7 +107,7 @@ class ReattachableExecuteSuite extends 
SparkConnectServerTest {
       val initialInnerIter = reattachableIter.innerIterator
       val operationId = getReattachableIterator(iter).operationId
 
-      // open the iterator
+      // open the iterator, guarantees that the RPC reached the server
       iter.next()
 
       // interrupt all RPCs on server
@@ -129,7 +128,7 @@ class ReattachableExecuteSuite extends 
SparkConnectServerTest {
       val initialInnerIter = reattachableIter.innerIterator
       val operationId = getReattachableIterator(iter).operationId
 
-      // open the iterator
+      // open the iterator, guarantees that the RPC reached the server
       val response = iter.next()
 
       // Send another Reattach request, it should preempt this request with an
@@ -152,7 +151,7 @@ class ReattachableExecuteSuite extends 
SparkConnectServerTest {
       val plan = buildPlan("select * from range(100000)")
       val iter = client.execute(buildPlan(MEDIUM_RESULTS_QUERY))
       val operationId = getReattachableIterator(iter).operationId
-      // open the iterator
+      // open the iterator, guarantees that the RPC reached the server
       iter.next()
       // disconnect and remove on server
       
SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis()
 - 1)
@@ -190,7 +189,7 @@ class ReattachableExecuteSuite extends 
SparkConnectServerTest {
       val initialInnerIter = reattachableIter.innerIterator
       val operationId = getReattachableIterator(iter).operationId
 
-      assert(iter.hasNext) // open iterator
+      assert(iter.hasNext) // open iterator, guarantees that the RPC reached 
the server
       val execution = getExecutionHolder
       assert(execution.responseObserver.releasedUntilIndex == 0)
 
@@ -249,7 +248,7 @@ class ReattachableExecuteSuite extends 
SparkConnectServerTest {
         .get(Connect.CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE)
         .toLong
 
-      iter.hasNext // open iterator
+      iter.hasNext // open iterator, guarantees that the RPC reached the server
       val execution = getExecutionHolder
 
       // after consuming enough from the iterator, server should automatically 
start releasing
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index 7bd4bd742c95..c0b7eaf5823d 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -32,16 +32,25 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
   // were all already in the buffer.
   val BIG_ENOUGH_QUERY = "select * from range(1000000)"
 
+  test("Execute is sent eagerly to the server upon iterator creation") {
+    // This behavior changed with grpc upgrade from 1.56.0 to 1.59.0.
+    // Testing to be aware of future changes.
+    withClient { client =>
+      val query = client.execute(buildPlan(BIG_ENOUGH_QUERY))
+      // just creating the iterator triggers query to be sent to server.
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(SparkConnectService.executionManager.listExecuteHolders.length 
== 1)
+      }
+      assert(query.hasNext)
+    }
+  }
+
   test("ReleaseSession releases all queries and does not allow more requests 
in the session") {
     withClient { client =>
       val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY))
       val query2 = client.execute(buildPlan(BIG_ENOUGH_QUERY))
-      val query3 = client.execute(buildPlan("select 1"))
-      // just creating the iterator is lazy, trigger query1 and query2 to be 
sent.
-      query1.hasNext
-      query2.hasNext
       Eventually.eventually(timeout(eventuallyTimeout)) {
-        SparkConnectService.executionManager.listExecuteHolders.length == 2
+        assert(SparkConnectService.executionManager.listExecuteHolders.length 
== 2)
       }
 
       // Close session
@@ -51,8 +60,7 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
 
       // Check that queries get cancelled
       Eventually.eventually(timeout(eventuallyTimeout)) {
-        SparkConnectService.executionManager.listExecuteHolders.length == 0
-        // SparkConnectService.sessionManager.
+        assert(SparkConnectService.executionManager.listExecuteHolders.length 
== 0)
       }
 
       // query1 and query2 could get either an:
@@ -75,13 +83,6 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
         query2Error.getMessage.contains("OPERATION_CANCELED") ||
           
query2Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED"))
 
-      // query3 has not been submitted before, so it should now fail with 
SESSION_CLOSED
-      // TODO(SPARK-46042) Reenable a `releaseSession` test case in 
SparkConnectServiceE2ESuite
-      val query3Error = intercept[SparkException] {
-        query3.hasNext
-      }
-      assert(query3Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
-
       // No other requests should be allowed in the session, failing with 
SESSION_CLOSED
       val requestError = intercept[SparkException] {
         client.interruptAll()
@@ -99,18 +100,15 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
       withClient(sessionId = sessionIdB, userId = userIdB) { clientB =>
         val queryA = clientA.execute(buildPlan(BIG_ENOUGH_QUERY))
         val queryB = clientB.execute(buildPlan(BIG_ENOUGH_QUERY))
-        // just creating the iterator is lazy, trigger query1 and query2 to be 
sent.
-        queryA.hasNext
-        queryB.hasNext
         Eventually.eventually(timeout(eventuallyTimeout)) {
-          SparkConnectService.executionManager.listExecuteHolders.length == 2
+          
assert(SparkConnectService.executionManager.listExecuteHolders.length == 2)
         }
         // Close session A
         clientA.releaseSession()
 
         // A's query gets kicked out.
         Eventually.eventually(timeout(eventuallyTimeout)) {
-          SparkConnectService.executionManager.listExecuteHolders.length == 1
+          
assert(SparkConnectService.executionManager.listExecuteHolders.length == 1)
         }
         val queryAError = intercept[SparkException] {
           while (queryA.hasNext) queryA.next()
@@ -151,7 +149,7 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
     withClient(sessionId = sessionId, userId = userId) { client =>
       // this will create the session, and then ReleaseSession at the end of 
withClient.
       val query = client.execute(buildPlan("SELECT 1"))
-      query.hasNext // trigger execution
+      query.hasNext // guarantees the request was received by server.
       client.releaseSession()
     }
     withClient(sessionId = sessionId, userId = userId) { client =>
@@ -169,17 +167,17 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
     val userId = "Y"
     withClient(sessionId = sessionId, userId = userId) { client =>
       val query = client.execute(buildPlan("SELECT 1"))
-      query.hasNext // trigger execution
+      query.hasNext // guarantees the request was received by server.
       client.releaseSession()
     }
     withClient(sessionId = UUID.randomUUID.toString, userId = userId) { client 
=>
       val query = client.execute(buildPlan("SELECT 1"))
-      query.hasNext // trigger execution
+      query.hasNext // guarantees the request was received by server.
       client.releaseSession()
     }
     withClient(sessionId = sessionId, userId = "YY") { client =>
       val query = client.execute(buildPlan("SELECT 1"))
-      query.hasNext // trigger execution
+      query.hasNext // guarantees the request was received by server.
       client.releaseSession()
     }
   }
@@ -188,10 +186,9 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
     withRawBlockingStub { stub =>
       val iter =
         stub.executePlan(buildExecutePlanRequest(buildPlan("select * from 
range(1000000)")))
-      iter.hasNext
       val execution = eventuallyGetExecutionHolder
       Eventually.eventually(timeout(30.seconds)) {
-        execution.eventsManager.status == ExecuteStatus.Finished
+        assert(execution.eventsManager.status == ExecuteStatus.Finished)
       }
     }
   }
@@ -199,10 +196,9 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
   test("SPARK-45133 local relation should reach FINISHED state when results 
are not consumed") {
     withClient { client =>
       val iter = client.execute(buildLocalRelation((1 to 1000000).map(i => (i, 
i + 1))))
-      iter.hasNext
       val execution = eventuallyGetExecutionHolder
       Eventually.eventually(timeout(30.seconds)) {
-        execution.eventsManager.status == ExecuteStatus.Finished
+        assert(execution.eventsManager.status == ExecuteStatus.Finished)
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to