This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dda37841189 [SPARK-44422][CONNECT] Spark Connect fine grained interrupt
dda37841189 is described below
commit dda37841189a753d7a31d22e091b51903f6cd624
Author: Juliusz Sompolski <[email protected]>
AuthorDate: Fri Jul 21 14:23:47 2023 +0900
[SPARK-44422][CONNECT] Spark Connect fine grained interrupt
### What changes were proposed in this pull request?
Currently, Spark Connect only allows to cancel all operations in a session
by using SparkSession.interruptAll().
In this PR we are adding a mechanism to interrupt by tag (similar to
SparkContext.cancelJobsWithTag), and to interrupt individual operations.
Also, add the new tags to SparkListenerConnectOperationStarted.
### Why are the changes needed?
Better control of query cancelation in Spark Connect
### Does this PR introduce _any_ user-facing change?
Yes. New Apis in Spark Connect scala client:
```
SparkSession.addTag
SparkSession.removeTag
SparkSession.getTags
SparkSession.clearTags
SparkSession.interruptTag
SparkSession.interruptOperation
```
and also `SparkResult.operationId`, to be able to get the id for
`SparkSession.interruptOperation`.
Python client APIs will be added in a followup PR.
### How was this patch tested?
Added tests in SparkSessionE2ESuite.
Closes #42009 from juliuszsompolski/sc-fine-grained-cancel.
Authored-by: Juliusz Sompolski <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/resources/error/error-classes.json | 18 +++
.../scala/org/apache/spark/sql/SparkSession.scala | 80 ++++++++++-
.../sql/connect/client/SparkConnectClient.scala | 58 ++++++++
.../spark/sql/connect/client/SparkResult.scala | 27 ++++
.../apache/spark/sql/SparkSessionE2ESuite.scala | 154 ++++++++++++++++++++-
.../CheckConnectJvmClientCompatibility.scala | 12 ++
.../src/main/protobuf/spark/connect/base.proto | 37 ++++-
.../spark/sql/connect/common/ProtoUtils.scala | 24 ++++
.../execution/ExecuteResponseObserver.scala | 25 +++-
.../connect/execution/ExecuteThreadRunner.scala | 25 +++-
.../sql/connect/service/ExecuteEventsManager.scala | 8 +-
.../spark/sql/connect/service/ExecuteHolder.scala | 36 ++++-
.../spark/sql/connect/service/SessionHolder.scala | 75 ++++++++--
.../service/SparkConnectInterruptHandler.scala | 24 +++-
.../service/ExecuteEventsManagerSuite.scala | 1 +
.../main/scala/org/apache/spark/SparkContext.scala | 1 +
...-error-conditions-invalid-handle-error-class.md | 36 +++++
docs/sql-error-conditions.md | 8 ++
python/pyspark/sql/connect/proto/base_pb2.py | 144 +++++++++----------
python/pyspark/sql/connect/proto/base_pb2.pyi | 93 ++++++++++++-
20 files changed, 778 insertions(+), 108 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index 7913a9b9241..73b1cff7c4e 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -1442,6 +1442,24 @@
],
"sqlState" : "22023"
},
+ "INVALID_HANDLE" : {
+ "message" : [
+ "The handle <handle> is invalid."
+ ],
+ "subClass" : {
+ "ALREADY_EXISTS" : {
+ "message" : [
+ "Handle already exists."
+ ]
+ },
+ "FORMAT" : {
+ "message" : [
+ "Handle has invalid format. Handle must an UUID string of the format
'00112233-4455-6677-8899-aabbccddeeff'"
+ ]
+ }
+ },
+ "sqlState" : "HY000"
+ },
"INVALID_HIVE_COLUMN_NAME" : {
"message" : [
"Cannot create the table <tableName> having the nested column
<columnName> whose name contains invalid characters <invalidChars> in Hive
metastore."
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index fb9959c9942..b37e3884038 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -613,14 +613,40 @@ class SparkSession private[sql] (
/**
* Interrupt all operations of this session currently running on the
connected server.
*
- * TODO/WIP: Currently it will interrupt the Spark Jobs running on the
server, triggered from
- * ExecutePlan requests. If an operation is not running a Spark Job, it
becomes an noop and the
- * operation will continue afterwards, possibly with more Spark Jobs.
+ * @return
+ * sequence of operationIds of interrupted operations. Note: there is
still a possiblility of
+ * operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
- def interruptAll(): Unit = {
- client.interruptAll()
+ def interruptAll(): Seq[String] = {
+ client.interruptAll().getInterruptedIdsList.asScala.toSeq
+ }
+
+ /**
+ * Interrupt all operations of this session with the given operation tag.
+ *
+ * @return
+ * sequence of operationIds of interrupted operations. Note: there is
still a possiblility of
+ * operation finishing just as it is interrupted.
+ *
+ * @since 3.5.0
+ */
+ def interruptTag(tag: String): Seq[String] = {
+ client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
+ }
+
+ /**
+ * Interrupt an operation of this session with the given operationId.
+ *
+ * @return
+ * sequence of operationIds of interrupted operations. Note: there is
still a possiblility of
+ * operation finishing just as it is interrupted.
+ *
+ * @since 3.5.0
+ */
+ def interruptOperation(operationId: String): Seq[String] = {
+ client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
}
/**
@@ -641,6 +667,50 @@ class SparkSession private[sql] (
allocator.close()
SparkSession.onSessionClose(this)
}
+
+ /**
+ * Add a tag to be assigned to all the operations started by this thread in
this session.
+ *
+ * @param tag
+ * The tag to be added. Cannot contain ',' (comma) character or be an
empty string.
+ *
+ * @since 3.5.0
+ */
+ def addTag(tag: String): Unit = {
+ client.addTag(tag)
+ }
+
+ /**
+ * Remove a tag previously added to be assigned to all the operations
started by this thread in
+ * this session. Noop if such a tag was not added earlier.
+ *
+ * @param tag
+ * The tag to be removed. Cannot contain ',' (comma) character or be an
empty string.
+ *
+ * @since 3.5.0
+ */
+ def removeTag(tag: String): Unit = {
+ client.removeTag(tag)
+ }
+
+ /**
+ * Get the tags that are currently set to be assigned to all the operations
started by this
+ * thread.
+ *
+ * @since 3.5.0
+ */
+ def getTags(): Set[String] = {
+ client.getTags()
+ }
+
+ /**
+ * Clear the current thread's operation tags.
+ *
+ * @since 3.5.0
+ */
+ def clearTags(): Unit = {
+ client.clearTags()
+ }
}
// The minimal builder needed to create a spark session.
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index b41ae5555bf..d03d27a6f53 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -21,11 +21,15 @@ import java.net.URI
import java.util.UUID
import java.util.concurrent.Executor
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
import com.google.protobuf.ByteString
import io.grpc._
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.UserContext
+import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.common.config.ConnectCommon
/**
@@ -76,6 +80,7 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
+ .addAllTags(tags.get.toSeq.asJava)
.build()
bstub.executePlan(request)
}
@@ -195,6 +200,59 @@ private[sql] class SparkConnectClient(
bstub.interrupt(request)
}
+ private[sql] def interruptTag(tag: String): proto.InterruptResponse = {
+ val builder = proto.InterruptRequest.newBuilder()
+ val request = builder
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .setClientType(userAgent)
+
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG)
+ .setOperationTag(tag)
+ .build()
+ bstub.interrupt(request)
+ }
+
+ private[sql] def interruptOperation(id: String): proto.InterruptResponse = {
+ val builder = proto.InterruptRequest.newBuilder()
+ val request = builder
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .setClientType(userAgent)
+
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID)
+ .setOperationId(id)
+ .build()
+ bstub.interrupt(request)
+ }
+
+ private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
+ override def childValue(parent: mutable.Set[String]): mutable.Set[String]
= {
+ // Note: make a clone such that changes in the parent tags aren't
reflected in
+ // those of the children threads.
+ parent.clone()
+ }
+ override protected def initialValue(): mutable.Set[String] = new
mutable.HashSet[String]()
+ }
+
+ private[sql] def addTag(tag: String): Unit = {
+ // validation is also done server side, but this will give error earlier.
+ ProtoUtils.throwIfInvalidTag(tag)
+ tags.get += tag
+ }
+
+ private[sql] def removeTag(tag: String): Unit = {
+ // validation is also done server side, but this will give error earlier.
+ ProtoUtils.throwIfInvalidTag(tag)
+ tags.get.remove(tag)
+ }
+
+ private[sql] def getTags(): Set[String] = {
+ tags.get.toSet
+ }
+
+ private[sql] def clearTags(): Unit = {
+ tags.get.clear()
+ }
+
def copy(): SparkConnectClient = configuration.toSparkConnectClient
/**
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 1cdc2035de6..eed8bd3f37d 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -40,6 +40,7 @@ private[sql] class SparkResult[T](
extends AutoCloseable
with Cleanable { self =>
+ private[this] var opId: String = _
private[this] var numRecords: Int = 0
private[this] var structType: StructType = _
private[this] var arrowSchema: pojo.Schema = _
@@ -72,6 +73,7 @@ private[sql] class SparkResult[T](
}
private def processResponses(
+ stopOnOperationId: Boolean = false,
stopOnSchema: Boolean = false,
stopOnArrowSchema: Boolean = false,
stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
@@ -79,6 +81,20 @@ private[sql] class SparkResult[T](
var stop = false
while (!stop && responses.hasNext) {
val response = responses.next()
+
+ // Save and validate operationId
+ if (opId == null) {
+ opId = response.getOperationId
+ }
+ if (opId != response.getOperationId) {
+ // backwards compatibility:
+ // response from an old server without operationId field would have
getOperationId == "".
+ throw new IllegalStateException(
+ "Received response with wrong operationId. " +
+ s"Expected '$opId' but received '${response.getOperationId}'.")
+ }
+ stop |= stopOnOperationId
+
if (response.hasSchema) {
// The original schema should arrive before ArrowBatches.
structType =
@@ -148,6 +164,17 @@ private[sql] class SparkResult[T](
structType
}
+ /**
+ * @return
+ * the operationId of the result.
+ */
+ def operationId: String = {
+ if (opId == null) {
+ processResponses(stopOnOperationId = true)
+ }
+ opId
+ }
+
/**
* Create an Array with the contents of the result.
*/
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
index 70eeb6c2c41..5afafaaa6b9 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql
+import scala.collection.mutable
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
import scala.concurrent.duration._
import scala.util.{Failure, Success}
@@ -64,13 +65,16 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
}
// 20 seconds is < 30 seconds the queries should be running,
// because it should be interrupted sooner
+ val interrupted = mutable.ListBuffer[String]()
eventually(timeout(20.seconds), interval(1.seconds)) {
// keep interrupting every second, until both queries get interrupted.
- spark.interruptAll()
+ val ids = spark.interruptAll()
+ interrupted ++= ids
assert(error.isEmpty, s"Error not empty: $error")
assert(q1Interrupted)
assert(q2Interrupted)
}
+ assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}
test("interrupt all - foreground queries, background interrupt") {
@@ -79,9 +83,12 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
implicit val ec: ExecutionContextExecutor = ExecutionContext.global
@volatile var finished = false
+ val interrupted = mutable.ListBuffer[String]()
+
val interruptor = Future {
eventually(timeout(20.seconds), interval(1.seconds)) {
- spark.interruptAll()
+ val ids = spark.interruptAll()
+ interrupted ++= ids
assert(finished)
}
finished
@@ -96,5 +103,148 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
assert(e2.getMessage.contains("OPERATION_CANCELED"), s"Unexpected
exception: $e2")
finished = true
assert(ThreadUtils.awaitResult(interruptor, 10.seconds))
+ assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
+ }
+
+ test("interrupt tag") {
+ val session = spark
+ import session.implicits._
+
+ // global ExecutionContext has only 2 threads in Apache Spark CI
+ // create own thread pool for four Futures used in this test
+ val numThreads = 4
+ val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool",
numThreads)
+ val executionContext = ExecutionContext.fromExecutorService(fpool)
+
+ val q1 = Future {
+ assert(spark.getTags() == Set())
+ spark.addTag("two")
+ assert(spark.getTags() == Set("two"))
+ spark.clearTags() // check that clearing all tags works
+ assert(spark.getTags() == Set())
+ spark.addTag("one")
+ assert(spark.getTags() == Set("one"))
+ try {
+ spark
+ .range(10)
+ .map(n => {
+ Thread.sleep(30000); n
+ })
+ .collect()
+ } finally {
+ spark.clearTags() // clear for the case of thread reuse by another
Future
+ }
+ }(executionContext)
+ val q2 = Future {
+ assert(spark.getTags() == Set())
+ spark.addTag("one")
+ spark.addTag("two")
+ spark.addTag("one")
+ spark.addTag("two") // duplicates shouldn't matter
+ assert(spark.getTags() == Set("one", "two"))
+ try {
+ spark
+ .range(10)
+ .map(n => {
+ Thread.sleep(30000); n
+ })
+ .collect()
+ } finally {
+ spark.clearTags() // clear for the case of thread reuse by another
Future
+ }
+ }(executionContext)
+ val q3 = Future {
+ assert(spark.getTags() == Set())
+ spark.addTag("foo")
+ spark.removeTag("foo")
+ assert(spark.getTags() == Set()) // check that remove works removing the
last tag
+ spark.addTag("two")
+ assert(spark.getTags() == Set("two"))
+ try {
+ spark
+ .range(10)
+ .map(n => {
+ Thread.sleep(30000); n
+ })
+ .collect()
+ } finally {
+ spark.clearTags() // clear for the case of thread reuse by another
Future
+ }
+ }(executionContext)
+ val q4 = Future {
+ assert(spark.getTags() == Set())
+ spark.addTag("one")
+ spark.addTag("two")
+ spark.addTag("two")
+ assert(spark.getTags() == Set("one", "two"))
+ spark.removeTag("two") // check that remove works, despite duplicate add
+ assert(spark.getTags() == Set("one"))
+ try {
+ spark
+ .range(10)
+ .map(n => {
+ Thread.sleep(30000); n
+ })
+ .collect()
+ } finally {
+ spark.clearTags() // clear for the case of thread reuse by another
Future
+ }
+ }(executionContext)
+ val interrupted = mutable.ListBuffer[String]()
+
+ // q2 and q3 should be cancelled
+ interrupted.clear()
+ eventually(timeout(20.seconds), interval(1.seconds)) {
+ val ids = spark.interruptTag("two")
+ interrupted ++= ids
+ assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
+ }
+ val e2 = intercept[SparkException] {
+ ThreadUtils.awaitResult(q2, 1.minute)
+ }
+ assert(e2.getCause.getMessage contains "OPERATION_CANCELED")
+ val e3 = intercept[SparkException] {
+ ThreadUtils.awaitResult(q3, 1.minute)
+ }
+ assert(e3.getCause.getMessage contains "OPERATION_CANCELED")
+ assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
+
+ // q1 and q4 should be cancelled
+ interrupted.clear()
+ eventually(timeout(20.seconds), interval(1.seconds)) {
+ val ids = spark.interruptTag("one")
+ interrupted ++= ids
+ assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
+ }
+ val e1 = intercept[SparkException] {
+ ThreadUtils.awaitResult(q1, 1.minute)
+ }
+ assert(e1.getCause.getMessage contains "OPERATION_CANCELED")
+ val e4 = intercept[SparkException] {
+ ThreadUtils.awaitResult(q4, 1.minute)
+ }
+ assert(e4.getCause.getMessage contains "OPERATION_CANCELED")
+ assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
+ }
+
+ test("interrupt operation") {
+ val session = spark
+ import session.implicits._
+
+ val result = spark
+ .range(10)
+ .map(n => {
+ Thread.sleep(5000); n
+ })
+ .collectResult()
+ // cancel
+ val operationId = result.operationId
+ val canceledId = spark.interruptOperation(operationId)
+ assert(canceledId == Seq(operationId))
+ // and check that it got canceled
+ val e = intercept[SparkException] {
+ result.toArray
+ }
+ assert(e.getMessage contains "OPERATION_CANCELED")
}
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index e7f01d6140d..deb2ff631fd 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -365,6 +365,18 @@ object CheckConnectJvmClientCompatibility {
// public
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.interruptAll"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.SparkSession.interruptTag"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.SparkSession.interruptOperation"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.SparkSession.addTag"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.SparkSession.removeTag"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.SparkSession.getTags"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.SparkSession.clearTags"),
// SparkSession#Builder
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession#Builder.remote"),
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index e869712858a..d935ae65328 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -281,6 +281,12 @@ message ExecutePlanRequest {
// server side.
UserContext user_context = 2;
+ // (Optional)
+ // Provide an id for this request. If not provided, it will be generated by
the server.
+ // It is returned in every ExecutePlanResponse.operation_id of the
ExecutePlan response stream.
+ // The id must be an UUID string of the format
`00112233-4455-6677-8899-aabbccddeeff`
+ optional string operation_id = 6;
+
// (Required) The logical plan to be executed / analyzed.
Plan plan = 3;
@@ -299,6 +305,11 @@ message ExecutePlanRequest {
google.protobuf.Any extension = 999;
}
}
+
+ // Tags to tag the given execution with.
+ // Tags cannot contain ',' character and cannot be empty strings.
+ // Used by Interrupt with interrupt.tag.
+ repeated string tags = 7;
}
// The response of a query, can be one or more for each request. Responses
belonging to the
@@ -306,6 +317,12 @@ message ExecutePlanRequest {
message ExecutePlanResponse {
string session_id = 1;
+ // Identifies the ExecutePlan execution.
+ // If set by the client in ExecutePlanRequest.operationId, that value is
returned.
+ // Otherwise generated by the server.
+ // It is an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
+ string operation_id = 12;
+
// Union type for the different response messages.
oneof response_type {
ArrowBatch arrow_batch = 2;
@@ -616,13 +633,31 @@ message InterruptRequest {
enum InterruptType {
INTERRUPT_TYPE_UNSPECIFIED = 0;
- // Interrupt all running executions within session with provided
session_id.
+ // Interrupt all running executions within the session with the provided
session_id.
INTERRUPT_TYPE_ALL = 1;
+
+ // Interrupt all running executions within the session with the provided
operation_tag.
+ INTERRUPT_TYPE_TAG = 2;
+
+ // Interrupt the running execution within the session with the provided
operation_id.
+ INTERRUPT_TYPE_OPERATION_ID = 3;
+ }
+
+ oneof interrupt {
+ // if interrupt_tag == INTERRUPT_TYPE_TAG, interrupt operation with this
tag.
+ string operation_tag = 5;
+
+ // if interrupt_tag == INTERRUPT_TYPE_OPERATION_ID, interrupt operation
with this operation_id.
+ string operation_id = 6;
}
}
message InterruptResponse {
+ // Session id in which the interrupt was running.
string session_id = 1;
+
+ // Operation ids of the executions which were interrupted.
+ repeated string interrupted_ids = 2;
}
// Main interface for the SparkConnect service.
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
index e0c7d267c60..e2934b56744 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoUtils.scala
@@ -81,4 +81,28 @@ private[connect] object ProtoUtils {
private def createString(prefix: String, size: Int): String = {
s"$prefix[truncated(size=${format.format(size)})]"
}
+
+ // Because Spark Connect operation tags are also set as SparkContext Job
tags, they cannot contain
+ // SparkContext.SPARK_JOB_TAGS_SEP
+ private var SPARK_JOB_TAGS_SEP = ',' // SparkContext.SPARK_JOB_TAGS_SEP
+
+ /**
+ * Validate if a tag for ExecutePlanRequest.tags is valid. Throw
IllegalArgumentException if
+ * not.
+ */
+ def throwIfInvalidTag(tag: String): Unit = {
+ // Same format rules apply to Spark Connect execution tags as to
SparkContext job tags,
+ // because the Spark Connect job tag is also used as part of SparkContext
job tag.
+ // See SparkContext.throwIfInvalidTag and ExecuteHolder.tagToSparkJobTag
+ if (tag == null) {
+ throw new IllegalArgumentException("Spark Connect tag cannot be null.")
+ }
+ if (tag.contains(SPARK_JOB_TAGS_SEP)) {
+ throw new IllegalArgumentException(
+ s"Spark Connect tag cannot contain '$SPARK_JOB_TAGS_SEP'.")
+ }
+ if (tag.isEmpty) {
+ throw new IllegalArgumentException("Spark Connect tag cannot be an empty
string.")
+ }
+ }
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
index 5aecbdfce16..ae89c150a68 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala
@@ -21,7 +21,9 @@ import scala.collection.mutable
import io.grpc.stub.StreamObserver
+import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connect.service.ExecuteHolder
/**
* This StreamObserver is running on the execution thread. Execution pushes
responses to it, it
@@ -40,7 +42,9 @@ import org.apache.spark.internal.Logging
* @see
* attachConsumer
*/
-private[connect] class ExecuteResponseObserver[T]() extends StreamObserver[T]
with Logging {
+private[connect] class ExecuteResponseObserver[T](val executeHolder:
ExecuteHolder)
+ extends StreamObserver[T]
+ with Logging {
/**
* Cached responses produced by the execution. Map from response index ->
response. Response
@@ -77,7 +81,9 @@ private[connect] class ExecuteResponseObserver[T]() extends
StreamObserver[T] wi
throw new IllegalStateException("Stream onNext can't be called after
stream completed")
}
lastProducedIndex += 1
- responses += ((lastProducedIndex, CachedStreamResponse[T](r,
lastProducedIndex)))
+ val processedResponse = setCommonResponseFields(r)
+ responses +=
+ ((lastProducedIndex, CachedStreamResponse[T](processedResponse,
lastProducedIndex)))
logDebug(s"Saved response with index=$lastProducedIndex")
notifyAll()
}
@@ -158,4 +164,19 @@ private[connect] class ExecuteResponseObserver[T]()
extends StreamObserver[T] wi
i -= 1
}
}
+
+ /**
+ * Populate response fields that are common and should be set in every
response.
+ */
+ private def setCommonResponseFields(response: T): T = {
+ response match {
+ case executePlanResponse: proto.ExecutePlanResponse =>
+ executePlanResponse
+ .toBuilder()
+ .setSessionId(executeHolder.sessionHolder.sessionId)
+ .setOperationId(executeHolder.operationId)
+ .build()
+ .asInstanceOf[T]
+ }
+ }
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 6c2ffa46547..6758df0d7e6 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -54,11 +54,20 @@ private[connect] class ExecuteThreadRunner(executeHolder:
ExecuteHolder) extends
executionThread.join()
}
- /** Interrupt the executing thread. */
- def interrupt(): Unit = {
+ /**
+ * Interrupt the executing thread.
+ * @return
+ * true if it was not interrupted before, false if it was already
interrupted.
+ */
+ def interrupt(): Boolean = {
synchronized {
- interrupted = true
- executionThread.interrupt()
+ if (!interrupted) {
+ interrupted = true
+ executionThread.interrupt()
+ true
+ } else {
+ false
+ }
}
}
@@ -85,6 +94,10 @@ private[connect] class ExecuteThreadRunner(executeHolder:
ExecuteHolder) extends
}
} finally {
executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag)
+ executeHolder.sparkSessionTags.foreach { tag =>
+ executeHolder.sessionHolder.session.sparkContext
+ .removeJobTag(executeHolder.tagToSparkJobTag(tag))
+ }
}
} catch {
ErrorUtils.handleError(
@@ -113,6 +126,10 @@ private[connect] class ExecuteThreadRunner(executeHolder:
ExecuteHolder) extends
// Set tag for query cancellation
session.sparkContext.addJobTag(executeHolder.jobTag)
+ // Also set all user defined tags as Spark Job tags.
+ executeHolder.sparkSessionTags.foreach { tag =>
+ session.sparkContext.addJobTag(executeHolder.tagToSparkJobTag(tag))
+ }
session.sparkContext.setJobDescription(
s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
session.sparkContext.setInterruptOnCancel(true)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
index 0af54f034a2..5e831aaa98f 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala
@@ -59,6 +59,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder,
clock: Clock) {
private def jobTag = executeHolder.jobTag
+ private def sparkSessionTags = executeHolder.sparkSessionTags
+
private def listenerBus = sessionHolder.session.sparkContext.listenerBus
private def sessionHolder = executeHolder.sessionHolder
@@ -119,7 +121,8 @@ case class ExecuteEventsManager(executeHolder:
ExecuteHolder, clock: Clock) {
Utils.redact(
sessionHolder.session.sessionState.conf.stringRedactionPattern,
ProtoUtils.abbreviate(plan,
ExecuteEventsManager.MAX_STATEMENT_TEXT_SIZE).toString),
- Some(request)))
+ Some(request),
+ sparkSessionTags))
}
/**
@@ -270,6 +273,8 @@ case class ExecuteEventsManager(executeHolder:
ExecuteHolder, clock: Clock) {
* The connect request plan converted to text.
* @param planRequest:
* The Connect request. None if the operation is not of type @link
proto.ExecutePlanRequest
+ * @param sparkSessionTags:
+ * Extra tags set by the user (via SparkSession.addTag).
* @param extraTags:
* Additional metadata during the request.
*/
@@ -282,6 +287,7 @@ case class SparkListenerConnectOperationStarted(
userName: String,
statementText: String,
planRequest: Option[proto.ExecutePlanRequest],
+ sparkSessionTags: Set[String],
extraTags: Map[String, String] = Map.empty)
extends SparkListenerEvent
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index 1f70973b60e..74530ad032f 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql.connect.service
+import scala.collection.JavaConverters._
+
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender,
ExecuteResponseObserver, ExecuteThreadRunner}
import org.apache.spark.util.SystemClock
@@ -31,16 +34,34 @@ private[connect] class ExecuteHolder(
val sessionHolder: SessionHolder)
extends Logging {
+ /**
+ * Tag that is set for this execution on SparkContext, via
SparkContext.addJobTag. Used
+ * (internally) for cancallation of the Spark Jobs ran by this execution.
+ */
val jobTag =
s"SparkConnect_Execute_" +
s"User_${sessionHolder.userId}_" +
s"Session_${sessionHolder.sessionId}_" +
s"Request_${operationId}"
+ /**
+ * Tags set by Spark Connect client users via SparkSession.addTag. Used to
identify and group
+ * executions, and for user cancellation using SparkSession.interruptTag.
+ */
+ val sparkSessionTags: Set[String] = request
+ .getTagsList()
+ .asScala
+ .toSeq
+ .map { tag =>
+ ProtoUtils.throwIfInvalidTag(tag)
+ tag
+ }
+ .toSet
+
val session = sessionHolder.session
val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] =
- new ExecuteResponseObserver[proto.ExecutePlanResponse]()
+ new ExecuteResponseObserver[proto.ExecutePlanResponse](this)
val eventsManager: ExecuteEventsManager = ExecuteEventsManager(this, new
SystemClock())
@@ -85,8 +106,19 @@ private[connect] class ExecuteHolder(
/**
* Interrupt the execution. Interrupts the running thread, which cancels all
running Spark Jobs
* and makes the execution throw an OPERATION_CANCELED error.
+ * @return
+ * true if it was not interrupted before, false if it was already
interrupted.
*/
- def interrupt(): Unit = {
+ def interrupt(): Boolean = {
runner.interrupt()
}
+
+ /**
+ * Spark Connect tags are also added as SparkContext job tags, but to make
the tag unique, they
+ * need to be combined with userId and sessionId.
+ */
+ def tagToSparkJobTag(tag: String): String = {
+ "SparkConnect_Tag_" +
+ s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}"
+ }
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 5ac4f6db82a..ae53d1d171f 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -22,10 +22,9 @@ import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
+import scala.collection.mutable
-import org.apache.spark.JobArtifactSet
-import org.apache.spark.SparkException
+import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException}
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
@@ -57,9 +56,25 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
new ConcurrentHashMap()
private[connect] def createExecuteHolder(request: proto.ExecutePlanRequest):
ExecuteHolder = {
- val operationId = UUID.randomUUID().toString
+ val operationId = if (request.hasOperationId) {
+ try {
+ UUID.fromString(request.getOperationId).toString
+ } catch {
+ case _: IllegalArgumentException =>
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.FORMAT",
+ messageParameters = Map("handle" -> request.getOperationId))
+ }
+ } else {
+ UUID.randomUUID().toString
+ }
val executePlanHolder = new ExecuteHolder(request, operationId, this)
- assert(executions.putIfAbsent(operationId, executePlanHolder) == null)
+ val oldExecute = executions.putIfAbsent(operationId, executePlanHolder)
+ if (oldExecute != null) {
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.ALREADY_EXISTS",
+ messageParameters = Map("handle" -> operationId))
+ }
executePlanHolder
}
@@ -71,17 +86,51 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
executions.remove(operationId)
}
- private[connect] def interruptAll(): Unit = {
+ /**
+ * Interrupt all executions in the session.
+ * @return
+ * list of operationIds of interrupted executions
+ */
+ private[connect] def interruptAll(): Seq[String] = {
+ val interruptedIds = new mutable.ArrayBuffer[String]()
executions.asScala.values.foreach { execute =>
- // Eat exception while trying to interrupt a given execution and move
forward.
- try {
- logDebug(s"Interrupting execution ${execute.operationId}")
- execute.interrupt()
- } catch {
- case NonFatal(e) =>
- logWarning(s"Exception $e while trying to interrupt execution
${execute.operationId}")
+ if (execute.interrupt()) {
+ interruptedIds += execute.operationId
+ }
+ }
+ interruptedIds.toSeq
+ }
+
+ /**
+ * Interrupt executions in the session with a given tag.
+ * @return
+ * list of operationIds of interrupted executions
+ */
+ private[connect] def interruptTag(tag: String): Seq[String] = {
+ val interruptedIds = new mutable.ArrayBuffer[String]()
+ executions.asScala.values.foreach { execute =>
+ if (execute.sparkSessionTags.contains(tag)) {
+ if (execute.interrupt()) {
+ interruptedIds += execute.operationId
+ }
+ }
+ }
+ interruptedIds.toSeq
+ }
+
+ /**
+ * Interrupt the execution with the given operation_id
+ * @return
+ * list of operationIds of interrupted executions (one element or empty)
+ */
+ private[connect] def interruptOperation(operationId: String): Seq[String] = {
+ val interruptedIds = new mutable.ArrayBuffer[String]()
+ Option(executions.get(operationId)).foreach { execute =>
+ if (execute.interrupt()) {
+ interruptedIds += execute.operationId
}
}
+ interruptedIds.toSeq
}
private[connect] lazy val artifactManager = new
SparkConnectArtifactManager(this)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
index b0923e277e4..a9ed391460c 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.connect.service
+import scala.collection.JavaConverters._
+
import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
@@ -30,16 +32,32 @@ class SparkConnectInterruptHandler(responseObserver:
StreamObserver[proto.Interr
SparkConnectService
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
- v.getInterruptType match {
+ val interruptedIds = v.getInterruptType match {
case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL =>
sessionHolder.interruptAll()
+ case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG =>
+ if (!v.hasOperationTag) {
+ throw new IllegalArgumentException(
+ s"INTERRUPT_TYPE_TAG requested, but no operation_tag provided.")
+ }
+ sessionHolder.interruptTag(v.getOperationTag)
+ case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID =>
+ if (!v.hasOperationId) {
+ throw new IllegalArgumentException(
+ s"INTERRUPT_TYPE_OPERATION_ID requested, but no operation_id
provided.")
+ }
+ sessionHolder.interruptOperation(v.getOperationId)
case other =>
throw new UnsupportedOperationException(s"Unknown InterruptType
$other!")
}
- val builder =
proto.InterruptResponse.newBuilder().setSessionId(v.getSessionId)
+ val response = proto.InterruptResponse
+ .newBuilder()
+ .setSessionId(v.getSessionId)
+ .addAllInterruptedIds(interruptedIds.asJava)
+ .build()
- responseObserver.onNext(builder.build())
+ responseObserver.onNext(response)
responseObserver.onCompleted()
}
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
index 365b17632a7..27c57e0d759 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
@@ -64,6 +64,7 @@ class ExecuteEventsManagerSuite
DEFAULT_USER_NAME,
DEFAULT_TEXT,
Some(events.executeHolder.request),
+ Set.empty,
Map.empty))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 80f7eaf00ed..26fdb86d299 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -2975,6 +2975,7 @@ object SparkContext extends Logging {
/** Separator of tags in SPARK_JOB_TAGS property */
private[spark] val SPARK_JOB_TAGS_SEP = ","
+ // Same rules apply to Spark Connect execution tags, see
ExecuteHolder.throwIfInvalidTag
private[spark] def throwIfInvalidTag(tag: String) = {
if (tag == null) {
throw new IllegalArgumentException("Spark job tag cannot be null.")
diff --git a/docs/sql-error-conditions-invalid-handle-error-class.md
b/docs/sql-error-conditions-invalid-handle-error-class.md
new file mode 100644
index 00000000000..7c083bc5f50
--- /dev/null
+++ b/docs/sql-error-conditions-invalid-handle-error-class.md
@@ -0,0 +1,36 @@
+---
+layout: global
+title: INVALID_HANDLE error class
+displayTitle: INVALID_HANDLE error class
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+[SQLSTATE:
HY000](sql-error-conditions-sqlstates.html#class-HY-cli-specific-condition)
+
+The handle `<handle>` is invalid.
+
+This error class has the following derived error classes:
+
+## ALREADY_EXISTS
+
+Handle already exists.
+
+## FORMAT
+
+Handle has invalid format. Handle must an UUID string of the format
'00112233-4455-6677-8899-aabbccddeeff'
+
+
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index cd04c414df3..12238b6724b 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -918,6 +918,14 @@ For more details see
[INVALID_FORMAT](sql-error-conditions-invalid-format-error-
The fraction of sec must be zero. Valid range is [0, 60]. If necessary set
`<ansiConfig>` to "false" to bypass this error.
+### [INVALID_HANDLE](sql-error-conditions-invalid-handle-error-class.html)
+
+[SQLSTATE:
HY000](sql-error-conditions-sqlstates.html#class-HY-cli-specific-condition)
+
+The handle `<handle>` is invalid.
+
+For more details see
[INVALID_HANDLE](sql-error-conditions-invalid-handle-error-class.html)
+
### INVALID_HIVE_COLUMN_NAME
SQLSTATE: none assigned
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 7bf93ed58fa..04044d4cdcf 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -115,75 +115,75 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_ANALYZEPLANRESPONSE_GETSTORAGELEVEL._serialized_start = 4482
_ANALYZEPLANRESPONSE_GETSTORAGELEVEL._serialized_end = 4565
_EXECUTEPLANREQUEST._serialized_start = 4578
- _EXECUTEPLANREQUEST._serialized_end = 4967
- _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 4863
- _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 4951
- _EXECUTEPLANRESPONSE._serialized_start = 4970
- _EXECUTEPLANRESPONSE._serialized_end = 6735
- _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 5966
- _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6037
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 6039
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 6100
- _EXECUTEPLANRESPONSE_METRICS._serialized_start = 6103
- _EXECUTEPLANRESPONSE_METRICS._serialized_end = 6620
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 6198
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 6530
-
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start
= 6407
-
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end
= 6530
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 6532
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 6620
- _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 6622
- _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 6718
- _KEYVALUE._serialized_start = 6737
- _KEYVALUE._serialized_end = 6802
- _CONFIGREQUEST._serialized_start = 6805
- _CONFIGREQUEST._serialized_end = 7833
- _CONFIGREQUEST_OPERATION._serialized_start = 7025
- _CONFIGREQUEST_OPERATION._serialized_end = 7523
- _CONFIGREQUEST_SET._serialized_start = 7525
- _CONFIGREQUEST_SET._serialized_end = 7577
- _CONFIGREQUEST_GET._serialized_start = 7579
- _CONFIGREQUEST_GET._serialized_end = 7604
- _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 7606
- _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 7669
- _CONFIGREQUEST_GETOPTION._serialized_start = 7671
- _CONFIGREQUEST_GETOPTION._serialized_end = 7702
- _CONFIGREQUEST_GETALL._serialized_start = 7704
- _CONFIGREQUEST_GETALL._serialized_end = 7752
- _CONFIGREQUEST_UNSET._serialized_start = 7754
- _CONFIGREQUEST_UNSET._serialized_end = 7781
- _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 7783
- _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 7817
- _CONFIGRESPONSE._serialized_start = 7835
- _CONFIGRESPONSE._serialized_end = 7957
- _ADDARTIFACTSREQUEST._serialized_start = 7960
- _ADDARTIFACTSREQUEST._serialized_end = 8831
- _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8347
- _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8400
- _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8402
- _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8513
- _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8515
- _ADDARTIFACTSREQUEST_BATCH._serialized_end = 8608
- _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 8611
- _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 8804
- _ADDARTIFACTSRESPONSE._serialized_start = 8834
- _ADDARTIFACTSRESPONSE._serialized_end = 9022
- _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 8941
- _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9022
- _ARTIFACTSTATUSESREQUEST._serialized_start = 9025
- _ARTIFACTSTATUSESREQUEST._serialized_end = 9220
- _ARTIFACTSTATUSESRESPONSE._serialized_start = 9223
- _ARTIFACTSTATUSESRESPONSE._serialized_end = 9491
- _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9334
- _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9374
- _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9376
- _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9491
- _INTERRUPTREQUEST._serialized_start = 9494
- _INTERRUPTREQUEST._serialized_end = 9819
- _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 9732
- _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 9803
- _INTERRUPTRESPONSE._serialized_start = 9821
- _INTERRUPTRESPONSE._serialized_end = 9871
- _SPARKCONNECTSERVICE._serialized_start = 9874
- _SPARKCONNECTSERVICE._serialized_end = 10422
+ _EXECUTEPLANREQUEST._serialized_end = 5044
+ _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 4923
+ _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5011
+ _EXECUTEPLANRESPONSE._serialized_start = 5047
+ _EXECUTEPLANRESPONSE._serialized_end = 6847
+ _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 6078
+ _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 6149
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 6151
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 6212
+ _EXECUTEPLANRESPONSE_METRICS._serialized_start = 6215
+ _EXECUTEPLANRESPONSE_METRICS._serialized_end = 6732
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 6310
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 6642
+
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start
= 6519
+
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end
= 6642
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 6644
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 6732
+ _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 6734
+ _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 6830
+ _KEYVALUE._serialized_start = 6849
+ _KEYVALUE._serialized_end = 6914
+ _CONFIGREQUEST._serialized_start = 6917
+ _CONFIGREQUEST._serialized_end = 7945
+ _CONFIGREQUEST_OPERATION._serialized_start = 7137
+ _CONFIGREQUEST_OPERATION._serialized_end = 7635
+ _CONFIGREQUEST_SET._serialized_start = 7637
+ _CONFIGREQUEST_SET._serialized_end = 7689
+ _CONFIGREQUEST_GET._serialized_start = 7691
+ _CONFIGREQUEST_GET._serialized_end = 7716
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 7718
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 7781
+ _CONFIGREQUEST_GETOPTION._serialized_start = 7783
+ _CONFIGREQUEST_GETOPTION._serialized_end = 7814
+ _CONFIGREQUEST_GETALL._serialized_start = 7816
+ _CONFIGREQUEST_GETALL._serialized_end = 7864
+ _CONFIGREQUEST_UNSET._serialized_start = 7866
+ _CONFIGREQUEST_UNSET._serialized_end = 7893
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 7895
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 7929
+ _CONFIGRESPONSE._serialized_start = 7947
+ _CONFIGRESPONSE._serialized_end = 8069
+ _ADDARTIFACTSREQUEST._serialized_start = 8072
+ _ADDARTIFACTSREQUEST._serialized_end = 8943
+ _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 8459
+ _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 8512
+ _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 8514
+ _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 8625
+ _ADDARTIFACTSREQUEST_BATCH._serialized_start = 8627
+ _ADDARTIFACTSREQUEST_BATCH._serialized_end = 8720
+ _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 8723
+ _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 8916
+ _ADDARTIFACTSRESPONSE._serialized_start = 8946
+ _ADDARTIFACTSRESPONSE._serialized_end = 9134
+ _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 9053
+ _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 9134
+ _ARTIFACTSTATUSESREQUEST._serialized_start = 9137
+ _ARTIFACTSTATUSESREQUEST._serialized_end = 9332
+ _ARTIFACTSTATUSESRESPONSE._serialized_start = 9335
+ _ARTIFACTSTATUSESRESPONSE._serialized_end = 9603
+ _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9446
+ _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9486
+ _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9488
+ _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9603
+ _INTERRUPTREQUEST._serialized_start = 9606
+ _INTERRUPTREQUEST._serialized_end = 10078
+ _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 9921
+ _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 10049
+ _INTERRUPTRESPONSE._serialized_start = 10080
+ _INTERRUPTRESPONSE._serialized_end = 10171
+ _SPARKCONNECTSERVICE._serialized_start = 10174
+ _SPARKCONNECTSERVICE._serialized_end = 10722
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 633058f33ed..651438ea438 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1031,9 +1031,11 @@ class
ExecutePlanRequest(google.protobuf.message.Message):
SESSION_ID_FIELD_NUMBER: builtins.int
USER_CONTEXT_FIELD_NUMBER: builtins.int
+ OPERATION_ID_FIELD_NUMBER: builtins.int
PLAN_FIELD_NUMBER: builtins.int
CLIENT_TYPE_FIELD_NUMBER: builtins.int
REQUEST_OPTIONS_FIELD_NUMBER: builtins.int
+ TAGS_FIELD_NUMBER: builtins.int
session_id: builtins.str
"""(Required)
@@ -1048,6 +1050,12 @@ class
ExecutePlanRequest(google.protobuf.message.Message):
user_context.user_id and session+id both identify a unique remote
spark session on the
server side.
"""
+ operation_id: builtins.str
+ """(Optional)
+ Provide an id for this request. If not provided, it will be generated by
the server.
+ It is returned in every ExecutePlanResponse.operation_id of the
ExecutePlan response stream.
+ The id must be an UUID string of the format
`00112233-4455-6677-8899-aabbccddeeff`
+ """
@property
def plan(self) -> global___Plan:
"""(Required) The logical plan to be executed / analyzed."""
@@ -1065,23 +1073,37 @@ class
ExecutePlanRequest(google.protobuf.message.Message):
"""Repeated element for options that can be passed to the request.
This element is currently
unused but allows to pass in an extension value used for arbitrary
options.
"""
+ @property
+ def tags(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """Tags to tag the given execution with.
+ Tags cannot contain ',' character and cannot be empty strings.
+ Used by Interrupt with interrupt.tag.
+ """
def __init__(
self,
*,
session_id: builtins.str = ...,
user_context: global___UserContext | None = ...,
+ operation_id: builtins.str | None = ...,
plan: global___Plan | None = ...,
client_type: builtins.str | None = ...,
request_options:
collections.abc.Iterable[global___ExecutePlanRequest.RequestOption]
| None = ...,
+ tags: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_client_type",
b"_client_type",
+ "_operation_id",
+ b"_operation_id",
"client_type",
b"client_type",
+ "operation_id",
+ b"operation_id",
"plan",
b"plan",
"user_context",
@@ -1093,21 +1115,32 @@ class
ExecutePlanRequest(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"_client_type",
b"_client_type",
+ "_operation_id",
+ b"_operation_id",
"client_type",
b"client_type",
+ "operation_id",
+ b"operation_id",
"plan",
b"plan",
"request_options",
b"request_options",
"session_id",
b"session_id",
+ "tags",
+ b"tags",
"user_context",
b"user_context",
],
) -> None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_client_type",
b"_client_type"]
) -> typing_extensions.Literal["client_type"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_operation_id",
b"_operation_id"]
+ ) -> typing_extensions.Literal["operation_id"] | None: ...
global___ExecutePlanRequest = ExecutePlanRequest
@@ -1290,6 +1323,7 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
) -> None: ...
SESSION_ID_FIELD_NUMBER: builtins.int
+ OPERATION_ID_FIELD_NUMBER: builtins.int
ARROW_BATCH_FIELD_NUMBER: builtins.int
SQL_COMMAND_RESULT_FIELD_NUMBER: builtins.int
WRITE_STREAM_OPERATION_START_RESULT_FIELD_NUMBER: builtins.int
@@ -1301,6 +1335,12 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
OBSERVED_METRICS_FIELD_NUMBER: builtins.int
SCHEMA_FIELD_NUMBER: builtins.int
session_id: builtins.str
+ operation_id: builtins.str
+ """Identifies the ExecutePlan execution.
+ If set by the client in ExecutePlanRequest.operationId, that value is
returned.
+ Otherwise generated by the server.
+ It is an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
+ """
@property
def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ...
@property
@@ -1348,6 +1388,7 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
self,
*,
session_id: builtins.str = ...,
+ operation_id: builtins.str = ...,
arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ...,
sql_command_result: global___ExecutePlanResponse.SqlCommandResult |
None = ...,
write_stream_operation_start_result:
pyspark.sql.connect.proto.commands_pb2.WriteStreamOperationStartResult
@@ -1402,6 +1443,8 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
b"metrics",
"observed_metrics",
b"observed_metrics",
+ "operation_id",
+ b"operation_id",
"response_type",
b"response_type",
"schema",
@@ -2208,17 +2251,27 @@ class InterruptRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
INTERRUPT_TYPE_UNSPECIFIED: InterruptRequest._InterruptType.ValueType
# 0
INTERRUPT_TYPE_ALL: InterruptRequest._InterruptType.ValueType # 1
- """Interrupt all running executions within session with provided
session_id."""
+ """Interrupt all running executions within the session with the
provided session_id."""
+ INTERRUPT_TYPE_TAG: InterruptRequest._InterruptType.ValueType # 2
+ """Interrupt all running executions within the session with the
provided operation_tag."""
+ INTERRUPT_TYPE_OPERATION_ID: InterruptRequest._InterruptType.ValueType
# 3
+ """Interrupt the running execution within the session with the
provided operation_id."""
class InterruptType(_InterruptType,
metaclass=_InterruptTypeEnumTypeWrapper): ...
INTERRUPT_TYPE_UNSPECIFIED: InterruptRequest.InterruptType.ValueType # 0
INTERRUPT_TYPE_ALL: InterruptRequest.InterruptType.ValueType # 1
- """Interrupt all running executions within session with provided
session_id."""
+ """Interrupt all running executions within the session with the provided
session_id."""
+ INTERRUPT_TYPE_TAG: InterruptRequest.InterruptType.ValueType # 2
+ """Interrupt all running executions within the session with the provided
operation_tag."""
+ INTERRUPT_TYPE_OPERATION_ID: InterruptRequest.InterruptType.ValueType # 3
+ """Interrupt the running execution within the session with the provided
operation_id."""
SESSION_ID_FIELD_NUMBER: builtins.int
USER_CONTEXT_FIELD_NUMBER: builtins.int
CLIENT_TYPE_FIELD_NUMBER: builtins.int
INTERRUPT_TYPE_FIELD_NUMBER: builtins.int
+ OPERATION_TAG_FIELD_NUMBER: builtins.int
+ OPERATION_ID_FIELD_NUMBER: builtins.int
session_id: builtins.str
"""(Required)
@@ -2236,6 +2289,10 @@ class InterruptRequest(google.protobuf.message.Message):
"""
interrupt_type: global___InterruptRequest.InterruptType.ValueType
"""(Required) The type of interrupt to execute."""
+ operation_tag: builtins.str
+ """if interrupt_tag == INTERRUPT_TYPE_TAG, interrupt operation with this
tag."""
+ operation_id: builtins.str
+ """if interrupt_tag == INTERRUPT_TYPE_OPERATION_ID, interrupt operation
with this operation_id."""
def __init__(
self,
*,
@@ -2243,6 +2300,8 @@ class InterruptRequest(google.protobuf.message.Message):
user_context: global___UserContext | None = ...,
client_type: builtins.str | None = ...,
interrupt_type: global___InterruptRequest.InterruptType.ValueType =
...,
+ operation_tag: builtins.str = ...,
+ operation_id: builtins.str = ...,
) -> None: ...
def HasField(
self,
@@ -2251,6 +2310,12 @@ class InterruptRequest(google.protobuf.message.Message):
b"_client_type",
"client_type",
b"client_type",
+ "interrupt",
+ b"interrupt",
+ "operation_id",
+ b"operation_id",
+ "operation_tag",
+ b"operation_tag",
"user_context",
b"user_context",
],
@@ -2262,17 +2327,28 @@ class InterruptRequest(google.protobuf.message.Message):
b"_client_type",
"client_type",
b"client_type",
+ "interrupt",
+ b"interrupt",
"interrupt_type",
b"interrupt_type",
+ "operation_id",
+ b"operation_id",
+ "operation_tag",
+ b"operation_tag",
"session_id",
b"session_id",
"user_context",
b"user_context",
],
) -> None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_client_type",
b"_client_type"]
) -> typing_extensions.Literal["client_type"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["interrupt", b"interrupt"]
+ ) -> typing_extensions.Literal["operation_tag", "operation_id"] | None: ...
global___InterruptRequest = InterruptRequest
@@ -2280,14 +2356,25 @@ class
InterruptResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SESSION_ID_FIELD_NUMBER: builtins.int
+ INTERRUPTED_IDS_FIELD_NUMBER: builtins.int
session_id: builtins.str
+ """Session id in which the interrupt was running."""
+ @property
+ def interrupted_ids(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """Operation ids of the executions which were interrupted."""
def __init__(
self,
*,
session_id: builtins.str = ...,
+ interrupted_ids: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(
- self, field_name: typing_extensions.Literal["session_id",
b"session_id"]
+ self,
+ field_name: typing_extensions.Literal[
+ "interrupted_ids", b"interrupted_ids", "session_id", b"session_id"
+ ],
) -> None: ...
global___InterruptResponse = InterruptResponse
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]