This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fd8e99b9df55 [SPARK-49249][SPARK-49320] Add new tag-related APIs in
Connect back to Spark Core
fd8e99b9df55 is described below
commit fd8e99b9df55bf2ea29b6279a6a840ffef20ed4e
Author: Paddy Xu <[email protected]>
AuthorDate: Tue Sep 17 23:06:05 2024 -0400
[SPARK-49249][SPARK-49320] Add new tag-related APIs in Connect back to
Spark Core
### What changes were proposed in this pull request?
This PR adds several new tag-related APIs in Connect back to Spark Core.
Following the isolation practice in the original Connect API, the newly
introduced APIs also support isolation:
- `interrupt{Tag,All,Operation}` can only cancel jobs created by this Spark
session.
- `{add,remove}Tag` and `{get,clear}Tags` only apply to jobs created by
this Spark session.
Instead of returning query IDs like in Spark Connect, here in Spark SQL,
these methods will return SQL execution root IDs - as "query IDs" are only for
Connect.
### Why are the changes needed?
To close the API gap between Connect and Core.
### Does this PR introduce _any_ user-facing change?
Yes, Core users can use some new APIs.
### How was this patch tested?
New test added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47815 from xupefei/reverse-api-tag.
Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 70 +-----
.../CheckConnectJvmClientCompatibility.scala | 15 --
.../main/scala/org/apache/spark/SparkContext.scala | 56 ++++-
.../org/apache/spark/scheduler/DAGScheduler.scala | 33 ++-
.../apache/spark/scheduler/DAGSchedulerEvent.scala | 5 +-
.../org/apache/spark/sql/api/SparkSession.scala | 92 ++++++++
.../scala/org/apache/spark/sql/SparkSession.scala | 119 +++++++++-
.../apache/spark/sql/execution/SQLExecution.scala | 205 +++++++++-------
...parkSessionJobTaggingAndCancellationSuite.scala | 262 +++++++++++++++++++++
.../spark/sql/execution/SQLExecutionSuite.scala | 2 +-
10 files changed, 667 insertions(+), 192 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 989a7e0c174c..aa6258a14b81 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -420,7 +420,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
- def interruptAll(): Seq[String] = {
+ override def interruptAll(): Seq[String] = {
client.interruptAll().getInterruptedIdsList.asScala.toSeq
}
@@ -433,7 +433,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
- def interruptTag(tag: String): Seq[String] = {
+ override def interruptTag(tag: String): Seq[String] = {
client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
}
@@ -446,7 +446,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
- def interruptOperation(operationId: String): Seq[String] = {
+ override def interruptOperation(operationId: String): Seq[String] = {
client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
}
@@ -477,65 +477,17 @@ class SparkSession private[sql] (
SparkSession.onSessionClose(this)
}
- /**
- * Add a tag to be assigned to all the operations started by this thread in
this session.
- *
- * Often, a unit of execution in an application consists of multiple Spark
executions.
- * Application programmers can use this method to group all those jobs
together and give a group
- * tag. The application can use
`org.apache.spark.sql.SparkSession.interruptTag` to cancel all
- * running running executions with this tag. For example:
- * {{{
- * // In the main thread:
- * spark.addTag("myjobs")
- * spark.range(10).map(i => { Thread.sleep(10); i }).collect()
- *
- * // In a separate thread:
- * spark.interruptTag("myjobs")
- * }}}
- *
- * There may be multiple tags present at the same time, so different parts
of application may
- * use different tags to perform cancellation at different levels of
granularity.
- *
- * @param tag
- * The tag to be added. Cannot contain ',' (comma) character or be an
empty string.
- *
- * @since 3.5.0
- */
- def addTag(tag: String): Unit = {
- client.addTag(tag)
- }
+ /** @inheritdoc */
+ override def addTag(tag: String): Unit = client.addTag(tag)
- /**
- * Remove a tag previously added to be assigned to all the operations
started by this thread in
- * this session. Noop if such a tag was not added earlier.
- *
- * @param tag
- * The tag to be removed. Cannot contain ',' (comma) character or be an
empty string.
- *
- * @since 3.5.0
- */
- def removeTag(tag: String): Unit = {
- client.removeTag(tag)
- }
+ /** @inheritdoc */
+ override def removeTag(tag: String): Unit = client.removeTag(tag)
- /**
- * Get the tags that are currently set to be assigned to all the operations
started by this
- * thread.
- *
- * @since 3.5.0
- */
- def getTags(): Set[String] = {
- client.getTags()
- }
+ /** @inheritdoc */
+ override def getTags(): Set[String] = client.getTags()
- /**
- * Clear the current thread's operation tags.
- *
- * @since 3.5.0
- */
- def clearTags(): Unit = {
- client.clearTags()
- }
+ /** @inheritdoc */
+ override def clearTags(): Unit = client.clearTags()
/**
* We cannot deserialize a connect [[SparkSession]] because of a class clash
on the server side.
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index f4043f19eb6a..abf03cfbc672 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -365,21 +365,6 @@ object CheckConnectJvmClientCompatibility {
// Experimental
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.registerClassFinder"),
- // public
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.interruptAll"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.interruptTag"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.interruptOperation"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.addTag"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.removeTag"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.getTags"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.clearTags"),
// SparkSession#Builder
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession#Builder.remote"),
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 485f0abcd25e..042179d86c31 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -27,6 +27,7 @@ import scala.collection.Map
import scala.collection.concurrent.{Map => ScalaConcurrentMap}
import scala.collection.immutable
import scala.collection.mutable.HashMap
+import scala.concurrent.{Future, Promise}
import scala.jdk.CollectionConverters._
import scala.reflect.{classTag, ClassTag}
import scala.util.control.NonFatal
@@ -909,10 +910,20 @@ class SparkContext(config: SparkConf) extends Logging {
*
* @since 3.5.0
*/
- def addJobTag(tag: String): Unit = {
- SparkContext.throwIfInvalidTag(tag)
+ def addJobTag(tag: String): Unit = addJobTags(Set(tag))
+
+ /**
+ * Add multiple tags to be assigned to all the jobs started by this thread.
+ * See [[addJobTag]] for more details.
+ *
+ * @param tags The tags to be added. Cannot contain ',' (comma) character.
+ *
+ * @since 4.0.0
+ */
+ def addJobTags(tags: Set[String]): Unit = {
+ tags.foreach(SparkContext.throwIfInvalidTag)
val existingTags = getJobTags()
- val newTags = (existingTags +
tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
+ val newTags = (existingTags ++
tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags)
}
@@ -924,10 +935,20 @@ class SparkContext(config: SparkConf) extends Logging {
*
* @since 3.5.0
*/
- def removeJobTag(tag: String): Unit = {
- SparkContext.throwIfInvalidTag(tag)
+ def removeJobTag(tag: String): Unit = removeJobTags(Set(tag))
+
+ /**
+ * Remove multiple tags to be assigned to all the jobs started by this
thread.
+ * See [[removeJobTag]] for more details.
+ *
+ * @param tags The tags to be removed. Cannot contain ',' (comma) character.
+ *
+ * @since 4.0.0
+ */
+ def removeJobTags(tags: Set[String]): Unit = {
+ tags.foreach(SparkContext.throwIfInvalidTag)
val existingTags = getJobTags()
- val newTags = (existingTags -
tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
+ val newTags = (existingTags --
tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
if (newTags.isEmpty) {
clearJobTags()
} else {
@@ -2684,6 +2705,25 @@ class SparkContext(config: SparkConf) extends Logging {
dagScheduler.cancelJobGroup(groupId, cancelFutureJobs = true, None)
}
+ /**
+ * Cancel active jobs that have the specified tag. See
`org.apache.spark.SparkContext.addJobTag`.
+ *
+ * @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
+ * @param reason reason for cancellation.
+ * @return A future with [[ActiveJob]]s, allowing extraction of information
such as Job ID and
+ * tags.
+ */
+ private[spark] def cancelJobsWithTagWithFuture(
+ tag: String,
+ reason: String): Future[Seq[ActiveJob]] = {
+ SparkContext.throwIfInvalidTag(tag)
+ assertNotStopped()
+
+ val cancelledJobs = Promise[Seq[ActiveJob]]()
+ dagScheduler.cancelJobsWithTag(tag, Some(reason), Some(cancelledJobs))
+ cancelledJobs.future
+ }
+
/**
* Cancel active jobs that have the specified tag. See
`org.apache.spark.SparkContext.addJobTag`.
*
@@ -2695,7 +2735,7 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelJobsWithTag(tag: String, reason: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()
- dagScheduler.cancelJobsWithTag(tag, Option(reason))
+ dagScheduler.cancelJobsWithTag(tag, Option(reason), cancelledJobs = None)
}
/**
@@ -2708,7 +2748,7 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelJobsWithTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()
- dagScheduler.cancelJobsWithTag(tag, None)
+ dagScheduler.cancelJobsWithTag(tag, reason = None, cancelledJobs = None)
}
/** Cancel all jobs that have been scheduled or are running. */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 6c824e2fdeae..2c89fe7885d0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -27,6 +27,7 @@ import scala.annotation.tailrec
import scala.collection.Map
import scala.collection.mutable
import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+import scala.concurrent.Promise
import scala.concurrent.duration._
import scala.util.control.NonFatal
@@ -1116,11 +1117,18 @@ private[spark] class DAGScheduler(
/**
* Cancel all jobs with a given tag.
+ *
+ * @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
+ * @param reason reason for cancellation.
+ * @param cancelledJobs a promise to be completed with operation IDs being
cancelled.
*/
- def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = {
+ def cancelJobsWithTag(
+ tag: String,
+ reason: Option[String],
+ cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = {
SparkContext.throwIfInvalidTag(tag)
logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}")
- eventProcessLoop.post(JobTagCancelled(tag, reason))
+ eventProcessLoop.post(JobTagCancelled(tag, reason, cancelledJobs))
}
/**
@@ -1234,17 +1242,22 @@ private[spark] class DAGScheduler(
jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
}
- private[scheduler] def handleJobTagCancelled(tag: String, reason:
Option[String]): Unit = {
- // Cancel all jobs belonging that have this tag.
+ private[scheduler] def handleJobTagCancelled(
+ tag: String,
+ reason: Option[String],
+ cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = {
+ // Cancel all jobs that have all provided tags.
// First finds all active jobs with this group id, and then kill stages
for them.
- val jobIds = activeJobs.filter { activeJob =>
+ val jobsToBeCancelled = activeJobs.filter { activeJob =>
Option(activeJob.properties).exists { properties =>
Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("")
.split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag)
}
- }.map(_.jobId)
- val updatedReason = reason.getOrElse("part of cancelled job tag
%s".format(tag))
- jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
+ }
+ val updatedReason =
+ reason.getOrElse("part of cancelled job tags %s".format(tag))
+ jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_,
Option(updatedReason)))
+ cancelledJobs.map(_.success(jobsToBeCancelled.toSeq))
}
private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo):
Unit = {
@@ -3113,8 +3126,8 @@ private[scheduler] class
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case JobGroupCancelled(groupId, cancelFutureJobs, reason) =>
dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason)
- case JobTagCancelled(tag, reason) =>
- dagScheduler.handleJobTagCancelled(tag, reason)
+ case JobTagCancelled(tag, reason, cancelledJobs) =>
+ dagScheduler.handleJobTagCancelled(tag, reason, cancelledJobs)
case AllJobsCancelled =>
dagScheduler.doCancelAllJobs()
diff --git
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index c9ad54d1fdc7..8932d2ef323b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,6 +19,8 @@ package org.apache.spark.scheduler
import java.util.Properties
+import scala.concurrent.Promise
+
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{AccumulatorV2, CallSite}
@@ -71,7 +73,8 @@ private[scheduler] case class JobGroupCancelled(
private[scheduler] case class JobTagCancelled(
tagName: String,
- reason: Option[String]) extends DAGSchedulerEvent
+ reason: Option[String],
+ cancelledJobs: Option[Promise[Seq[ActiveJob]]]) extends DAGSchedulerEvent
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
index 0580931620aa..63d4a12e1183 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
@@ -401,6 +401,98 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]]
extends Serializable with C
@scala.annotation.varargs
def addArtifacts(uri: URI*): Unit
+ /**
+ * Add a tag to be assigned to all the operations started by this thread in
this session.
+ *
+ * Often, a unit of execution in an application consists of multiple Spark
executions.
+ * Application programmers can use this method to group all those jobs
together and give a group
+ * tag. The application can use
`org.apache.spark.sql.SparkSession.interruptTag` to cancel all
+ * running executions with this tag. For example:
+ * {{{
+ * // In the main thread:
+ * spark.addTag("myjobs")
+ * spark.range(10).map(i => { Thread.sleep(10); i }).collect()
+ *
+ * // In a separate thread:
+ * spark.interruptTag("myjobs")
+ * }}}
+ *
+ * There may be multiple tags present at the same time, so different parts
of application may
+ * use different tags to perform cancellation at different levels of
granularity.
+ *
+ * @param tag
+ * The tag to be added. Cannot contain ',' (comma) character or be an
empty string.
+ *
+ * @since 4.0.0
+ */
+ def addTag(tag: String): Unit
+
+ /**
+ * Remove a tag previously added to be assigned to all the operations
started by this thread in
+ * this session. Noop if such a tag was not added earlier.
+ *
+ * @param tag
+ * The tag to be removed. Cannot contain ',' (comma) character or be an
empty string.
+ *
+ * @since 4.0.0
+ */
+ def removeTag(tag: String): Unit
+
+ /**
+ * Get the operation tags that are currently set to be assigned to all the
operations started by
+ * this thread in this session.
+ *
+ * @since 4.0.0
+ */
+ def getTags(): Set[String]
+
+ /**
+ * Clear the current thread's operation tags.
+ *
+ * @since 4.0.0
+ */
+ def clearTags(): Unit
+
+ /**
+ * Request to interrupt all currently running operations of this session.
+ *
+ * @note
+ * This method will wait up to 60 seconds for the interruption request to
be issued.
+ *
+ * @return
+ * Sequence of operation IDs requested to be interrupted.
+ *
+ * @since 4.0.0
+ */
+ def interruptAll(): Seq[String]
+
+ /**
+ * Request to interrupt all currently running operations of this session
with the given job tag.
+ *
+ * @note
+ * This method will wait up to 60 seconds for the interruption request to
be issued.
+ *
+ * @return
+ * Sequence of operation IDs requested to be interrupted.
+ *
+ * @since 4.0.0
+ */
+ def interruptTag(tag: String): Seq[String]
+
+ /**
+ * Request to interrupt an operation of this session, given its operation ID.
+ *
+ * @note
+ * This method will wait up to 60 seconds for the interruption request to
be issued.
+ *
+ * @return
+ * The operation ID requested to be interrupted, as a single-element
sequence, or an empty
+ * sequence if the operation is not started by this session.
+ *
+ * @since 4.0.0
+ */
+ def interruptOperation(operationId: String): Seq[String]
+
/**
* Returns a [[DataFrameReader]] that can be used to read non-streaming data
in as a
* `DataFrame`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 5746b942341f..720b77b0b9fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.Paths
import java.util.{ServiceLoader, UUID}
+import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
+import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -57,7 +59,7 @@ import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
-import org.apache.spark.util.{CallSite, SparkFileUtils, Utils}
+import org.apache.spark.util.{CallSite, SparkFileUtils, ThreadUtils, Utils}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -92,7 +94,8 @@ class SparkSession private(
@transient private val existingSharedState: Option[SharedState],
@transient private val parentSessionState: Option[SessionState],
@transient private[sql] val extensions: SparkSessionExtensions,
- @transient private[sql] val initialSessionOptions: Map[String, String])
+ @transient private[sql] val initialSessionOptions: Map[String, String],
+ @transient private val parentManagedJobTags: Map[String, String])
extends api.SparkSession[Dataset] with Logging { self =>
// The call site where this SparkSession was constructed.
@@ -107,7 +110,12 @@ class SparkSession private(
private[sql] def this(
sc: SparkContext,
initialSessionOptions: java.util.HashMap[String, String]) = {
- this(sc, None, None, applyAndLoadExtensions(sc),
initialSessionOptions.asScala.toMap)
+ this(
+ sc,
+ existingSharedState = None,
+ parentSessionState = None,
+ applyAndLoadExtensions(sc), initialSessionOptions.asScala.toMap,
+ parentManagedJobTags = Map.empty)
}
private[sql] def this(sc: SparkContext) = this(sc, new
java.util.HashMap[String, String]())
@@ -122,6 +130,18 @@ class SparkSession private(
.getOrElse(SQLConf.getFallbackConf)
})
+ /** Tag to mark all jobs owned by this session. */
+ private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID"
+
+ /**
+ * A map to hold the mapping from user-defined tags to the real tags
attached to Jobs.
+ * Real tag have the current session ID attached: `"tag1" ->
s"spark-session-$sessionUUID-tag1"`.
+ */
+ @transient
+ private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = {
+ new ConcurrentHashMap(parentManagedJobTags.asJava)
+ }
+
/** @inheritdoc */
def version: String = SPARK_VERSION
@@ -235,7 +255,8 @@ class SparkSession private(
Some(sharedState),
parentSessionState = None,
extensions,
- initialSessionOptions)
+ initialSessionOptions,
+ parentManagedJobTags = Map.empty)
}
/**
@@ -256,8 +277,10 @@ class SparkSession private(
Some(sharedState),
Some(sessionState),
extensions,
- Map.empty)
+ Map.empty,
+ managedJobTags.asScala.toMap)
result.sessionState // force copy of SessionState
+ result.managedJobTags // force copy of userDefinedToRealTagsMap
result
}
@@ -636,6 +659,83 @@ class SparkSession private(
artifactManager.addLocalArtifacts(uri.flatMap(Artifact.parseArtifacts))
}
+ /** @inheritdoc */
+ override def addTag(tag: String): Unit = {
+ SparkContext.throwIfInvalidTag(tag)
+ managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag")
+ }
+
+ /** @inheritdoc */
+ override def removeTag(tag: String): Unit = managedJobTags.remove(tag)
+
+ /** @inheritdoc */
+ override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet
+
+ /** @inheritdoc */
+ override def clearTags(): Unit = managedJobTags.clear()
+
+ /**
+ * Request to interrupt all currently running SQL operations of this session.
+ *
+ * @note Only DataFrame/SQL operations started by this session can be
interrupted.
+ *
+ * @note This method will wait up to 60 seconds for the interruption request
to be issued.
+
+ * @return Sequence of SQL execution IDs requested to be interrupted.
+
+ * @since 4.0.0
+ */
+ override def interruptAll(): Seq[String] =
+ doInterruptTag(sessionJobTag, "as part of cancellation of all jobs")
+
+ /**
+ * Request to interrupt all currently running SQL operations of this session
with the given
+ * job tag.
+ *
+ * @note Only DataFrame/SQL operations started by this session can be
interrupted.
+ *
+ * @note This method will wait up to 60 seconds for the interruption request
to be issued.
+ *
+ * @return Sequence of SQL execution IDs requested to be interrupted.
+
+ * @since 4.0.0
+ */
+ override def interruptTag(tag: String): Seq[String] = {
+ val realTag = managedJobTags.get(tag)
+ if (realTag == null) return Seq.empty
+ doInterruptTag(realTag, s"part of cancelled job tags $tag")
+ }
+
+ private def doInterruptTag(tag: String, reason: String): Seq[String] = {
+ val cancelledTags =
+ sparkContext.cancelJobsWithTagWithFuture(tag, reason)
+
+ ThreadUtils.awaitResult(cancelledTags, 60.seconds)
+ .flatMap(job =>
Option(job.properties.getProperty(SQLExecution.EXECUTION_ROOT_ID_KEY)))
+ }
+
+ /**
+ * Request to interrupt a SQL operation of this session, given its SQL
execution ID.
+ *
+ * @note Only DataFrame/SQL operations started by this session can be
interrupted.
+ *
+ * @note This method will wait up to 60 seconds for the interruption request
to be issued.
+ *
+ * @return The execution ID requested to be interrupted, as a single-element
sequence, or an empty
+ * sequence if the operation is not started by this session.
+ *
+ * @since 4.0.0
+ */
+ override def interruptOperation(operationId: String): Seq[String] = {
+ scala.util.Try(operationId.toLong).toOption match {
+ case Some(executionIdToBeCancelled) =>
+ val tagToBeCancelled = SQLExecution.executionIdJobTag(this,
executionIdToBeCancelled)
+ doInterruptTag(tagToBeCancelled, reason = "")
+ case None =>
+ throw new IllegalArgumentException("executionId must be a number in
string form.")
+ }
+ }
+
/** @inheritdoc */
def read: DataFrameReader = new DataFrameReader(self)
@@ -722,7 +822,7 @@ class SparkSession private(
}
/**
- * Execute a block of code with the this session set as the active session,
and restore the
+ * Execute a block of code with this session set as the active session, and
restore the
* previous session on completion.
*/
private[sql] def withActive[T](block: => T): T = {
@@ -958,7 +1058,12 @@ object SparkSession extends Logging {
loadExtensions(extensions)
applyExtensions(sparkContext, extensions)
- session = new SparkSession(sparkContext, None, None, extensions,
options.toMap)
+ session = new SparkSession(sparkContext,
+ existingSharedState = None,
+ parentSessionState = None,
+ extensions,
+ initialSessionOptions = options.toMap,
+ parentManagedJobTags = Map.empty)
setDefaultSession(session)
setActiveSession(session)
registerContextListener(sparkContext)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 12ff649b621e..5db14a866213 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -44,7 +44,7 @@ object SQLExecution extends Logging {
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
- private val executionIdToQueryExecution = new ConcurrentHashMap[Long,
QueryExecution]()
+ private[sql] val executionIdToQueryExecution = new ConcurrentHashMap[Long,
QueryExecution]()
def getQueryExecution(executionId: Long): QueryExecution = {
executionIdToQueryExecution.get(executionId)
@@ -52,6 +52,9 @@ object SQLExecution extends Logging {
private val testing = sys.props.contains(IS_TESTING.key)
+ private[sql] def executionIdJobTag(session: SparkSession, id: Long) =
+ s"${session.sessionJobTag}-execution-root-id-$id"
+
private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
val sc = sparkSession.sparkContext
// only throw an exception during tests. a missing execution ID should not
fail a job.
@@ -82,6 +85,7 @@ object SQLExecution extends Logging {
// And for the root execution, rootExecutionId == executionId.
if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == null) {
sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, executionId.toString)
+ sc.addJobTag(executionIdJobTag(sparkSession, executionId))
}
val rootExecutionId = sc.getLocalProperty(EXECUTION_ROOT_ID_KEY).toLong
executionIdToQueryExecution.put(executionId, queryExecution)
@@ -116,92 +120,94 @@ object SQLExecution extends Logging {
val redactedConfigs =
sparkSession.sessionState.conf.redactOptions(modifiedConfigs)
withSQLConfPropagated(sparkSession) {
- var ex: Option[Throwable] = None
- var isExecutedPlanAvailable = false
- val startTime = System.nanoTime()
- val startEvent = SparkListenerSQLExecutionStart(
- executionId = executionId,
- rootExecutionId = Some(rootExecutionId),
- description = desc,
- details = callSite.longForm,
- physicalPlanDescription = "",
- sparkPlanInfo = SparkPlanInfo.EMPTY,
- time = System.currentTimeMillis(),
- modifiedConfigs = redactedConfigs,
- jobTags = sc.getJobTags(),
- jobGroupId =
Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
- )
- try {
- body match {
- case Left(e) =>
- sc.listenerBus.post(startEvent)
+ withSessionTagsApplied(sparkSession) {
+ var ex: Option[Throwable] = None
+ var isExecutedPlanAvailable = false
+ val startTime = System.nanoTime()
+ val startEvent = SparkListenerSQLExecutionStart(
+ executionId = executionId,
+ rootExecutionId = Some(rootExecutionId),
+ description = desc,
+ details = callSite.longForm,
+ physicalPlanDescription = "",
+ sparkPlanInfo = SparkPlanInfo.EMPTY,
+ time = System.currentTimeMillis(),
+ modifiedConfigs = redactedConfigs,
+ jobTags = sc.getJobTags(),
+ jobGroupId =
Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
+ )
+ try {
+ body match {
+ case Left(e) =>
+ sc.listenerBus.post(startEvent)
+ throw e
+ case Right(f) =>
+ val planDescriptionMode =
+
ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode)
+ val planDesc =
queryExecution.explainString(planDescriptionMode)
+ val planInfo = try {
+ SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)
+ } catch {
+ case NonFatal(e) =>
+ logDebug("Failed to generate SparkPlanInfo", e)
+ // If the queryExecution already failed before this, we
are not able to generate
+ // the the plan info, so we use and empty graphviz node to
make the UI happy
+ SparkPlanInfo.EMPTY
+ }
+ sc.listenerBus.post(
+ startEvent.copy(physicalPlanDescription = planDesc,
sparkPlanInfo = planInfo))
+ isExecutedPlanAvailable = true
+ f()
+ }
+ } catch {
+ case e: Throwable =>
+ ex = Some(e)
throw e
- case Right(f) =>
- val planDescriptionMode =
-
ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode)
- val planDesc = queryExecution.explainString(planDescriptionMode)
- val planInfo = try {
- SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)
- } catch {
- case NonFatal(e) =>
- logDebug("Failed to generate SparkPlanInfo", e)
- // If the queryExecution already failed before this, we are
not able to generate
- // the the plan info, so we use and empty graphviz node to
make the UI happy
- SparkPlanInfo.EMPTY
- }
- sc.listenerBus.post(
- startEvent.copy(physicalPlanDescription = planDesc,
sparkPlanInfo = planInfo))
- isExecutedPlanAvailable = true
- f()
- }
- } catch {
- case e: Throwable =>
- ex = Some(e)
- throw e
- } finally {
- val endTime = System.nanoTime()
- val errorMessage = ex.map {
- case e: SparkThrowable =>
- SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY)
- case e =>
- Utils.exceptionString(e)
- }
- if (queryExecution.shuffleCleanupMode != DoNotCleanup
- && isExecutedPlanAvailable) {
- val shuffleIds = queryExecution.executedPlan match {
- case ae: AdaptiveSparkPlanExec =>
- ae.context.shuffleIds.asScala.keys
- case _ =>
- Iterable.empty
+ } finally {
+ val endTime = System.nanoTime()
+ val errorMessage = ex.map {
+ case e: SparkThrowable =>
+ SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY)
+ case e =>
+ Utils.exceptionString(e)
}
- shuffleIds.foreach { shuffleId =>
- queryExecution.shuffleCleanupMode match {
- case RemoveShuffleFiles =>
- // Same as what we do in ContextCleaner.doCleanupShuffle,
but do not unregister
- // the shuffle on MapOutputTracker, so that stage retries
would be triggered.
- // Set blocking to Utils.isTesting to deflake unit tests.
- sc.shuffleDriverComponents.removeShuffle(shuffleId,
Utils.isTesting)
- case SkipMigration =>
-
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
- case _ => // this should not happen
+ if (queryExecution.shuffleCleanupMode != DoNotCleanup
+ && isExecutedPlanAvailable) {
+ val shuffleIds = queryExecution.executedPlan match {
+ case ae: AdaptiveSparkPlanExec =>
+ ae.context.shuffleIds.asScala.keys
+ case _ =>
+ Iterable.empty
+ }
+ shuffleIds.foreach { shuffleId =>
+ queryExecution.shuffleCleanupMode match {
+ case RemoveShuffleFiles =>
+ // Same as what we do in ContextCleaner.doCleanupShuffle,
but do not unregister
+ // the shuffle on MapOutputTracker, so that stage retries
would be triggered.
+ // Set blocking to Utils.isTesting to deflake unit tests.
+ sc.shuffleDriverComponents.removeShuffle(shuffleId,
Utils.isTesting)
+ case SkipMigration =>
+
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
+ case _ => // this should not happen
+ }
}
}
+ val event = SparkListenerSQLExecutionEnd(
+ executionId,
+ System.currentTimeMillis(),
+ // Use empty string to indicate no error, as None may mean
events generated by old
+ // versions of Spark.
+ errorMessage.orElse(Some("")))
+ // Currently only `Dataset.withAction` and
`DataFrameWriter.runCommand` specify the
+ // `name` parameter. The `ExecutionListenerManager` only watches
SQL executions with
+ // name. We can specify the execution name in more places in the
future, so that
+ // `QueryExecutionListener` can track more cases.
+ event.executionName = name
+ event.duration = endTime - startTime
+ event.qe = queryExecution
+ event.executionFailure = ex
+ sc.listenerBus.post(event)
}
- val event = SparkListenerSQLExecutionEnd(
- executionId,
- System.currentTimeMillis(),
- // Use empty string to indicate no error, as None may mean events
generated by old
- // versions of Spark.
- errorMessage.orElse(Some("")))
- // Currently only `Dataset.withAction` and
`DataFrameWriter.runCommand` specify the `name`
- // parameter. The `ExecutionListenerManager` only watches SQL
executions with name. We
- // can specify the execution name in more places in the future, so
that
- // `QueryExecutionListener` can track more cases.
- event.executionName = name
- event.duration = endTime - startTime
- event.qe = queryExecution
- event.executionFailure = ex
- sc.listenerBus.post(event)
}
}
} finally {
@@ -211,6 +217,7 @@ object SQLExecution extends Logging {
// The current execution is the root execution if rootExecutionId ==
executionId.
if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == executionId.toString) {
sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, null)
+ sc.removeJobTag(executionIdJobTag(sparkSession, executionId))
}
sc.setLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL,
originalInterruptOnCancel)
}
@@ -238,15 +245,28 @@ object SQLExecution extends Logging {
val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
withSQLConfPropagated(sparkSession) {
- try {
- sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
- body
- } finally {
- sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ withSessionTagsApplied(sparkSession) {
+ try {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
+ body
+ } finally {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ }
}
}
}
+ private[sql] def withSessionTagsApplied[T](sparkSession:
SparkSession)(block: => T): T = {
+ val allTags = sparkSession.managedJobTags.values().asScala.toSet +
sparkSession.sessionJobTag
+ sparkSession.sparkContext.addJobTags(allTags)
+
+ try {
+ block
+ } finally {
+ sparkSession.sparkContext.removeJobTags(allTags)
+ }
+ }
+
/**
* Wrap an action with specified SQL configs. These configs will be
propagated to the executor
* side via job local properties.
@@ -286,10 +306,13 @@ object SQLExecution extends Logging {
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
SparkSession.setActiveSession(activeSession)
- sc.setLocalProperties(localProps)
- val res = body
- // reset active session and local props.
- sc.setLocalProperties(originalLocalProps)
+ val res = withSessionTagsApplied(activeSession) {
+ sc.setLocalProperties(localProps)
+ val res = body
+ // reset active session and local props.
+ sc.setLocalProperties(originalLocalProps)
+ res
+ }
if (originalSession.nonEmpty) {
SparkSession.setActiveSession(originalSession.get)
} else {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
new file mode 100644
index 000000000000..e9fd07ecf18b
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.{ExecutionContext, Future}
+import scala.jdk.CollectionConverters._
+
+import org.scalatest.concurrent.Eventually
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{LocalSparkContext, SparkContext, SparkException,
SparkFunSuite}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd,
SparkListenerJobStart}
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.tags.ExtendedSQLTest
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Test cases for the tagging and cancellation APIs provided by
[[SparkSession]].
+ */
+@ExtendedSQLTest
+class SparkSessionJobTaggingAndCancellationSuite
+ extends SparkFunSuite
+ with Eventually
+ with LocalSparkContext {
+
+ override def afterEach(): Unit = {
+ try {
+ // This suite should not interfere with the other test suites.
+ SparkSession.getActiveSession.foreach(_.stop())
+ SparkSession.clearActiveSession()
+ SparkSession.getDefaultSession.foreach(_.stop())
+ SparkSession.clearDefaultSession()
+ resetSparkContext()
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ test("Tags are not inherited by new sessions") {
+ val session = SparkSession.builder().master("local").getOrCreate()
+
+ assert(session.getTags() == Set())
+ session.addTag("one")
+ assert(session.getTags() == Set("one"))
+
+ val newSession = session.newSession()
+ assert(newSession.getTags() == Set())
+ }
+
+ test("Tags are inherited by cloned sessions") {
+ val session = SparkSession.builder().master("local").getOrCreate()
+
+ assert(session.getTags() == Set())
+ session.addTag("one")
+ assert(session.getTags() == Set("one"))
+
+ val clonedSession = session.cloneSession()
+ assert(clonedSession.getTags() == Set("one"))
+ clonedSession.addTag("two")
+ assert(clonedSession.getTags() == Set("one", "two"))
+
+ // Tags are not propagated back to the original session
+ assert(session.getTags() == Set("one"))
+ }
+
+ test("Tags set from session are prefixed with session UUID") {
+ sc = new SparkContext("local[2]", "test")
+ val session = SparkSession.builder().sparkContext(sc).getOrCreate()
+ import session.implicits._
+
+ val sem = new Semaphore(0)
+ sc.addSparkListener(new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ sem.release()
+ }
+ })
+
+ session.addTag("one")
+ Future {
+ session.range(1, 10000).map { i => Thread.sleep(100); i }.count()
+ }(ExecutionContext.global)
+
+ assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
+ val activeJobsFuture =
+
session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"),
"reason")
+ val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head
+ val actualTags =
activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS)
+ .split(SparkContext.SPARK_JOB_TAGS_SEP)
+ assert(actualTags.toSet == Set(
+ session.sessionJobTag,
+ s"${session.sessionJobTag}-one",
+ SQLExecution.executionIdJobTag(
+ session,
+
activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)))
+ }
+
+ test("Cancellation APIs in SparkSession are isolated") {
+ sc = new SparkContext("local[2]", "test")
+ val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate()
+ var (sessionA, sessionB, sessionC): (SparkSession, SparkSession,
SparkSession) =
+ (null, null, null)
+
+ // global ExecutionContext has only 2 threads in Apache Spark CI
+ // create own thread pool for four Futures used in this test
+ val numThreads = 3
+ val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool",
numThreads)
+ val executionContext = ExecutionContext.fromExecutorService(fpool)
+
+ try {
+ // Add a listener to release the semaphore once jobs are launched.
+ val sem = new Semaphore(0)
+ val jobEnded = new AtomicInteger(0)
+ val jobProperties: ConcurrentHashMap[Int, java.util.Properties] = new
ConcurrentHashMap()
+
+ sc.addSparkListener(new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ jobProperties.put(jobStart.jobId, jobStart.properties)
+ sem.release()
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+ sem.release()
+ jobEnded.incrementAndGet()
+ }
+ })
+
+ // Note: since tags are added in the Future threads, they don't need to
be cleared in between.
+ val jobA = Future {
+ sessionA = globalSession.cloneSession()
+ import globalSession.implicits._
+
+ assert(sessionA.getTags() == Set())
+ sessionA.addTag("two")
+ assert(sessionA.getTags() == Set("two"))
+ sessionA.clearTags() // check that clearing all tags works
+ assert(sessionA.getTags() == Set())
+ sessionA.addTag("one")
+ assert(sessionA.getTags() == Set("one"))
+ try {
+ sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count()
+ } finally {
+ sessionA.clearTags() // clear for the case of thread reuse by
another Future
+ }
+ }(executionContext)
+ val jobB = Future {
+ sessionB = globalSession.cloneSession()
+ import globalSession.implicits._
+
+ assert(sessionB.getTags() == Set())
+ sessionB.addTag("one")
+ sessionB.addTag("two")
+ sessionB.addTag("one")
+ sessionB.addTag("two") // duplicates shouldn't matter
+ assert(sessionB.getTags() == Set("one", "two"))
+ try {
+ sessionB.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count()
+ } finally {
+ sessionB.clearTags() // clear for the case of thread reuse by
another Future
+ }
+ }(executionContext)
+ val jobC = Future {
+ sessionC = globalSession.cloneSession()
+ import globalSession.implicits._
+
+ sessionC.addTag("foo")
+ sessionC.removeTag("foo")
+ assert(sessionC.getTags() == Set()) // check that remove works
removing the last tag
+ sessionC.addTag("boo")
+ try {
+ sessionC.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count()
+ } finally {
+ sessionC.clearTags() // clear for the case of thread reuse by
another Future
+ }
+ }(executionContext)
+
+ // Block until four jobs have started.
+ assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES))
+
+ // Tags are applied
+ assert(jobProperties.size == 3)
+ for (ss <- Seq(sessionA, sessionB, sessionC)) {
+ val jobProperty =
jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS)
+ .asInstanceOf[String].contains(ss.sessionUUID))
+ assert(jobProperty.size == 1)
+ val tags =
jobProperty.head.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String]
+ .split(SparkContext.SPARK_JOB_TAGS_SEP)
+
+ val executionRootIdTag = SQLExecution.executionIdJobTag(
+ ss,
+
jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)
+ val userTagsPrefix = s"spark-session-${ss.sessionUUID}-"
+
+ ss match {
+ case s if s == sessionA => assert(tags.toSet == Set(
+ s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one"))
+ case s if s == sessionB => assert(tags.toSet == Set(
+ s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one",
s"${userTagsPrefix}two"))
+ case s if s == sessionC => assert(tags.toSet == Set(
+ s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo"))
+ }
+ }
+
+ // Global session cancels nothing
+ assert(globalSession.interruptAll().isEmpty)
+ assert(globalSession.interruptTag("one").isEmpty)
+ assert(globalSession.interruptTag("two").isEmpty)
+ for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) {
+ assert(globalSession.interruptOperation(i.toString).isEmpty)
+ }
+ assert(jobEnded.intValue == 0)
+
+ // One job cancelled
+ for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) {
+ sessionC.interruptOperation(i.toString)
+ }
+ val eC = intercept[SparkException] {
+ ThreadUtils.awaitResult(jobC, 1.minute)
+ }.getCause
+ assert(eC.getMessage contains "cancelled")
+ assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
+ assert(jobEnded.intValue == 1)
+
+ // Another job cancelled
+ assert(sessionA.interruptTag("one").size == 1)
+ val eA = intercept[SparkException] {
+ ThreadUtils.awaitResult(jobA, 1.minute)
+ }.getCause
+ assert(eA.getMessage contains "cancelled job tags one")
+ assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
+ assert(jobEnded.intValue == 2)
+
+ // The last job cancelled
+ sessionB.interruptAll()
+ val eB = intercept[SparkException] {
+ ThreadUtils.awaitResult(jobB, 1.minute)
+ }.getCause
+ assert(eB.getMessage contains "cancellation of all jobs")
+ assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
+ assert(jobEnded.intValue == 3)
+ } finally {
+ fpool.shutdownNow()
+ }
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
index 94d33731b6de..059a4c9b8376 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
@@ -228,7 +228,7 @@ class SQLExecutionSuite extends SparkFunSuite with
SQLConfHelper {
spark.range(1).collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
- assert(jobTags.contains(jobTag))
+ assert(jobTags.get.contains(jobTag))
assert(sqlJobTags.contains(jobTag))
} finally {
spark.sparkContext.removeJobTag(jobTag)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]