This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fd8e99b9df55 [SPARK-49249][SPARK-49320] Add new tag-related APIs in 
Connect back to Spark Core
fd8e99b9df55 is described below

commit fd8e99b9df55bf2ea29b6279a6a840ffef20ed4e
Author: Paddy Xu <[email protected]>
AuthorDate: Tue Sep 17 23:06:05 2024 -0400

    [SPARK-49249][SPARK-49320] Add new tag-related APIs in Connect back to 
Spark Core
    
    ### What changes were proposed in this pull request?
    
    This PR adds several new tag-related APIs in Connect back to Spark Core. 
Following the isolation practice in the original Connect API, the newly 
introduced APIs also support isolation:
    
    - `interrupt{Tag,All,Operation}` can only cancel jobs created by this Spark 
session.
    -  `{add,remove}Tag` and `{get,clear}Tags` only apply to jobs created by 
this Spark session.
    
    Instead of returning query IDs like in Spark Connect, here in Spark SQL, 
these methods will return SQL execution root IDs - as "query IDs" are only for 
Connect.
    
    ### Why are the changes needed?
    
    To close the API gap between Connect and Core.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, Core users can use some new APIs.
    
    ### How was this patch tested?
    
    New test added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47815 from xupefei/reverse-api-tag.
    
    Authored-by: Paddy Xu <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |  70 +-----
 .../CheckConnectJvmClientCompatibility.scala       |  15 --
 .../main/scala/org/apache/spark/SparkContext.scala |  56 ++++-
 .../org/apache/spark/scheduler/DAGScheduler.scala  |  33 ++-
 .../apache/spark/scheduler/DAGSchedulerEvent.scala |   5 +-
 .../org/apache/spark/sql/api/SparkSession.scala    |  92 ++++++++
 .../scala/org/apache/spark/sql/SparkSession.scala  | 119 +++++++++-
 .../apache/spark/sql/execution/SQLExecution.scala  | 205 +++++++++-------
 ...parkSessionJobTaggingAndCancellationSuite.scala | 262 +++++++++++++++++++++
 .../spark/sql/execution/SQLExecutionSuite.scala    |   2 +-
 10 files changed, 667 insertions(+), 192 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 989a7e0c174c..aa6258a14b81 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -420,7 +420,7 @@ class SparkSession private[sql] (
    *
    * @since 3.5.0
    */
-  def interruptAll(): Seq[String] = {
+  override def interruptAll(): Seq[String] = {
     client.interruptAll().getInterruptedIdsList.asScala.toSeq
   }
 
@@ -433,7 +433,7 @@ class SparkSession private[sql] (
    *
    * @since 3.5.0
    */
-  def interruptTag(tag: String): Seq[String] = {
+  override def interruptTag(tag: String): Seq[String] = {
     client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
   }
 
@@ -446,7 +446,7 @@ class SparkSession private[sql] (
    *
    * @since 3.5.0
    */
-  def interruptOperation(operationId: String): Seq[String] = {
+  override def interruptOperation(operationId: String): Seq[String] = {
     client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
   }
 
@@ -477,65 +477,17 @@ class SparkSession private[sql] (
     SparkSession.onSessionClose(this)
   }
 
-  /**
-   * Add a tag to be assigned to all the operations started by this thread in 
this session.
-   *
-   * Often, a unit of execution in an application consists of multiple Spark 
executions.
-   * Application programmers can use this method to group all those jobs 
together and give a group
-   * tag. The application can use 
`org.apache.spark.sql.SparkSession.interruptTag` to cancel all
-   * running running executions with this tag. For example:
-   * {{{
-   * // In the main thread:
-   * spark.addTag("myjobs")
-   * spark.range(10).map(i => { Thread.sleep(10); i }).collect()
-   *
-   * // In a separate thread:
-   * spark.interruptTag("myjobs")
-   * }}}
-   *
-   * There may be multiple tags present at the same time, so different parts 
of application may
-   * use different tags to perform cancellation at different levels of 
granularity.
-   *
-   * @param tag
-   *   The tag to be added. Cannot contain ',' (comma) character or be an 
empty string.
-   *
-   * @since 3.5.0
-   */
-  def addTag(tag: String): Unit = {
-    client.addTag(tag)
-  }
+  /** @inheritdoc */
+  override def addTag(tag: String): Unit = client.addTag(tag)
 
-  /**
-   * Remove a tag previously added to be assigned to all the operations 
started by this thread in
-   * this session. Noop if such a tag was not added earlier.
-   *
-   * @param tag
-   *   The tag to be removed. Cannot contain ',' (comma) character or be an 
empty string.
-   *
-   * @since 3.5.0
-   */
-  def removeTag(tag: String): Unit = {
-    client.removeTag(tag)
-  }
+  /** @inheritdoc */
+  override def removeTag(tag: String): Unit = client.removeTag(tag)
 
-  /**
-   * Get the tags that are currently set to be assigned to all the operations 
started by this
-   * thread.
-   *
-   * @since 3.5.0
-   */
-  def getTags(): Set[String] = {
-    client.getTags()
-  }
+  /** @inheritdoc */
+  override def getTags(): Set[String] = client.getTags()
 
-  /**
-   * Clear the current thread's operation tags.
-   *
-   * @since 3.5.0
-   */
-  def clearTags(): Unit = {
-    client.clearTags()
-  }
+  /** @inheritdoc */
+  override def clearTags(): Unit = client.clearTags()
 
   /**
    * We cannot deserialize a connect [[SparkSession]] because of a class clash 
on the server side.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index f4043f19eb6a..abf03cfbc672 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -365,21 +365,6 @@ object CheckConnectJvmClientCompatibility {
       // Experimental
       ProblemFilters.exclude[DirectMissingMethodProblem](
         "org.apache.spark.sql.SparkSession.registerClassFinder"),
-      // public
-      ProblemFilters.exclude[DirectMissingMethodProblem](
-        "org.apache.spark.sql.SparkSession.interruptAll"),
-      ProblemFilters.exclude[DirectMissingMethodProblem](
-        "org.apache.spark.sql.SparkSession.interruptTag"),
-      ProblemFilters.exclude[DirectMissingMethodProblem](
-        "org.apache.spark.sql.SparkSession.interruptOperation"),
-      ProblemFilters.exclude[DirectMissingMethodProblem](
-        "org.apache.spark.sql.SparkSession.addTag"),
-      ProblemFilters.exclude[DirectMissingMethodProblem](
-        "org.apache.spark.sql.SparkSession.removeTag"),
-      ProblemFilters.exclude[DirectMissingMethodProblem](
-        "org.apache.spark.sql.SparkSession.getTags"),
-      ProblemFilters.exclude[DirectMissingMethodProblem](
-        "org.apache.spark.sql.SparkSession.clearTags"),
       // SparkSession#Builder
       ProblemFilters.exclude[DirectMissingMethodProblem](
         "org.apache.spark.sql.SparkSession#Builder.remote"),
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 485f0abcd25e..042179d86c31 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -27,6 +27,7 @@ import scala.collection.Map
 import scala.collection.concurrent.{Map => ScalaConcurrentMap}
 import scala.collection.immutable
 import scala.collection.mutable.HashMap
+import scala.concurrent.{Future, Promise}
 import scala.jdk.CollectionConverters._
 import scala.reflect.{classTag, ClassTag}
 import scala.util.control.NonFatal
@@ -909,10 +910,20 @@ class SparkContext(config: SparkConf) extends Logging {
    *
    * @since 3.5.0
    */
-  def addJobTag(tag: String): Unit = {
-    SparkContext.throwIfInvalidTag(tag)
+  def addJobTag(tag: String): Unit = addJobTags(Set(tag))
+
+  /**
+   * Add multiple tags to be assigned to all the jobs started by this thread.
+   * See [[addJobTag]] for more details.
+   *
+   * @param tags The tags to be added. Cannot contain ',' (comma) character.
+   *
+   * @since 4.0.0
+   */
+  def addJobTags(tags: Set[String]): Unit = {
+    tags.foreach(SparkContext.throwIfInvalidTag)
     val existingTags = getJobTags()
-    val newTags = (existingTags + 
tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
+    val newTags = (existingTags ++ 
tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
     setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags)
   }
 
@@ -924,10 +935,20 @@ class SparkContext(config: SparkConf) extends Logging {
    *
    * @since 3.5.0
    */
-  def removeJobTag(tag: String): Unit = {
-    SparkContext.throwIfInvalidTag(tag)
+  def removeJobTag(tag: String): Unit = removeJobTags(Set(tag))
+
+  /**
+   * Remove multiple tags to be assigned to all the jobs started by this 
thread.
+   * See [[removeJobTag]] for more details.
+   *
+   * @param tags The tags to be removed. Cannot contain ',' (comma) character.
+   *
+   * @since 4.0.0
+   */
+  def removeJobTags(tags: Set[String]): Unit = {
+    tags.foreach(SparkContext.throwIfInvalidTag)
     val existingTags = getJobTags()
-    val newTags = (existingTags - 
tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
+    val newTags = (existingTags -- 
tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
     if (newTags.isEmpty) {
       clearJobTags()
     } else {
@@ -2684,6 +2705,25 @@ class SparkContext(config: SparkConf) extends Logging {
     dagScheduler.cancelJobGroup(groupId, cancelFutureJobs = true, None)
   }
 
+  /**
+   * Cancel active jobs that have the specified tag. See 
`org.apache.spark.SparkContext.addJobTag`.
+   *
+   * @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
+   * @param reason reason for cancellation.
+   * @return A future with [[ActiveJob]]s, allowing extraction of information 
such as Job ID and
+   *   tags.
+   */
+  private[spark] def cancelJobsWithTagWithFuture(
+      tag: String,
+      reason: String): Future[Seq[ActiveJob]] = {
+    SparkContext.throwIfInvalidTag(tag)
+    assertNotStopped()
+
+    val cancelledJobs = Promise[Seq[ActiveJob]]()
+    dagScheduler.cancelJobsWithTag(tag, Some(reason), Some(cancelledJobs))
+    cancelledJobs.future
+  }
+
   /**
    * Cancel active jobs that have the specified tag. See 
`org.apache.spark.SparkContext.addJobTag`.
    *
@@ -2695,7 +2735,7 @@ class SparkContext(config: SparkConf) extends Logging {
   def cancelJobsWithTag(tag: String, reason: String): Unit = {
     SparkContext.throwIfInvalidTag(tag)
     assertNotStopped()
-    dagScheduler.cancelJobsWithTag(tag, Option(reason))
+    dagScheduler.cancelJobsWithTag(tag, Option(reason), cancelledJobs = None)
   }
 
   /**
@@ -2708,7 +2748,7 @@ class SparkContext(config: SparkConf) extends Logging {
   def cancelJobsWithTag(tag: String): Unit = {
     SparkContext.throwIfInvalidTag(tag)
     assertNotStopped()
-    dagScheduler.cancelJobsWithTag(tag, None)
+    dagScheduler.cancelJobsWithTag(tag, reason = None, cancelledJobs = None)
   }
 
   /** Cancel all jobs that have been scheduled or are running.  */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 6c824e2fdeae..2c89fe7885d0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -27,6 +27,7 @@ import scala.annotation.tailrec
 import scala.collection.Map
 import scala.collection.mutable
 import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+import scala.concurrent.Promise
 import scala.concurrent.duration._
 import scala.util.control.NonFatal
 
@@ -1116,11 +1117,18 @@ private[spark] class DAGScheduler(
 
   /**
    * Cancel all jobs with a given tag.
+   *
+   * @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
+   * @param reason reason for cancellation.
+   * @param cancelledJobs a promise to be completed with operation IDs being 
cancelled.
    */
-  def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = {
+  def cancelJobsWithTag(
+      tag: String,
+      reason: Option[String],
+      cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = {
     SparkContext.throwIfInvalidTag(tag)
     logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}")
-    eventProcessLoop.post(JobTagCancelled(tag, reason))
+    eventProcessLoop.post(JobTagCancelled(tag, reason, cancelledJobs))
   }
 
   /**
@@ -1234,17 +1242,22 @@ private[spark] class DAGScheduler(
     jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
   }
 
-  private[scheduler] def handleJobTagCancelled(tag: String, reason: 
Option[String]): Unit = {
-    // Cancel all jobs belonging that have this tag.
+  private[scheduler] def handleJobTagCancelled(
+      tag: String,
+      reason: Option[String],
+      cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = {
+    // Cancel all jobs that have all provided tags.
     // First finds all active jobs with this group id, and then kill stages 
for them.
-    val jobIds = activeJobs.filter { activeJob =>
+    val jobsToBeCancelled = activeJobs.filter { activeJob =>
       Option(activeJob.properties).exists { properties =>
         
Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("")
           
.split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag)
       }
-    }.map(_.jobId)
-    val updatedReason = reason.getOrElse("part of cancelled job tag 
%s".format(tag))
-    jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
+    }
+    val updatedReason =
+      reason.getOrElse("part of cancelled job tags %s".format(tag))
+    jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_, 
Option(updatedReason)))
+    cancelledJobs.map(_.success(jobsToBeCancelled.toSeq))
   }
 
   private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): 
Unit = {
@@ -3113,8 +3126,8 @@ private[scheduler] class 
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
     case JobGroupCancelled(groupId, cancelFutureJobs, reason) =>
       dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason)
 
-    case JobTagCancelled(tag, reason) =>
-      dagScheduler.handleJobTagCancelled(tag, reason)
+    case JobTagCancelled(tag, reason, cancelledJobs) =>
+      dagScheduler.handleJobTagCancelled(tag, reason, cancelledJobs)
 
     case AllJobsCancelled =>
       dagScheduler.doCancelAllJobs()
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index c9ad54d1fdc7..8932d2ef323b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,6 +19,8 @@ package org.apache.spark.scheduler
 
 import java.util.Properties
 
+import scala.concurrent.Promise
+
 import org.apache.spark._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.{AccumulatorV2, CallSite}
@@ -71,7 +73,8 @@ private[scheduler] case class JobGroupCancelled(
 
 private[scheduler] case class JobTagCancelled(
     tagName: String,
-    reason: Option[String]) extends DAGSchedulerEvent
+    reason: Option[String],
+    cancelledJobs: Option[Promise[Seq[ActiveJob]]]) extends DAGSchedulerEvent
 
 private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
 
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
index 0580931620aa..63d4a12e1183 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
@@ -401,6 +401,98 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] 
extends Serializable with C
   @scala.annotation.varargs
   def addArtifacts(uri: URI*): Unit
 
+  /**
+   * Add a tag to be assigned to all the operations started by this thread in 
this session.
+   *
+   * Often, a unit of execution in an application consists of multiple Spark 
executions.
+   * Application programmers can use this method to group all those jobs 
together and give a group
+   * tag. The application can use 
`org.apache.spark.sql.SparkSession.interruptTag` to cancel all
+   * running executions with this tag. For example:
+   * {{{
+   * // In the main thread:
+   * spark.addTag("myjobs")
+   * spark.range(10).map(i => { Thread.sleep(10); i }).collect()
+   *
+   * // In a separate thread:
+   * spark.interruptTag("myjobs")
+   * }}}
+   *
+   * There may be multiple tags present at the same time, so different parts 
of application may
+   * use different tags to perform cancellation at different levels of 
granularity.
+   *
+   * @param tag
+   *   The tag to be added. Cannot contain ',' (comma) character or be an 
empty string.
+   *
+   * @since 4.0.0
+   */
+  def addTag(tag: String): Unit
+
+  /**
+   * Remove a tag previously added to be assigned to all the operations 
started by this thread in
+   * this session. Noop if such a tag was not added earlier.
+   *
+   * @param tag
+   *   The tag to be removed. Cannot contain ',' (comma) character or be an 
empty string.
+   *
+   * @since 4.0.0
+   */
+  def removeTag(tag: String): Unit
+
+  /**
+   * Get the operation tags that are currently set to be assigned to all the 
operations started by
+   * this thread in this session.
+   *
+   * @since 4.0.0
+   */
+  def getTags(): Set[String]
+
+  /**
+   * Clear the current thread's operation tags.
+   *
+   * @since 4.0.0
+   */
+  def clearTags(): Unit
+
+  /**
+   * Request to interrupt all currently running operations of this session.
+   *
+   * @note
+   *   This method will wait up to 60 seconds for the interruption request to 
be issued.
+   *
+   * @return
+   *   Sequence of operation IDs requested to be interrupted.
+   *
+   * @since 4.0.0
+   */
+  def interruptAll(): Seq[String]
+
+  /**
+   * Request to interrupt all currently running operations of this session 
with the given job tag.
+   *
+   * @note
+   *   This method will wait up to 60 seconds for the interruption request to 
be issued.
+   *
+   * @return
+   *   Sequence of operation IDs requested to be interrupted.
+   *
+   * @since 4.0.0
+   */
+  def interruptTag(tag: String): Seq[String]
+
+  /**
+   * Request to interrupt an operation of this session, given its operation ID.
+   *
+   * @note
+   *   This method will wait up to 60 seconds for the interruption request to 
be issued.
+   *
+   * @return
+   *   The operation ID requested to be interrupted, as a single-element 
sequence, or an empty
+   *   sequence if the operation is not started by this session.
+   *
+   * @since 4.0.0
+   */
+  def interruptOperation(operationId: String): Seq[String]
+
   /**
    * Returns a [[DataFrameReader]] that can be used to read non-streaming data 
in as a
    * `DataFrame`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 5746b942341f..720b77b0b9fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql
 import java.net.URI
 import java.nio.file.Paths
 import java.util.{ServiceLoader, UUID}
+import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
 
+import scala.concurrent.duration.DurationInt
 import scala.jdk.CollectionConverters._
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
@@ -57,7 +59,7 @@ import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.streaming._
 import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.sql.util.ExecutionListenerManager
-import org.apache.spark.util.{CallSite, SparkFileUtils, Utils}
+import org.apache.spark.util.{CallSite, SparkFileUtils, ThreadUtils, Utils}
 import org.apache.spark.util.ArrayImplicits._
 
 /**
@@ -92,7 +94,8 @@ class SparkSession private(
     @transient private val existingSharedState: Option[SharedState],
     @transient private val parentSessionState: Option[SessionState],
     @transient private[sql] val extensions: SparkSessionExtensions,
-    @transient private[sql] val initialSessionOptions: Map[String, String])
+    @transient private[sql] val initialSessionOptions: Map[String, String],
+    @transient private val parentManagedJobTags: Map[String, String])
   extends api.SparkSession[Dataset] with Logging { self =>
 
   // The call site where this SparkSession was constructed.
@@ -107,7 +110,12 @@ class SparkSession private(
   private[sql] def this(
       sc: SparkContext,
       initialSessionOptions: java.util.HashMap[String, String]) = {
-    this(sc, None, None, applyAndLoadExtensions(sc), 
initialSessionOptions.asScala.toMap)
+    this(
+      sc,
+      existingSharedState = None,
+      parentSessionState = None,
+      applyAndLoadExtensions(sc), initialSessionOptions.asScala.toMap,
+      parentManagedJobTags = Map.empty)
   }
 
   private[sql] def this(sc: SparkContext) = this(sc, new 
java.util.HashMap[String, String]())
@@ -122,6 +130,18 @@ class SparkSession private(
       .getOrElse(SQLConf.getFallbackConf)
   })
 
+  /** Tag to mark all jobs owned by this session. */
+  private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID"
+
+  /**
+   * A map to hold the mapping from user-defined tags to the real tags 
attached to Jobs.
+   * Real tag have the current session ID attached: `"tag1" -> 
s"spark-session-$sessionUUID-tag1"`.
+   */
+  @transient
+  private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = {
+    new ConcurrentHashMap(parentManagedJobTags.asJava)
+  }
+
   /** @inheritdoc */
   def version: String = SPARK_VERSION
 
@@ -235,7 +255,8 @@ class SparkSession private(
       Some(sharedState),
       parentSessionState = None,
       extensions,
-      initialSessionOptions)
+      initialSessionOptions,
+      parentManagedJobTags = Map.empty)
   }
 
   /**
@@ -256,8 +277,10 @@ class SparkSession private(
       Some(sharedState),
       Some(sessionState),
       extensions,
-      Map.empty)
+      Map.empty,
+      managedJobTags.asScala.toMap)
     result.sessionState // force copy of SessionState
+    result.managedJobTags // force copy of userDefinedToRealTagsMap
     result
   }
 
@@ -636,6 +659,83 @@ class SparkSession private(
     artifactManager.addLocalArtifacts(uri.flatMap(Artifact.parseArtifacts))
   }
 
+  /** @inheritdoc */
+  override def addTag(tag: String): Unit = {
+    SparkContext.throwIfInvalidTag(tag)
+    managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag")
+  }
+
+  /** @inheritdoc */
+  override def removeTag(tag: String): Unit = managedJobTags.remove(tag)
+
+  /** @inheritdoc */
+  override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet
+
+  /** @inheritdoc */
+  override def clearTags(): Unit = managedJobTags.clear()
+
+  /**
+   * Request to interrupt all currently running SQL operations of this session.
+   *
+   * @note Only DataFrame/SQL operations started by this session can be 
interrupted.
+   *
+   * @note This method will wait up to 60 seconds for the interruption request 
to be issued.
+
+   * @return Sequence of SQL execution IDs requested to be interrupted.
+
+   * @since 4.0.0
+   */
+  override def interruptAll(): Seq[String] =
+    doInterruptTag(sessionJobTag, "as part of cancellation of all jobs")
+
+  /**
+   * Request to interrupt all currently running SQL operations of this session 
with the given
+   * job tag.
+   *
+   * @note Only DataFrame/SQL operations started by this session can be 
interrupted.
+   *
+   * @note This method will wait up to 60 seconds for the interruption request 
to be issued.
+   *
+   * @return Sequence of SQL execution IDs requested to be interrupted.
+
+   * @since 4.0.0
+   */
+  override def interruptTag(tag: String): Seq[String] = {
+    val realTag = managedJobTags.get(tag)
+    if (realTag == null) return Seq.empty
+    doInterruptTag(realTag, s"part of cancelled job tags $tag")
+  }
+
+  private def doInterruptTag(tag: String, reason: String): Seq[String] = {
+    val cancelledTags =
+      sparkContext.cancelJobsWithTagWithFuture(tag, reason)
+
+    ThreadUtils.awaitResult(cancelledTags, 60.seconds)
+      .flatMap(job => 
Option(job.properties.getProperty(SQLExecution.EXECUTION_ROOT_ID_KEY)))
+  }
+
+  /**
+   * Request to interrupt a SQL operation of this session, given its SQL 
execution ID.
+   *
+   * @note Only DataFrame/SQL operations started by this session can be 
interrupted.
+   *
+   * @note This method will wait up to 60 seconds for the interruption request 
to be issued.
+   *
+   * @return The execution ID requested to be interrupted, as a single-element 
sequence, or an empty
+   *    sequence if the operation is not started by this session.
+   *
+   * @since 4.0.0
+   */
+  override def interruptOperation(operationId: String): Seq[String] = {
+    scala.util.Try(operationId.toLong).toOption match {
+      case Some(executionIdToBeCancelled) =>
+        val tagToBeCancelled = SQLExecution.executionIdJobTag(this, 
executionIdToBeCancelled)
+        doInterruptTag(tagToBeCancelled, reason = "")
+      case None =>
+        throw new IllegalArgumentException("executionId must be a number in 
string form.")
+    }
+  }
+
   /** @inheritdoc */
   def read: DataFrameReader = new DataFrameReader(self)
 
@@ -722,7 +822,7 @@ class SparkSession private(
   }
 
   /**
-   * Execute a block of code with the this session set as the active session, 
and restore the
+   * Execute a block of code with this session set as the active session, and 
restore the
    * previous session on completion.
    */
   private[sql] def withActive[T](block: => T): T = {
@@ -958,7 +1058,12 @@ object SparkSession extends Logging {
         loadExtensions(extensions)
         applyExtensions(sparkContext, extensions)
 
-        session = new SparkSession(sparkContext, None, None, extensions, 
options.toMap)
+        session = new SparkSession(sparkContext,
+          existingSharedState = None,
+          parentSessionState = None,
+          extensions,
+          initialSessionOptions = options.toMap,
+          parentManagedJobTags = Map.empty)
         setDefaultSession(session)
         setActiveSession(session)
         registerContextListener(sparkContext)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 12ff649b621e..5db14a866213 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -44,7 +44,7 @@ object SQLExecution extends Logging {
 
   private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
 
-  private val executionIdToQueryExecution = new ConcurrentHashMap[Long, 
QueryExecution]()
+  private[sql] val executionIdToQueryExecution = new ConcurrentHashMap[Long, 
QueryExecution]()
 
   def getQueryExecution(executionId: Long): QueryExecution = {
     executionIdToQueryExecution.get(executionId)
@@ -52,6 +52,9 @@ object SQLExecution extends Logging {
 
   private val testing = sys.props.contains(IS_TESTING.key)
 
+  private[sql] def executionIdJobTag(session: SparkSession, id: Long) =
+    s"${session.sessionJobTag}-execution-root-id-$id"
+
   private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
     val sc = sparkSession.sparkContext
     // only throw an exception during tests. a missing execution ID should not 
fail a job.
@@ -82,6 +85,7 @@ object SQLExecution extends Logging {
     // And for the root execution, rootExecutionId == executionId.
     if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == null) {
       sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, executionId.toString)
+      sc.addJobTag(executionIdJobTag(sparkSession, executionId))
     }
     val rootExecutionId = sc.getLocalProperty(EXECUTION_ROOT_ID_KEY).toLong
     executionIdToQueryExecution.put(executionId, queryExecution)
@@ -116,92 +120,94 @@ object SQLExecution extends Logging {
       val redactedConfigs = 
sparkSession.sessionState.conf.redactOptions(modifiedConfigs)
 
       withSQLConfPropagated(sparkSession) {
-        var ex: Option[Throwable] = None
-        var isExecutedPlanAvailable = false
-        val startTime = System.nanoTime()
-        val startEvent = SparkListenerSQLExecutionStart(
-          executionId = executionId,
-          rootExecutionId = Some(rootExecutionId),
-          description = desc,
-          details = callSite.longForm,
-          physicalPlanDescription = "",
-          sparkPlanInfo = SparkPlanInfo.EMPTY,
-          time = System.currentTimeMillis(),
-          modifiedConfigs = redactedConfigs,
-          jobTags = sc.getJobTags(),
-          jobGroupId = 
Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
-        )
-        try {
-          body match {
-            case Left(e) =>
-              sc.listenerBus.post(startEvent)
+        withSessionTagsApplied(sparkSession) {
+          var ex: Option[Throwable] = None
+          var isExecutedPlanAvailable = false
+          val startTime = System.nanoTime()
+          val startEvent = SparkListenerSQLExecutionStart(
+            executionId = executionId,
+            rootExecutionId = Some(rootExecutionId),
+            description = desc,
+            details = callSite.longForm,
+            physicalPlanDescription = "",
+            sparkPlanInfo = SparkPlanInfo.EMPTY,
+            time = System.currentTimeMillis(),
+            modifiedConfigs = redactedConfigs,
+            jobTags = sc.getJobTags(),
+            jobGroupId = 
Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
+          )
+          try {
+            body match {
+              case Left(e) =>
+                sc.listenerBus.post(startEvent)
+                throw e
+              case Right(f) =>
+                val planDescriptionMode =
+                  
ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode)
+                val planDesc = 
queryExecution.explainString(planDescriptionMode)
+                val planInfo = try {
+                  SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)
+                } catch {
+                  case NonFatal(e) =>
+                    logDebug("Failed to generate SparkPlanInfo", e)
+                    // If the queryExecution already failed before this, we 
are not able to generate
+                    // the the plan info, so we use and empty graphviz node to 
make the UI happy
+                    SparkPlanInfo.EMPTY
+                }
+                sc.listenerBus.post(
+                  startEvent.copy(physicalPlanDescription = planDesc, 
sparkPlanInfo = planInfo))
+                isExecutedPlanAvailable = true
+                f()
+            }
+          } catch {
+            case e: Throwable =>
+              ex = Some(e)
               throw e
-            case Right(f) =>
-              val planDescriptionMode =
-                
ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode)
-              val planDesc = queryExecution.explainString(planDescriptionMode)
-              val planInfo = try {
-                SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)
-              } catch {
-                case NonFatal(e) =>
-                  logDebug("Failed to generate SparkPlanInfo", e)
-                  // If the queryExecution already failed before this, we are 
not able to generate
-                  // the the plan info, so we use and empty graphviz node to 
make the UI happy
-                  SparkPlanInfo.EMPTY
-              }
-              sc.listenerBus.post(
-                startEvent.copy(physicalPlanDescription = planDesc, 
sparkPlanInfo = planInfo))
-              isExecutedPlanAvailable = true
-              f()
-          }
-        } catch {
-          case e: Throwable =>
-            ex = Some(e)
-            throw e
-        } finally {
-          val endTime = System.nanoTime()
-          val errorMessage = ex.map {
-            case e: SparkThrowable =>
-              SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY)
-            case e =>
-              Utils.exceptionString(e)
-          }
-          if (queryExecution.shuffleCleanupMode != DoNotCleanup
-            && isExecutedPlanAvailable) {
-            val shuffleIds = queryExecution.executedPlan match {
-              case ae: AdaptiveSparkPlanExec =>
-                ae.context.shuffleIds.asScala.keys
-              case _ =>
-                Iterable.empty
+          } finally {
+            val endTime = System.nanoTime()
+            val errorMessage = ex.map {
+              case e: SparkThrowable =>
+                SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY)
+              case e =>
+                Utils.exceptionString(e)
             }
-            shuffleIds.foreach { shuffleId =>
-              queryExecution.shuffleCleanupMode match {
-                case RemoveShuffleFiles =>
-                  // Same as what we do in ContextCleaner.doCleanupShuffle, 
but do not unregister
-                  // the shuffle on MapOutputTracker, so that stage retries 
would be triggered.
-                  // Set blocking to Utils.isTesting to deflake unit tests.
-                  sc.shuffleDriverComponents.removeShuffle(shuffleId, 
Utils.isTesting)
-                case SkipMigration =>
-                  
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
-                case _ => // this should not happen
+            if (queryExecution.shuffleCleanupMode != DoNotCleanup
+              && isExecutedPlanAvailable) {
+              val shuffleIds = queryExecution.executedPlan match {
+                case ae: AdaptiveSparkPlanExec =>
+                  ae.context.shuffleIds.asScala.keys
+                case _ =>
+                  Iterable.empty
+              }
+              shuffleIds.foreach { shuffleId =>
+                queryExecution.shuffleCleanupMode match {
+                  case RemoveShuffleFiles =>
+                    // Same as what we do in ContextCleaner.doCleanupShuffle, 
but do not unregister
+                    // the shuffle on MapOutputTracker, so that stage retries 
would be triggered.
+                    // Set blocking to Utils.isTesting to deflake unit tests.
+                    sc.shuffleDriverComponents.removeShuffle(shuffleId, 
Utils.isTesting)
+                  case SkipMigration =>
+                    
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
+                  case _ => // this should not happen
+                }
               }
             }
+            val event = SparkListenerSQLExecutionEnd(
+              executionId,
+              System.currentTimeMillis(),
+              // Use empty string to indicate no error, as None may mean 
events generated by old
+              // versions of Spark.
+              errorMessage.orElse(Some("")))
+            // Currently only `Dataset.withAction` and 
`DataFrameWriter.runCommand` specify the
+            // `name` parameter. The `ExecutionListenerManager` only watches 
SQL executions with
+            // name. We can specify the execution name in more places in the 
future, so that
+            // `QueryExecutionListener` can track more cases.
+            event.executionName = name
+            event.duration = endTime - startTime
+            event.qe = queryExecution
+            event.executionFailure = ex
+            sc.listenerBus.post(event)
           }
-          val event = SparkListenerSQLExecutionEnd(
-            executionId,
-            System.currentTimeMillis(),
-            // Use empty string to indicate no error, as None may mean events 
generated by old
-            // versions of Spark.
-            errorMessage.orElse(Some("")))
-          // Currently only `Dataset.withAction` and 
`DataFrameWriter.runCommand` specify the `name`
-          // parameter. The `ExecutionListenerManager` only watches SQL 
executions with name. We
-          // can specify the execution name in more places in the future, so 
that
-          // `QueryExecutionListener` can track more cases.
-          event.executionName = name
-          event.duration = endTime - startTime
-          event.qe = queryExecution
-          event.executionFailure = ex
-          sc.listenerBus.post(event)
         }
       }
     } finally {
@@ -211,6 +217,7 @@ object SQLExecution extends Logging {
       // The current execution is the root execution if rootExecutionId == 
executionId.
       if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == executionId.toString) {
         sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, null)
+        sc.removeJobTag(executionIdJobTag(sparkSession, executionId))
       }
       sc.setLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL, 
originalInterruptOnCancel)
     }
@@ -238,15 +245,28 @@ object SQLExecution extends Logging {
     val sc = sparkSession.sparkContext
     val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
     withSQLConfPropagated(sparkSession) {
-      try {
-        sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
-        body
-      } finally {
-        sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+      withSessionTagsApplied(sparkSession) {
+        try {
+          sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
+          body
+        } finally {
+          sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+        }
       }
     }
   }
 
+  private[sql] def withSessionTagsApplied[T](sparkSession: 
SparkSession)(block: => T): T = {
+    val allTags = sparkSession.managedJobTags.values().asScala.toSet + 
sparkSession.sessionJobTag
+    sparkSession.sparkContext.addJobTags(allTags)
+
+    try {
+      block
+    } finally {
+      sparkSession.sparkContext.removeJobTags(allTags)
+    }
+  }
+
   /**
    * Wrap an action with specified SQL configs. These configs will be 
propagated to the executor
    * side via job local properties.
@@ -286,10 +306,13 @@ object SQLExecution extends Logging {
       val originalSession = SparkSession.getActiveSession
       val originalLocalProps = sc.getLocalProperties
       SparkSession.setActiveSession(activeSession)
-      sc.setLocalProperties(localProps)
-      val res = body
-      // reset active session and local props.
-      sc.setLocalProperties(originalLocalProps)
+      val res = withSessionTagsApplied(activeSession) {
+        sc.setLocalProperties(localProps)
+        val res = body
+        // reset active session and local props.
+        sc.setLocalProperties(originalLocalProps)
+        res
+      }
       if (originalSession.nonEmpty) {
         SparkSession.setActiveSession(originalSession.get)
       } else {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
new file mode 100644
index 000000000000..e9fd07ecf18b
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.{ExecutionContext, Future}
+import scala.jdk.CollectionConverters._
+
+import org.scalatest.concurrent.Eventually
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, 
SparkFunSuite}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, 
SparkListenerJobStart}
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.tags.ExtendedSQLTest
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Test cases for the tagging and cancellation APIs provided by 
[[SparkSession]].
+ */
+@ExtendedSQLTest
+class SparkSessionJobTaggingAndCancellationSuite
+  extends SparkFunSuite
+  with Eventually
+  with LocalSparkContext {
+
+  override def afterEach(): Unit = {
+    try {
+      // This suite should not interfere with the other test suites.
+      SparkSession.getActiveSession.foreach(_.stop())
+      SparkSession.clearActiveSession()
+      SparkSession.getDefaultSession.foreach(_.stop())
+      SparkSession.clearDefaultSession()
+      resetSparkContext()
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  test("Tags are not inherited by new sessions") {
+    val session = SparkSession.builder().master("local").getOrCreate()
+
+    assert(session.getTags() == Set())
+    session.addTag("one")
+    assert(session.getTags() == Set("one"))
+
+    val newSession = session.newSession()
+    assert(newSession.getTags() == Set())
+  }
+
+  test("Tags are inherited by cloned sessions") {
+    val session = SparkSession.builder().master("local").getOrCreate()
+
+    assert(session.getTags() == Set())
+    session.addTag("one")
+    assert(session.getTags() == Set("one"))
+
+    val clonedSession = session.cloneSession()
+    assert(clonedSession.getTags() == Set("one"))
+    clonedSession.addTag("two")
+    assert(clonedSession.getTags() == Set("one", "two"))
+
+    // Tags are not propagated back to the original session
+    assert(session.getTags() == Set("one"))
+  }
+
+  test("Tags set from session are prefixed with session UUID") {
+    sc = new SparkContext("local[2]", "test")
+    val session = SparkSession.builder().sparkContext(sc).getOrCreate()
+    import session.implicits._
+
+    val sem = new Semaphore(0)
+    sc.addSparkListener(new SparkListener {
+      override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+        sem.release()
+      }
+    })
+
+    session.addTag("one")
+    Future {
+      session.range(1, 10000).map { i => Thread.sleep(100); i }.count()
+    }(ExecutionContext.global)
+
+    assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
+    val activeJobsFuture =
+      
session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"),
 "reason")
+    val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head
+    val actualTags = 
activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS)
+      .split(SparkContext.SPARK_JOB_TAGS_SEP)
+    assert(actualTags.toSet == Set(
+      session.sessionJobTag,
+      s"${session.sessionJobTag}-one",
+      SQLExecution.executionIdJobTag(
+        session,
+        
activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)))
+  }
+
+  test("Cancellation APIs in SparkSession are isolated") {
+    sc = new SparkContext("local[2]", "test")
+    val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate()
+    var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, 
SparkSession) =
+      (null, null, null)
+
+    // global ExecutionContext has only 2 threads in Apache Spark CI
+    // create own thread pool for four Futures used in this test
+    val numThreads = 3
+    val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", 
numThreads)
+    val executionContext = ExecutionContext.fromExecutorService(fpool)
+
+    try {
+      // Add a listener to release the semaphore once jobs are launched.
+      val sem = new Semaphore(0)
+      val jobEnded = new AtomicInteger(0)
+      val jobProperties: ConcurrentHashMap[Int, java.util.Properties] = new 
ConcurrentHashMap()
+
+      sc.addSparkListener(new SparkListener {
+        override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+          jobProperties.put(jobStart.jobId, jobStart.properties)
+          sem.release()
+        }
+
+        override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+          sem.release()
+          jobEnded.incrementAndGet()
+        }
+      })
+
+      // Note: since tags are added in the Future threads, they don't need to 
be cleared in between.
+      val jobA = Future {
+        sessionA = globalSession.cloneSession()
+        import globalSession.implicits._
+
+        assert(sessionA.getTags() == Set())
+        sessionA.addTag("two")
+        assert(sessionA.getTags() == Set("two"))
+        sessionA.clearTags() // check that clearing all tags works
+        assert(sessionA.getTags() == Set())
+        sessionA.addTag("one")
+        assert(sessionA.getTags() == Set("one"))
+        try {
+          sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count()
+        } finally {
+          sessionA.clearTags() // clear for the case of thread reuse by 
another Future
+        }
+      }(executionContext)
+      val jobB = Future {
+        sessionB = globalSession.cloneSession()
+        import globalSession.implicits._
+
+        assert(sessionB.getTags() == Set())
+        sessionB.addTag("one")
+        sessionB.addTag("two")
+        sessionB.addTag("one")
+        sessionB.addTag("two") // duplicates shouldn't matter
+        assert(sessionB.getTags() == Set("one", "two"))
+        try {
+          sessionB.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count()
+        } finally {
+          sessionB.clearTags() // clear for the case of thread reuse by 
another Future
+        }
+      }(executionContext)
+      val jobC = Future {
+        sessionC = globalSession.cloneSession()
+        import globalSession.implicits._
+
+        sessionC.addTag("foo")
+        sessionC.removeTag("foo")
+        assert(sessionC.getTags() == Set()) // check that remove works 
removing the last tag
+        sessionC.addTag("boo")
+        try {
+          sessionC.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count()
+        } finally {
+          sessionC.clearTags() // clear for the case of thread reuse by 
another Future
+        }
+      }(executionContext)
+
+      // Block until four jobs have started.
+      assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES))
+
+      // Tags are applied
+      assert(jobProperties.size == 3)
+      for (ss <- Seq(sessionA, sessionB, sessionC)) {
+        val jobProperty = 
jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS)
+          .asInstanceOf[String].contains(ss.sessionUUID))
+        assert(jobProperty.size == 1)
+        val tags = 
jobProperty.head.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String]
+          .split(SparkContext.SPARK_JOB_TAGS_SEP)
+
+        val executionRootIdTag = SQLExecution.executionIdJobTag(
+          ss,
+          
jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong)
+        val userTagsPrefix = s"spark-session-${ss.sessionUUID}-"
+
+        ss match {
+          case s if s == sessionA => assert(tags.toSet == Set(
+            s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one"))
+          case s if s == sessionB => assert(tags.toSet == Set(
+            s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", 
s"${userTagsPrefix}two"))
+          case s if s == sessionC => assert(tags.toSet == Set(
+            s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo"))
+        }
+      }
+
+      // Global session cancels nothing
+      assert(globalSession.interruptAll().isEmpty)
+      assert(globalSession.interruptTag("one").isEmpty)
+      assert(globalSession.interruptTag("two").isEmpty)
+      for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) {
+        assert(globalSession.interruptOperation(i.toString).isEmpty)
+      }
+      assert(jobEnded.intValue == 0)
+
+      // One job cancelled
+      for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) {
+        sessionC.interruptOperation(i.toString)
+      }
+      val eC = intercept[SparkException] {
+        ThreadUtils.awaitResult(jobC, 1.minute)
+      }.getCause
+      assert(eC.getMessage contains "cancelled")
+      assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
+      assert(jobEnded.intValue == 1)
+
+      // Another job cancelled
+      assert(sessionA.interruptTag("one").size == 1)
+      val eA = intercept[SparkException] {
+        ThreadUtils.awaitResult(jobA, 1.minute)
+      }.getCause
+      assert(eA.getMessage contains "cancelled job tags one")
+      assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
+      assert(jobEnded.intValue == 2)
+
+      // The last job cancelled
+      sessionB.interruptAll()
+      val eB = intercept[SparkException] {
+        ThreadUtils.awaitResult(jobB, 1.minute)
+      }.getCause
+      assert(eB.getMessage contains "cancellation of all jobs")
+      assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES))
+      assert(jobEnded.intValue == 3)
+    } finally {
+      fpool.shutdownNow()
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
index 94d33731b6de..059a4c9b8376 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
@@ -228,7 +228,7 @@ class SQLExecutionSuite extends SparkFunSuite with 
SQLConfHelper {
       spark.range(1).collect()
 
       spark.sparkContext.listenerBus.waitUntilEmpty()
-      assert(jobTags.contains(jobTag))
+      assert(jobTags.get.contains(jobTag))
       assert(sqlJobTags.contains(jobTag))
     } finally {
       spark.sparkContext.removeJobTag(jobTag)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to