This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 45f25dce182 [SPARK-43331][CONNECT] Add Spark Connect 
SparkSession.interruptAll
45f25dce182 is described below

commit 45f25dce18253a352cd7af4d91d4dcda743f5612
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Wed May 3 11:48:16 2023 -0400

    [SPARK-43331][CONNECT] Add Spark Connect SparkSession.interruptAll
    
    ### What changes were proposed in this pull request?
    
    Currently, queries that are ran using Spark Connect cannot be interrupted. 
Even when the RPC connection got broken, the Spark jobs on the server continue 
running.
    
    This PR proposes a
    ```
    rpc Interrupt(InterruptRequest) returns (InterruptResponse) {}
    ```
    server RPC API, that can be called from the client as 
`SparkSession.interruptAll()` to interrupt all actively running Spark Jobs from 
ExecutePlan executions. In most user scenarios, SparkSession is not used for 
multiple executions concurrently, but is used sequentially, so `interruptAll()` 
should serve a big chunk of user needs. It can also be used to clean up.
    
    To keep track of executions, we introduce `ExecutionHolder` to hold the 
execution state, and make `SessionHolder` keep track of the executions 
currently running in the session. In this first PR, the interrupt only 
interrupts running Spark Jobs. As such, it is to a degree best effort, because 
it will not interrupt commands that don't run Spark Jobs, or it will not 
interrupt anything if a Spark Job is not running when it the interrupt is 
received by the server, and  the command will con [...]
    
    Future work I plan to design and work on will involve:
    * Interrupting any execution. This will involve moving execution from the 
GRPC handler thread handling ExecutePlan to launching it in a separate thread 
that can be interrupted. `ExecutionHolder`
    * Interrupting executions selectively. This will involve exposing the 
operationId to the user.
    * (Refactor) some cleanup needed around using SessionHolder. Currently, 
sessionId, userId, various session data is passed around separately. 
SessionHolder may be passed instead, and then also be extended with more useful 
APIs for management of the session.
    
    ### Why are the changes needed?
    
    Need to have APIs to be able to interrupt running queries in Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    Users of Spark Connect can now call `interruptAll()` on client 
`SparkSession` object, to send an interrupt RPC to the server, which will 
interrupt the running queries.
    
    In followup PRs, this will be extended to Python client, and to work not 
only for interrupting Spark Jobs.
    
    ### How was this patch tested?
    
    Added E2E tests to ClientE2ETestSuite for scala connect client.
    Added unit tests to test_client for python connect client.
    
    Closes #41005 from juliuszsompolski/connect-cancelall.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |  13 +++
 .../sql/connect/client/SparkConnectClient.scala    |  11 +++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  70 +++++++++++++++
 .../src/main/protobuf/spark/connect/base.proto     |  34 +++++++
 .../sql/connect/service/ExecutePlanHolder.scala    |  41 +++++++++
 .../spark/sql/connect/service/SessionHolder.scala  |  63 +++++++++++++
 .../service/SparkConnectInterruptHandler.scala     |  45 ++++++++++
 .../sql/connect/service/SparkConnectService.scala  |  24 +++--
 .../service/SparkConnectStreamHandler.scala        |  47 ++++++----
 python/pyspark/errors/error_classes.py             |   5 ++
 python/pyspark/sql/connect/client.py               |  42 +++++++++
 python/pyspark/sql/connect/proto/base_pb2.py       |  37 +++++++-
 python/pyspark/sql/connect/proto/base_pb2.pyi      | 100 +++++++++++++++++++++
 python/pyspark/sql/connect/proto/base_pb2_grpc.py  |  45 ++++++++++
 python/pyspark/sql/connect/session.py              |   3 +
 python/pyspark/sql/tests/connect/test_client.py    |  14 +++
 16 files changed, 568 insertions(+), 26 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 461b18ec9c1..48640878211 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -536,6 +536,19 @@ class SparkSession private[sql] (
     planIdGenerator.set(0)
   }
 
