This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 45f25dce182 [SPARK-43331][CONNECT] Add Spark Connect
SparkSession.interruptAll
45f25dce182 is described below
commit 45f25dce18253a352cd7af4d91d4dcda743f5612
Author: Juliusz Sompolski <[email protected]>
AuthorDate: Wed May 3 11:48:16 2023 -0400
[SPARK-43331][CONNECT] Add Spark Connect SparkSession.interruptAll
### What changes were proposed in this pull request?
Currently, queries that are ran using Spark Connect cannot be interrupted.
Even when the RPC connection got broken, the Spark jobs on the server continue
running.
This PR proposes a
```
rpc Interrupt(InterruptRequest) returns (InterruptResponse) {}
```
server RPC API, that can be called from the client as
`SparkSession.interruptAll()` to interrupt all actively running Spark Jobs from
ExecutePlan executions. In most user scenarios, SparkSession is not used for
multiple executions concurrently, but is used sequentially, so `interruptAll()`
should serve a big chunk of user needs. It can also be used to clean up.
To keep track of executions, we introduce `ExecutionHolder` to hold the
execution state, and make `SessionHolder` keep track of the executions
currently running in the session. In this first PR, the interrupt only
interrupts running Spark Jobs. As such, it is to a degree best effort, because
it will not interrupt commands that don't run Spark Jobs, or it will not
interrupt anything if a Spark Job is not running when it the interrupt is
received by the server, and the command will con [...]
Future work I plan to design and work on will involve:
* Interrupting any execution. This will involve moving execution from the
GRPC handler thread handling ExecutePlan to launching it in a separate thread
that can be interrupted. `ExecutionHolder`
* Interrupting executions selectively. This will involve exposing the
operationId to the user.
* (Refactor) some cleanup needed around using SessionHolder. Currently,
sessionId, userId, various session data is passed around separately.
SessionHolder may be passed instead, and then also be extended with more useful
APIs for management of the session.
### Why are the changes needed?
Need to have APIs to be able to interrupt running queries in Spark Connect.
### Does this PR introduce _any_ user-facing change?
Yes.
Users of Spark Connect can now call `interruptAll()` on client
`SparkSession` object, to send an interrupt RPC to the server, which will
interrupt the running queries.
In followup PRs, this will be extended to Python client, and to work not
only for interrupting Spark Jobs.
### How was this patch tested?
Added E2E tests to ClientE2ETestSuite for scala connect client.
Added unit tests to test_client for python connect client.
Closes #41005 from juliuszsompolski/connect-cancelall.
Authored-by: Juliusz Sompolski <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 13 +++
.../sql/connect/client/SparkConnectClient.scala | 11 +++
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 70 +++++++++++++++
.../src/main/protobuf/spark/connect/base.proto | 34 +++++++
.../sql/connect/service/ExecutePlanHolder.scala | 41 +++++++++
.../spark/sql/connect/service/SessionHolder.scala | 63 +++++++++++++
.../service/SparkConnectInterruptHandler.scala | 45 ++++++++++
.../sql/connect/service/SparkConnectService.scala | 24 +++--
.../service/SparkConnectStreamHandler.scala | 47 ++++++----
python/pyspark/errors/error_classes.py | 5 ++
python/pyspark/sql/connect/client.py | 42 +++++++++
python/pyspark/sql/connect/proto/base_pb2.py | 37 +++++++-
python/pyspark/sql/connect/proto/base_pb2.pyi | 100 +++++++++++++++++++++
python/pyspark/sql/connect/proto/base_pb2_grpc.py | 45 ++++++++++
python/pyspark/sql/connect/session.py | 3 +
python/pyspark/sql/tests/connect/test_client.py | 14 +++
16 files changed, 568 insertions(+), 26 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 461b18ec9c1..48640878211 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -536,6 +536,19 @@ class SparkSession private[sql] (
planIdGenerator.set(0)
}
+ /**
+ * Interrupt all operations of this session currently running on the
connected server.
+ *
+ * TODO/WIP: Currently it will interrupt the Spark Jobs running on the
server, triggered from
+ * ExecutePlan requests. If an operation is not running a Spark Job, it
becomes an noop and the
+ * operation will continue afterwards, possibly with more Spark Jobs.
+ *
+ * @since 3.5.0
+ */
+ def interruptAll(): Unit = {
+ client.interruptAll()
+ }
+
/**
* Synonym for `close()`.
*
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 1d47f3e663f..ca00eb74a20 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -184,6 +184,17 @@ private[sql] class SparkConnectClient(
analyze(request)
}
+ private[sql] def interruptAll(): proto.InterruptResponse = {
+ val builder = proto.InterruptRequest.newBuilder()
+ val request = builder
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .setClientType(userAgent)
+
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL)
+ .build()
+ stub.interrupt(request)
+ }
+
def copy(): SparkConnectClient = {
new SparkConnectClient(userContext, channelBuilder, userAgent)
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index abeeaf7e483..8a01f828ef0 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -20,6 +20,9 @@ import java.io.{ByteArrayOutputStream, PrintStream}
import java.nio.file.Files
import scala.collection.JavaConverters._
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration._
+import scala.util.{Failure, Success}
import io.grpc.StatusRuntimeException
import java.util.Properties
@@ -27,6 +30,7 @@ import org.apache.commons.io.FileUtils
import org.apache.commons.io.output.TeeOutputStream
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalactic.TolerantNumerics
+import org.scalatest.concurrent.Eventually._
import org.apache.spark.SPARK_VERSION
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
@@ -36,6 +40,7 @@ import
org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSpa
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.util.ThreadUtils
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
@@ -867,6 +872,71 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper {
assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
}
}
+
+ test("interrupt all - background queries, foreground interrupt") {
+ val session = spark
+ import session.implicits._
+ implicit val ec = ExecutionContext.global
+ val q1 = Future {
+ spark.range(10).map(n => { Thread.sleep(30000); n }).collect()
+ }
+ val q2 = Future {
+ spark.range(10).map(n => { Thread.sleep(30000); n }).collect()
+ }
+ var q1Interrupted = false
+ var q2Interrupted = false
+ var error: Option[String] = None
+ q1.onComplete {
+ case Success(_) =>
+ error = Some("q1 shouldn't have finished!")
+ case Failure(t) if t.getMessage.contains("cancelled") =>
+ q1Interrupted = true
+ case Failure(t) =>
+ error = Some("unexpected failure in q1: " + t.toString)
+ }
+ q1.onComplete {
+ case Success(_) =>
+ error = Some("q2 shouldn't have finished!")
+ case Failure(t) if t.getMessage.contains("cancelled") =>
+ q2Interrupted = true
+ case Failure(t) =>
+ error = Some("unexpected failure in q2: " + t.toString)
+ }
+ // 20 seconds is < 30 seconds the queries should be running,
+ // because it should be interrupted sooner
+ eventually(timeout(20.seconds), interval(1.seconds)) {
+ // keep interrupting every second, until both queries get interrupted.
+ spark.interruptAll()
+ assert(q1Interrupted)
+ assert(q2Interrupted)
+ assert(error.isEmpty)
+ }
+ }
+
+ test("interrupt all - foreground queries, background interrupt") {
+ val session = spark
+ import session.implicits._
+ implicit val ec = ExecutionContext.global
+
+ @volatile var finished = false
+ val interruptor = Future {
+ eventually(timeout(20.seconds), interval(1.seconds)) {
+ spark.interruptAll()
+ assert(finished)
+ }
+ finished
+ }
+ val e1 = intercept[io.grpc.StatusRuntimeException] {
+ spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n
}).collect()
+ }
+ assert(e1.getMessage.contains("cancelled"))
+ val e2 = intercept[io.grpc.StatusRuntimeException] {
+ spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n
}).collect()
+ }
+ assert(e2.getMessage.contains("cancelled"))
+ finished = true
+ assert(ThreadUtils.awaitResult(interruptor, 10.seconds) == true)
+ }
}
private[sql] case class MyType(id: Long, a: Double, b: Double)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 1f572178404..b304a8aea63 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -582,6 +582,37 @@ message ArtifactStatusesResponse {
map<string, ArtifactStatus> statuses = 1;
}
+message InterruptRequest {
+ // (Required)
+ //
+ // The session_id specifies a spark session for a user id (which is specified
+ // by user_context.user_id). The session_id is set by the client to be able
to
+ // collate streaming responses from different queries within the dedicated
session.
+ string session_id = 1;
+
+ // (Required) User context
+ UserContext user_context = 2;
+
+ // Provides optional information about the client sending the request. This
field
+ // can be used for language or version specific information and is only
intended for
+ // logging purposes and will not be interpreted by the server.
+ optional string client_type = 3;
+
+ // (Required) The type of interrupt to execute.
+ InterruptType interrupt_type = 4;
+
+ enum InterruptType {
+ INTERRUPT_TYPE_UNSPECIFIED = 0;
+
+ // Interrupt all running executions within session with provided
session_id.
+ INTERRUPT_TYPE_ALL = 1;
+ }
+}
+
+message InterruptResponse {
+ string session_id = 1;
+}
+
// Main interface for the SparkConnect service.
service SparkConnectService {
@@ -602,5 +633,8 @@ service SparkConnectService {
// Check statuses of artifacts in the session and returns them in a
[[ArtifactStatusesResponse]]
rpc ArtifactStatus(ArtifactStatusesRequest) returns
(ArtifactStatusesResponse) {}
+
+ // Interrupts running executions
+ rpc Interrupt(InterruptRequest) returns (InterruptResponse) {}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
new file mode 100644
index 00000000000..a3c17b9826e
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import org.apache.spark.connect.proto
+
+/**
+ * Object used to hold the Spark Connect execution state.
+ */
+case class ExecutePlanHolder(
+ operationId: String,
+ sessionHolder: SessionHolder,
+ request: proto.ExecutePlanRequest) {
+
+ val jobGroupId =
+
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Request_${operationId}"
+
+ def interrupt(): Unit = {
+ // TODO/WIP: This only interrupts active Spark jobs that are actively
running.
+ // This would then throw the error from ExecutePlan and terminate it.
+ // But if the query is not running a Spark job, but executing code on
Spark driver, this
+ // would be a noop and the execution will keep running.
+ sessionHolder.session.sparkContext.cancelJobGroup(jobGroupId)
+ }
+
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
new file mode 100644
index 00000000000..613d7a38e9e
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
+
+import collection.JavaConverters._
+import scala.util.control.NonFatal
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Object used to hold the Spark Connect session state.
+ */
+case class SessionHolder(userId: String, sessionId: String, session:
SparkSession)
+ extends Logging {
+
+ val executePlanOperations: ConcurrentMap[String, ExecutePlanHolder] =
+ new ConcurrentHashMap[String, ExecutePlanHolder]()
+
+ private[connect] def createExecutePlanHolder(
+ request: proto.ExecutePlanRequest): ExecutePlanHolder = {
+
+ val operationId = UUID.randomUUID().toString
+ val executePlanHolder = ExecutePlanHolder(operationId, this, request)
+ assert(executePlanOperations.putIfAbsent(operationId, executePlanHolder)
== null)
+ executePlanHolder
+ }
+
+ private[connect] def removeExecutePlanHolder(operationId: String): Unit = {
+ executePlanOperations.remove(operationId)
+ }
+
+ private[connect] def interruptAll(): Unit = {
+ executePlanOperations.asScala.values.foreach { execute =>
+ // Eat exception while trying to interrupt a given execution and move
forward.
+ try {
+ execute.interrupt()
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Exception $e while trying to interrupt execution
${execute.operationId}")
+ }
+ }
+ }
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
new file mode 100644
index 00000000000..b0923e277e4
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+
+class SparkConnectInterruptHandler(responseObserver:
StreamObserver[proto.InterruptResponse])
+ extends Logging {
+
+ def handle(v: proto.InterruptRequest): Unit = {
+ val sessionHolder =
+ SparkConnectService
+ .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+
+ v.getInterruptType match {
+ case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL =>
+ sessionHolder.interruptAll()
+ case other =>
+ throw new UnsupportedOperationException(s"Unknown InterruptType
$other!")
+ }
+
+ val builder =
proto.InterruptResponse.newBuilder().setSessionId(v.getSessionId)
+
+ responseObserver.onNext(builder.build())
+ responseObserver.onCompleted()
+ }
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index bfe1512a49e..b444fc67ce1 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -245,15 +245,23 @@ class SparkConnectService(debug: Boolean)
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId)
}
-}
-/**
- * Object used for referring to SparkSessions in the SessionCache.
- *
- * @param userId
- * @param session
- */
-case class SessionHolder(userId: String, sessionId: String, session:
SparkSession)
+ /**
+ * This is the entry point for calls interrupting running executions.
+ */
+ override def interrupt(
+ request: proto.InterruptRequest,
+ responseObserver: StreamObserver[proto.InterruptResponse]): Unit = {
+ try {
+ new SparkConnectInterruptHandler(responseObserver).handle(request)
+ } catch
+ handleError(
+ "interrupt",
+ observer = responseObserver,
+ userId = request.getUserContext.getUserId,
+ sessionId = request.getSessionId)
+ }
+}
/**
* Static instance of the SparkConnectService.
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index c544f484381..062ef892979 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver
@@ -47,18 +48,30 @@ class SparkConnectStreamHandler(responseObserver:
StreamObserver[ExecutePlanResp
extends Logging {
def handle(v: ExecutePlanRequest): Unit =
SparkConnectArtifactManager.withArtifactClassLoader {
- val session =
- SparkConnectService
- .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
- .session
- session.withActive {
+ val sessionHolder = SparkConnectService
+ .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+ val session = sessionHolder.session
- // Add debug information to the query execution so that the jobs are
traceable.
- try {
- val debugString =
+ session.withActive {
+ val debugString =
+ try {
Utils.redact(
session.sessionState.conf.stringRedactionPattern,
ProtoUtils.abbreviate(v).toString)
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Fail to extract debug information", e)
+ "UNKNOWN"
+ }
+
+ val executeHolder = sessionHolder.createExecutePlanHolder(v)
+ session.sparkContext.setJobGroup(
+ executeHolder.jobGroupId,
+ s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}",
+ interruptOnCancel = true)
+
+ try {
+ // Add debug information to the query execution so that the jobs are
traceable.
session.sparkContext.setLocalProperty(
"callSite.short",
s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
@@ -66,15 +79,19 @@ class SparkConnectStreamHandler(responseObserver:
StreamObserver[ExecutePlanResp
"callSite.long",
StringUtils.abbreviate(debugString, 2048))
} catch {
- case e: Throwable =>
- logWarning("Fail to extract or attach the debug information", e)
+ case NonFatal(e) =>
+ logWarning("Fail to attach the debug information", e)
}
- v.getPlan.getOpTypeCase match {
- case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
- case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
- case _ =>
- throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase}
not supported.")
+ try {
+ v.getPlan.getOpTypeCase match {
+ case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
+ case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
+ case _ =>
+ throw new
UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.")
+ }
+ } finally {
+ sessionHolder.removeExecutePlanHolder(executeHolder.operationId)
}
}
}
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index 2f559b6a538..3f52a14a607 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -566,6 +566,11 @@ ERROR_CLASSES_JSON = """
"Unknown explain mode: '<explain_mode>'. Accepted explain modes are
'simple', 'extended', 'codegen', 'cost', 'formatted'."
]
},
+ "UNKNOWN_INTERRUPT_TYPE" : {
+ "message" : [
+ "Unknown interrupt type: '<interrupt_type>'. Accepted interrupt types
are 'all'."
+ ]
+ },
"UNKNOWN_RESPONSE" : {
"message" : [
"Unknown response: <response>."
diff --git a/python/pyspark/sql/connect/client.py
b/python/pyspark/sql/connect/client.py
index 954ec0dfd04..ffcc4f768ae 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -1099,6 +1099,48 @@ class SparkConnectClient(object):
except Exception as error:
self._handle_error(error)
+ def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest:
+ req = pb2.InterruptRequest()
+ req.session_id = self._session_id
+ req.client_type = self._builder.userAgent
+ if interrupt_type == "all":
+ req.interrupt_type =
pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL
+ else:
+ raise PySparkValueError(
+ error_class="UNKNOWN_INTERRUPT_TYPE",
+ message_parameters={
+ "interrupt_type": str(interrupt_type),
+ },
+ )
+ if self._user_id:
+ req.user_context.user_id = self._user_id
+ return req
+
+ def interrupt_all(self) -> None:
+ """
+ Call the interrupt RPC of Spark Connect to interrupt all executions in
this session.
+
+ Returns
+ -------
+ None
+ """
+ req = self._interrupt_request("all")
+ try:
+ for attempt in Retrying(
+ can_retry=SparkConnectClient.retry_exception,
**self._retry_policy
+ ):
+ with attempt:
+ resp = self._stub.Interrupt(req,
metadata=self._builder.metadata())
+ if resp.session_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request:"
+ f"{resp.session_id} != {self._session_id}"
+ )
+ return
+ raise SparkConnectException("Invalid state during retry exception
handling.")
+ except Exception as error:
+ self._handle_error(error)
+
def _handle_error(self, error: Exception) -> NoReturn:
"""
Handle errors that occur during RPC calls.
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index b6d86b3efbd..997f59911da 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -38,7 +38,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
)
@@ -121,9 +121,12 @@ _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS =
_ARTIFACTSTATUSESRESPONSE.nested_type
_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY =
_ARTIFACTSTATUSESRESPONSE.nested_types_by_name[
"StatusesEntry"
]
+_INTERRUPTREQUEST = DESCRIPTOR.message_types_by_name["InterruptRequest"]
+_INTERRUPTRESPONSE = DESCRIPTOR.message_types_by_name["InterruptResponse"]
_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE =
_ANALYZEPLANREQUEST_EXPLAIN.enum_types_by_name[
"ExplainMode"
]
+_INTERRUPTREQUEST_INTERRUPTTYPE =
_INTERRUPTREQUEST.enum_types_by_name["InterruptType"]
Plan = _reflection.GeneratedProtocolMessageType(
"Plan",
(_message.Message,),
@@ -747,6 +750,28 @@ _sym_db.RegisterMessage(ArtifactStatusesResponse)
_sym_db.RegisterMessage(ArtifactStatusesResponse.ArtifactStatus)
_sym_db.RegisterMessage(ArtifactStatusesResponse.StatusesEntry)
+InterruptRequest = _reflection.GeneratedProtocolMessageType(
+ "InterruptRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _INTERRUPTREQUEST,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.InterruptRequest)
+ },
+)
+_sym_db.RegisterMessage(InterruptRequest)
+
+InterruptResponse = _reflection.GeneratedProtocolMessageType(
+ "InterruptResponse",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _INTERRUPTRESPONSE,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.InterruptResponse)
+ },
+)
+_sym_db.RegisterMessage(InterruptResponse)
+
_SPARKCONNECTSERVICE = DESCRIPTOR.services_by_name["SparkConnectService"]
if _descriptor._USE_C_DESCRIPTORS == False:
@@ -880,6 +905,12 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9194
_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9196
_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9311
- _SPARKCONNECTSERVICE._serialized_start = 9314
- _SPARKCONNECTSERVICE._serialized_end = 9780
+ _INTERRUPTREQUEST._serialized_start = 9314
+ _INTERRUPTREQUEST._serialized_end = 9639
+ _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 9552
+ _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 9623
+ _INTERRUPTRESPONSE._serialized_start = 9641
+ _INTERRUPTRESPONSE._serialized_end = 9691
+ _SPARKCONNECTSERVICE._serialized_start = 9694
+ _SPARKCONNECTSERVICE._serialized_end = 10242
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index df1d3055a5a..fb1bf92c3c3 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -2149,3 +2149,103 @@ class
ArtifactStatusesResponse(google.protobuf.message.Message):
) -> None: ...
global___ArtifactStatusesResponse = ArtifactStatusesResponse
+
+class InterruptRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class _InterruptType:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+ class _InterruptTypeEnumTypeWrapper(
+ google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+ InterruptRequest._InterruptType.ValueType
+ ],
+ builtins.type,
+ ): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ INTERRUPT_TYPE_UNSPECIFIED: InterruptRequest._InterruptType.ValueType
# 0
+ INTERRUPT_TYPE_ALL: InterruptRequest._InterruptType.ValueType # 1
+ """Interrupt all running executions within session with provided
session_id."""
+
+ class InterruptType(_InterruptType,
metaclass=_InterruptTypeEnumTypeWrapper): ...
+ INTERRUPT_TYPE_UNSPECIFIED: InterruptRequest.InterruptType.ValueType # 0
+ INTERRUPT_TYPE_ALL: InterruptRequest.InterruptType.ValueType # 1
+ """Interrupt all running executions within session with provided
session_id."""
+
+ SESSION_ID_FIELD_NUMBER: builtins.int
+ USER_CONTEXT_FIELD_NUMBER: builtins.int
+ CLIENT_TYPE_FIELD_NUMBER: builtins.int
+ INTERRUPT_TYPE_FIELD_NUMBER: builtins.int
+ session_id: builtins.str
+ """(Required)
+
+ The session_id specifies a spark session for a user id (which is specified
+ by user_context.user_id). The session_id is set by the client to be able to
+ collate streaming responses from different queries within the dedicated
session.
+ """
+ @property
+ def user_context(self) -> global___UserContext:
+ """(Required) User context"""
+ client_type: builtins.str
+ """Provides optional information about the client sending the request.
This field
+ can be used for language or version specific information and is only
intended for
+ logging purposes and will not be interpreted by the server.
+ """
+ interrupt_type: global___InterruptRequest.InterruptType.ValueType
+ """(Required) The type of interrupt to execute."""
+ def __init__(
+ self,
+ *,
+ session_id: builtins.str = ...,
+ user_context: global___UserContext | None = ...,
+ client_type: builtins.str | None = ...,
+ interrupt_type: global___InterruptRequest.InterruptType.ValueType =
...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_type",
+ b"_client_type",
+ "client_type",
+ b"client_type",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_type",
+ b"_client_type",
+ "client_type",
+ b"client_type",
+ "interrupt_type",
+ b"interrupt_type",
+ "session_id",
+ b"session_id",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_client_type",
b"_client_type"]
+ ) -> typing_extensions.Literal["client_type"] | None: ...
+
+global___InterruptRequest = InterruptRequest
+
+class InterruptResponse(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SESSION_ID_FIELD_NUMBER: builtins.int
+ session_id: builtins.str
+ def __init__(
+ self,
+ *,
+ session_id: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["session_id",
b"session_id"]
+ ) -> None: ...
+
+global___InterruptResponse = InterruptResponse
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index ecbe4f9c389..74eda7dfeac 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -55,6 +55,11 @@ class SparkConnectServiceStub(object):
request_serializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesRequest.SerializeToString,
response_deserializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesResponse.FromString,
)
+ self.Interrupt = channel.unary_unary(
+ "/spark.connect.SparkConnectService/Interrupt",
+
request_serializer=spark_dot_connect_dot_base__pb2.InterruptRequest.SerializeToString,
+
response_deserializer=spark_dot_connect_dot_base__pb2.InterruptResponse.FromString,
+ )
class SparkConnectServiceServicer(object):
@@ -95,6 +100,12 @@ class SparkConnectServiceServicer(object):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
+ def Interrupt(self, request, context):
+ """Interrupts running executions"""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
def add_SparkConnectServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
@@ -123,6 +134,11 @@ def add_SparkConnectServiceServicer_to_server(servicer,
server):
request_deserializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesRequest.FromString,
response_serializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesResponse.SerializeToString,
),
+ "Interrupt": grpc.unary_unary_rpc_method_handler(
+ servicer.Interrupt,
+
request_deserializer=spark_dot_connect_dot_base__pb2.InterruptRequest.FromString,
+
response_serializer=spark_dot_connect_dot_base__pb2.InterruptResponse.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
"spark.connect.SparkConnectService", rpc_method_handlers
@@ -278,3 +294,32 @@ class SparkConnectService(object):
timeout,
metadata,
)
+
+ @staticmethod
+ def Interrupt(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/spark.connect.SparkConnectService/Interrupt",
+ spark_dot_connect_dot_base__pb2.InterruptRequest.SerializeToString,
+ spark_dot_connect_dot_base__pb2.InterruptResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 5aecb0d3821..fde861d12b9 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -489,6 +489,9 @@ class SparkSession:
except Exception:
pass
+ def interrupt_all(self) -> None:
+ self.client.interrupt_all()
+
def stop(self) -> None:
# Stopping the session will only close the connection to the current
session (and
# the life cycle of the session is maintained by the server),
diff --git a/python/pyspark/sql/tests/connect/test_client.py
b/python/pyspark/sql/tests/connect/test_client.py
index d72a541a675..191a5204bf3 100644
--- a/python/pyspark/sql/tests/connect/test_client.py
+++ b/python/pyspark/sql/tests/connect/test_client.py
@@ -68,6 +68,14 @@ class SparkConnectClientTestCase(unittest.TestCase):
self.assertEqual(client._user_id, "abc")
+ def test_interrupt_all(self):
+ client = SparkConnectClient("sc://foo/;token=bar")
+ mock = MockService(client._session_id)
+ client._stub = mock
+
+ client.interrupt_all()
+ self.assertIsNotNone(mock.req, "Interrupt API was not called when
expected")
+
class MockService:
# Simplest mock of the SparkConnectService.
@@ -97,6 +105,12 @@ class MockService:
resp.arrow_batch.data = buf.to_pybytes()
return [resp]
+ def Interrupt(self, req: proto.InterruptRequest, metadata):
+ self.req = req
+ resp = proto.InterruptResponse()
+ resp.session_id = self._session_id
+ return resp
+
if __name__ == "__main__":
unittest.main()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]