+  /**
+   * Interrupt all operations of this session currently running on the 
connected server.
+   *
+   * TODO/WIP: Currently it will interrupt the Spark Jobs running on the 
server, triggered from
+   * ExecutePlan requests. If an operation is not running a Spark Job, it 
becomes an noop and the
+   * operation will continue afterwards, possibly with more Spark Jobs.
+   *
+   * @since 3.5.0
+   */
+  def interruptAll(): Unit = {
+    client.interruptAll()
+  }
+
   /**
    * Synonym for `close()`.
    *
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 1d47f3e663f..ca00eb74a20 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -184,6 +184,17 @@ private[sql] class SparkConnectClient(
     analyze(request)
   }
 
+  private[sql] def interruptAll(): proto.InterruptResponse = {
+    val builder = proto.InterruptRequest.newBuilder()
+    val request = builder
+      .setUserContext(userContext)
+      .setSessionId(sessionId)
+      .setClientType(userAgent)
+      
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL)
+      .build()
+    stub.interrupt(request)
+  }
+
   def copy(): SparkConnectClient = {
     new SparkConnectClient(userContext, channelBuilder, userAgent)
   }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index abeeaf7e483..8a01f828ef0 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -20,6 +20,9 @@ import java.io.{ByteArrayOutputStream, PrintStream}
 import java.nio.file.Files
 
 import scala.collection.JavaConverters._
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration._
+import scala.util.{Failure, Success}
 
 import io.grpc.StatusRuntimeException
 import java.util.Properties
@@ -27,6 +30,7 @@ import org.apache.commons.io.FileUtils
 import org.apache.commons.io.output.TeeOutputStream
 import org.apache.commons.lang3.{JavaVersion, SystemUtils}
 import org.scalactic.TolerantNumerics
+import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark.SPARK_VERSION
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
@@ -36,6 +40,7 @@ import 
org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSpa
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
+import org.apache.spark.util.ThreadUtils
 
 class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
 
@@ -867,6 +872,71 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper {
       assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
     }
   }
+
+  test("interrupt all - background queries, foreground interrupt") {
+    val session = spark
+    import session.implicits._
+    implicit val ec = ExecutionContext.global
+    val q1 = Future {
+      spark.range(10).map(n => { Thread.sleep(30000); n }).collect()
+    }
+    val q2 = Future {
+      spark.range(10).map(n => { Thread.sleep(30000); n }).collect()
+    }
+    var q1Interrupted = false
+    var q2Interrupted = false
+    var error: Option[String] = None
+    q1.onComplete {
+      case Success(_) =>
+        error = Some("q1 shouldn't have finished!")
+      case Failure(t) if t.getMessage.contains("cancelled") =>
+        q1Interrupted = true
+      case Failure(t) =>
+        error = Some("unexpected failure in q1: " + t.toString)
+    }
+    q1.onComplete {
+      case Success(_) =>
+        error = Some("q2 shouldn't have finished!")
+      case Failure(t) if t.getMessage.contains("cancelled") =>
+        q2Interrupted = true
+      case Failure(t) =>
+        error = Some("unexpected failure in q2: " + t.toString)
+    }
+    // 20 seconds is < 30 seconds the queries should be running,
+    // because it should be interrupted sooner
+    eventually(timeout(20.seconds), interval(1.seconds)) {
+      // keep interrupting every second, until both queries get interrupted.
+      spark.interruptAll()
+      assert(q1Interrupted)
+      assert(q2Interrupted)
+      assert(error.isEmpty)
+    }
+  }
+
+  test("interrupt all - foreground queries, background interrupt") {
+    val session = spark
+    import session.implicits._
+    implicit val ec = ExecutionContext.global
+
+    @volatile var finished = false
+    val interruptor = Future {
+      eventually(timeout(20.seconds), interval(1.seconds)) {
+        spark.interruptAll()
+        assert(finished)
+      }
+      finished
+    }
+    val e1 = intercept[io.grpc.StatusRuntimeException] {
+      spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n 
}).collect()
+    }
+    assert(e1.getMessage.contains("cancelled"))
+    val e2 = intercept[io.grpc.StatusRuntimeException] {
+      spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n 
}).collect()
+    }
+    assert(e2.getMessage.contains("cancelled"))
+    finished = true
+    assert(ThreadUtils.awaitResult(interruptor, 10.seconds) == true)
+  }
 }
 
 private[sql] case class MyType(id: Long, a: Double, b: Double)
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 1f572178404..b304a8aea63 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -582,6 +582,37 @@ message ArtifactStatusesResponse {
   map<string, ArtifactStatus> statuses = 1;
 }
 
+message InterruptRequest {
+  // (Required)
+  //
+  // The session_id specifies a spark session for a user id (which is specified
+  // by user_context.user_id). The session_id is set by the client to be able 
to
+  // collate streaming responses from different queries within the dedicated 
session.
+  string session_id = 1;
+
+  // (Required) User context
+  UserContext user_context = 2;
+
+  // Provides optional information about the client sending the request. This 
field
+  // can be used for language or version specific information and is only 
intended for
+  // logging purposes and will not be interpreted by the server.
+  optional string client_type = 3;
+
+  // (Required) The type of interrupt to execute.
+  InterruptType interrupt_type = 4;
+
+  enum InterruptType {
+    INTERRUPT_TYPE_UNSPECIFIED = 0;
+
+    // Interrupt all running executions within session with provided 
session_id.
+    INTERRUPT_TYPE_ALL = 1;
+  }
+}
+
+message InterruptResponse {
+  string session_id = 1;
+}
+
 // Main interface for the SparkConnect service.
 service SparkConnectService {
 
@@ -602,5 +633,8 @@ service SparkConnectService {
 
   // Check statuses of artifacts in the session and returns them in a 
[[ArtifactStatusesResponse]]
   rpc ArtifactStatus(ArtifactStatusesRequest) returns 
(ArtifactStatusesResponse) {}
+
+  // Interrupts running executions
+  rpc Interrupt(InterruptRequest) returns (InterruptResponse) {}
 }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
new file mode 100644
index 00000000000..a3c17b9826e
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecutePlanHolder.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import org.apache.spark.connect.proto
+
+/**
+ * Object used to hold the Spark Connect execution state.
+ */
+case class ExecutePlanHolder(
+    operationId: String,
+    sessionHolder: SessionHolder,
+    request: proto.ExecutePlanRequest) {
+
+  val jobGroupId =
+    
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Request_${operationId}"
+
+  def interrupt(): Unit = {
+    // TODO/WIP: This only interrupts active Spark jobs that are actively 
running.
+    // This would then throw the error from ExecutePlan and terminate it.
+    // But if the query is not running a Spark job, but executing code on 
Spark driver, this
+    // would be a noop and the execution will keep running.
+    sessionHolder.session.sparkContext.cancelJobGroup(jobGroupId)
+  }
+
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
new file mode 100644
index 00000000000..613d7a38e9e
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
+
+import collection.JavaConverters._
+import scala.util.control.NonFatal
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Object used to hold the Spark Connect session state.
+ */
+case class SessionHolder(userId: String, sessionId: String, session: 
SparkSession)
+    extends Logging {
+
+  val executePlanOperations: ConcurrentMap[String, ExecutePlanHolder] =
+    new ConcurrentHashMap[String, ExecutePlanHolder]()
+
+  private[connect] def createExecutePlanHolder(
+      request: proto.ExecutePlanRequest): ExecutePlanHolder = {
+
+    val operationId = UUID.randomUUID().toString
+    val executePlanHolder = ExecutePlanHolder(operationId, this, request)
+    assert(executePlanOperations.putIfAbsent(operationId, executePlanHolder) 
== null)
+    executePlanHolder
+  }
+
+  private[connect] def removeExecutePlanHolder(operationId: String): Unit = {
+    executePlanOperations.remove(operationId)
+  }
+
+  private[connect] def interruptAll(): Unit = {
+    executePlanOperations.asScala.values.foreach { execute =>
+      // Eat exception while trying to interrupt a given execution and move 
forward.
+      try {
+        execute.interrupt()
+      } catch {
+        case NonFatal(e) =>
+          logWarning(s"Exception $e while trying to interrupt execution 
${execute.operationId}")
+      }
+    }
+  }
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
new file mode 100644
index 00000000000..b0923e277e4
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+
+class SparkConnectInterruptHandler(responseObserver: 
StreamObserver[proto.InterruptResponse])
+    extends Logging {
+
+  def handle(v: proto.InterruptRequest): Unit = {
+    val sessionHolder =
+      SparkConnectService
+        .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+
+    v.getInterruptType match {
+      case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL =>
+        sessionHolder.interruptAll()
+      case other =>
+        throw new UnsupportedOperationException(s"Unknown InterruptType 
$other!")
+    }
+
+    val builder = 
proto.InterruptResponse.newBuilder().setSessionId(v.getSessionId)
+
+    responseObserver.onNext(builder.build())
+    responseObserver.onCompleted()
+  }
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index bfe1512a49e..b444fc67ce1 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -245,15 +245,23 @@ class SparkConnectService(debug: Boolean)
         userId = request.getUserContext.getUserId,
         sessionId = request.getSessionId)
   }
-}
 
-/**
- * Object used for referring to SparkSessions in the SessionCache.
- *
- * @param userId
- * @param session
- */
-case class SessionHolder(userId: String, sessionId: String, session: 
SparkSession)
+  /**
+   * This is the entry point for calls interrupting running executions.
+   */
+  override def interrupt(
+      request: proto.InterruptRequest,
+      responseObserver: StreamObserver[proto.InterruptResponse]): Unit = {
+    try {
+      new SparkConnectInterruptHandler(responseObserver).handle(request)
+    } catch
+      handleError(
+        "interrupt",
+        observer = responseObserver,
+        userId = request.getUserContext.getUserId,
+        sessionId = request.getSessionId)
+  }
+}
 
 /**
  * Static instance of the SparkConnectService.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index c544f484381..062ef892979 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
 
 import com.google.protobuf.ByteString
 import io.grpc.stub.StreamObserver
@@ -47,18 +48,30 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[ExecutePlanResp
     extends Logging {
 
   def handle(v: ExecutePlanRequest): Unit = 
SparkConnectArtifactManager.withArtifactClassLoader {
-    val session =
-      SparkConnectService
-        .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
-        .session
-    session.withActive {
+    val sessionHolder = SparkConnectService
+      .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+    val session = sessionHolder.session
 
-      // Add debug information to the query execution so that the jobs are 
traceable.
-      try {
-        val debugString =
+    session.withActive {
+      val debugString =
+        try {
           Utils.redact(
             session.sessionState.conf.stringRedactionPattern,
             ProtoUtils.abbreviate(v).toString)
+        } catch {
+          case NonFatal(e) =>
+            logWarning("Fail to extract debug information", e)
+            "UNKNOWN"
+        }
+
+      val executeHolder = sessionHolder.createExecutePlanHolder(v)
+      session.sparkContext.setJobGroup(
+        executeHolder.jobGroupId,
+        s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}",
+        interruptOnCancel = true)
+
+      try {
+        // Add debug information to the query execution so that the jobs are 
traceable.
         session.sparkContext.setLocalProperty(
           "callSite.short",
           s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
@@ -66,15 +79,19 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[ExecutePlanResp
           "callSite.long",
           StringUtils.abbreviate(debugString, 2048))
       } catch {
-        case e: Throwable =>
-          logWarning("Fail to extract or attach the debug information", e)
+        case NonFatal(e) =>
+          logWarning("Fail to attach the debug information", e)
       }
 
-      v.getPlan.getOpTypeCase match {
-        case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
-        case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
-        case _ =>
-          throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} 
not supported.")
+      try {
+        v.getPlan.getOpTypeCase match {
+          case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
+          case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
+          case _ =>
+            throw new 
UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.")
+        }
+      } finally {
+        sessionHolder.removeExecutePlanHolder(executeHolder.operationId)
       }
     }
   }
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 2f559b6a538..3f52a14a607 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -566,6 +566,11 @@ ERROR_CLASSES_JSON = """
       "Unknown explain mode: '<explain_mode>'. Accepted explain modes are 
'simple', 'extended', 'codegen', 'cost', 'formatted'."
     ]
   },
+  "UNKNOWN_INTERRUPT_TYPE" : {
+    "message" : [
+      "Unknown interrupt type: '<interrupt_type>'. Accepted interrupt types 
are 'all'."
+    ]
+  },
   "UNKNOWN_RESPONSE" : {
     "message" : [
       "Unknown response: <response>."
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 954ec0dfd04..ffcc4f768ae 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -1099,6 +1099,48 @@ class SparkConnectClient(object):
         except Exception as error:
             self._handle_error(error)
 
+    def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest:
+        req = pb2.InterruptRequest()
+        req.session_id = self._session_id
+        req.client_type = self._builder.userAgent
+        if interrupt_type == "all":
+            req.interrupt_type = 
pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL
+        else:
+            raise PySparkValueError(
+                error_class="UNKNOWN_INTERRUPT_TYPE",
+                message_parameters={
+                    "interrupt_type": str(interrupt_type),
+                },
+            )
+        if self._user_id:
+            req.user_context.user_id = self._user_id
+        return req
+
+    def interrupt_all(self) -> None:
+        """
+        Call the interrupt RPC of Spark Connect to interrupt all executions in 
this session.
+
+        Returns
+        -------
+        None
+        """
+        req = self._interrupt_request("all")
+        try:
+            for attempt in Retrying(
+                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+            ):
+                with attempt:
+                    resp = self._stub.Interrupt(req, 
metadata=self._builder.metadata())
+                    if resp.session_id != self._session_id:
+                        raise SparkConnectException(
+                            "Received incorrect session identifier for 
request:"
+                            f"{resp.session_id} != {self._session_id}"
+                        )
+                    return
+            raise SparkConnectException("Invalid state during retry exception 
handling.")
+        except Exception as error:
+            self._handle_error(error)
+
     def _handle_error(self, error: Exception) -> NoReturn:
         """
         Handle errors that occur during RPC calls.
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index b6d86b3efbd..997f59911da 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -38,7 +38,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
 )
 
 
@@ -121,9 +121,12 @@ _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS = 
_ARTIFACTSTATUSESRESPONSE.nested_type
 _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY = 
_ARTIFACTSTATUSESRESPONSE.nested_types_by_name[
     "StatusesEntry"
 ]
+_INTERRUPTREQUEST = DESCRIPTOR.message_types_by_name["InterruptRequest"]
+_INTERRUPTRESPONSE = DESCRIPTOR.message_types_by_name["InterruptResponse"]
 _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE = 
_ANALYZEPLANREQUEST_EXPLAIN.enum_types_by_name[
     "ExplainMode"
 ]
+_INTERRUPTREQUEST_INTERRUPTTYPE = 
_INTERRUPTREQUEST.enum_types_by_name["InterruptType"]
 Plan = _reflection.GeneratedProtocolMessageType(
     "Plan",
     (_message.Message,),
@@ -747,6 +750,28 @@ _sym_db.RegisterMessage(ArtifactStatusesResponse)
 _sym_db.RegisterMessage(ArtifactStatusesResponse.ArtifactStatus)
 _sym_db.RegisterMessage(ArtifactStatusesResponse.StatusesEntry)
 
+InterruptRequest = _reflection.GeneratedProtocolMessageType(
+    "InterruptRequest",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _INTERRUPTREQUEST,
+        "__module__": "spark.connect.base_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.InterruptRequest)
+    },
+)
+_sym_db.RegisterMessage(InterruptRequest)
+
+InterruptResponse = _reflection.GeneratedProtocolMessageType(
+    "InterruptResponse",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _INTERRUPTRESPONSE,
+        "__module__": "spark.connect.base_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.InterruptResponse)
+    },
+)
+_sym_db.RegisterMessage(InterruptResponse)
+
 _SPARKCONNECTSERVICE = DESCRIPTOR.services_by_name["SparkConnectService"]
 if _descriptor._USE_C_DESCRIPTORS == False:
 
@@ -880,6 +905,12 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9194
     _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9196
     _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9311
-    _SPARKCONNECTSERVICE._serialized_start = 9314
-    _SPARKCONNECTSERVICE._serialized_end = 9780
+    _INTERRUPTREQUEST._serialized_start = 9314
+    _INTERRUPTREQUEST._serialized_end = 9639
+    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 9552
+    _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 9623
+    _INTERRUPTRESPONSE._serialized_start = 9641
+    _INTERRUPTRESPONSE._serialized_end = 9691
+    _SPARKCONNECTSERVICE._serialized_start = 9694
+    _SPARKCONNECTSERVICE._serialized_end = 10242
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index df1d3055a5a..fb1bf92c3c3 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -2149,3 +2149,103 @@ class 
ArtifactStatusesResponse(google.protobuf.message.Message):
     ) -> None: ...
 
 global___ArtifactStatusesResponse = ArtifactStatusesResponse
+
+class InterruptRequest(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    class _InterruptType:
+        ValueType = typing.NewType("ValueType", builtins.int)
+        V: typing_extensions.TypeAlias = ValueType
+
+    class _InterruptTypeEnumTypeWrapper(
+        google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+            InterruptRequest._InterruptType.ValueType
+        ],
+        builtins.type,
+    ):  # noqa: F821
+        DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+        INTERRUPT_TYPE_UNSPECIFIED: InterruptRequest._InterruptType.ValueType  
# 0
+        INTERRUPT_TYPE_ALL: InterruptRequest._InterruptType.ValueType  # 1
+        """Interrupt all running executions within session with provided 
session_id."""
+
+    class InterruptType(_InterruptType, 
metaclass=_InterruptTypeEnumTypeWrapper): ...
+    INTERRUPT_TYPE_UNSPECIFIED: InterruptRequest.InterruptType.ValueType  # 0
+    INTERRUPT_TYPE_ALL: InterruptRequest.InterruptType.ValueType  # 1
+    """Interrupt all running executions within session with provided 
session_id."""
+
+    SESSION_ID_FIELD_NUMBER: builtins.int
+    USER_CONTEXT_FIELD_NUMBER: builtins.int
+    CLIENT_TYPE_FIELD_NUMBER: builtins.int
+    INTERRUPT_TYPE_FIELD_NUMBER: builtins.int
+    session_id: builtins.str
+    """(Required)
+
+    The session_id specifies a spark session for a user id (which is specified
+    by user_context.user_id). The session_id is set by the client to be able to
+    collate streaming responses from different queries within the dedicated 
session.
+    """
+    @property
+    def user_context(self) -> global___UserContext:
+        """(Required) User context"""
+    client_type: builtins.str
+    """Provides optional information about the client sending the request. 
This field
+    can be used for language or version specific information and is only 
intended for
+    logging purposes and will not be interpreted by the server.
+    """
+    interrupt_type: global___InterruptRequest.InterruptType.ValueType
+    """(Required) The type of interrupt to execute."""
+    def __init__(
+        self,
+        *,
+        session_id: builtins.str = ...,
+        user_context: global___UserContext | None = ...,
+        client_type: builtins.str | None = ...,
+        interrupt_type: global___InterruptRequest.InterruptType.ValueType = 
...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_client_type",
+            b"_client_type",
+            "client_type",
+            b"client_type",
+            "user_context",
+            b"user_context",
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_client_type",
+            b"_client_type",
+            "client_type",
+            b"client_type",
+            "interrupt_type",
+            b"interrupt_type",
+            "session_id",
+            b"session_id",
+            "user_context",
+            b"user_context",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_client_type", 
b"_client_type"]
+    ) -> typing_extensions.Literal["client_type"] | None: ...
+
+global___InterruptRequest = InterruptRequest
+
+class InterruptResponse(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    SESSION_ID_FIELD_NUMBER: builtins.int
+    session_id: builtins.str
+    def __init__(
+        self,
+        *,
+        session_id: builtins.str = ...,
+    ) -> None: ...
+    def ClearField(
+        self, field_name: typing_extensions.Literal["session_id", 
b"session_id"]
+    ) -> None: ...
+
+global___InterruptResponse = InterruptResponse
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py 
b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index ecbe4f9c389..74eda7dfeac 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -55,6 +55,11 @@ class SparkConnectServiceStub(object):
             
request_serializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesRequest.SerializeToString,
             
response_deserializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesResponse.FromString,
         )
+        self.Interrupt = channel.unary_unary(
+            "/spark.connect.SparkConnectService/Interrupt",
+            
request_serializer=spark_dot_connect_dot_base__pb2.InterruptRequest.SerializeToString,
+            
response_deserializer=spark_dot_connect_dot_base__pb2.InterruptResponse.FromString,
+        )
 
 
 class SparkConnectServiceServicer(object):
@@ -95,6 +100,12 @@ class SparkConnectServiceServicer(object):
         context.set_details("Method not implemented!")
         raise NotImplementedError("Method not implemented!")
 
+    def Interrupt(self, request, context):
+        """Interrupts running executions"""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details("Method not implemented!")
+        raise NotImplementedError("Method not implemented!")
+
 
 def add_SparkConnectServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
@@ -123,6 +134,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, 
server):
             
request_deserializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesRequest.FromString,
             
response_serializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesResponse.SerializeToString,
         ),
+        "Interrupt": grpc.unary_unary_rpc_method_handler(
+            servicer.Interrupt,
+            
request_deserializer=spark_dot_connect_dot_base__pb2.InterruptRequest.FromString,
+            
response_serializer=spark_dot_connect_dot_base__pb2.InterruptResponse.SerializeToString,
+        ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
         "spark.connect.SparkConnectService", rpc_method_handlers
@@ -278,3 +294,32 @@ class SparkConnectService(object):
             timeout,
             metadata,
         )
+
+    @staticmethod
+    def Interrupt(
+        request,
+        target,
+        options=(),
+        channel_credentials=None,
+        call_credentials=None,
+        insecure=False,
+        compression=None,
+        wait_for_ready=None,
+        timeout=None,
+        metadata=None,
+    ):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            "/spark.connect.SparkConnectService/Interrupt",
+            spark_dot_connect_dot_base__pb2.InterruptRequest.SerializeToString,
+            spark_dot_connect_dot_base__pb2.InterruptResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+        )
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 5aecb0d3821..fde861d12b9 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -489,6 +489,9 @@ class SparkSession:
         except Exception:
             pass
 
+    def interrupt_all(self) -> None:
+        self.client.interrupt_all()
+
     def stop(self) -> None:
         # Stopping the session will only close the connection to the current 
session (and
         # the life cycle of the session is maintained by the server),
diff --git a/python/pyspark/sql/tests/connect/test_client.py 
b/python/pyspark/sql/tests/connect/test_client.py
index d72a541a675..191a5204bf3 100644
--- a/python/pyspark/sql/tests/connect/test_client.py
+++ b/python/pyspark/sql/tests/connect/test_client.py
@@ -68,6 +68,14 @@ class SparkConnectClientTestCase(unittest.TestCase):
 
         self.assertEqual(client._user_id, "abc")
 
+    def test_interrupt_all(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+        mock = MockService(client._session_id)
+        client._stub = mock
+
+        client.interrupt_all()
+        self.assertIsNotNone(mock.req, "Interrupt API was not called when 
expected")
+
 
 class MockService:
     # Simplest mock of the SparkConnectService.
@@ -97,6 +105,12 @@ class MockService:
         resp.arrow_batch.data = buf.to_pybytes()
         return [resp]
 
+    def Interrupt(self, req: proto.InterruptRequest, metadata):
+        self.req = req
+        resp = proto.InterruptResponse()
+        resp.session_id = self._session_id
+        return resp
+
 
 if __name__ == "__main__":
     unittest.main()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to