This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new d569c2db833 [SPARK-44509][PYTHON][CONNECT] Add job cancellation API 
set in Spark Connect Python client
d569c2db833 is described below

commit d569c2db833b2d63b8a29bd7af202ee788232fd1
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Jul 25 11:33:12 2023 +0900

    [SPARK-44509][PYTHON][CONNECT] Add job cancellation API set in Spark 
Connect Python client
    
    ### What changes were proposed in this pull request?
    
    This PR proposes the Python implementations for 
https://github.com/apache/spark/pull/42009.
    
    ### Why are the changes needed?
    
    For the feature parity, and better control of query cancelation in Spark 
Connect
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. New Apis in Spark Connect Python client:
    
    ```
    SparkSession.addTag
    SparkSession.removeTag
    SparkSession.getTags
    SparkSession.clearTags
    SparkSession.interruptTag
    SparkSession.interruptOperation
    ```
    
    ### How was this patch tested?
    
    Unittests were added, and manually tested too.
    
    Closes #42120 from HyukjinKwon/SPARK-44509.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 5ee462d4cb60d4159abe36093f61ec0e5c749826)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |   6 +-
 .../spark/sql/connect/service/ExecuteHolder.scala  |   6 +-
 .../source/reference/pyspark.sql/spark_session.rst |   7 ++
 python/pyspark/sql/connect/client/core.py          | 104 +++++++++++++++---
 python/pyspark/sql/connect/session.py              |  43 +++++++-
 python/pyspark/sql/session.py                      | 122 +++++++++++++++++++++
 python/pyspark/sql/tests/connect/test_session.py   | 111 ++++++++++++++++++-
 7 files changed, 377 insertions(+), 22 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 161b5a0217e..5a0f33ffd5d 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -615,7 +615,7 @@ class SparkSession private[sql] (
    * Interrupt all operations of this session currently running on the 
connected server.
    *
    * @return
-   *   sequence of operationIds of interrupted operations. Note: there is 
still a possiblility of
+   *   sequence of operationIds of interrupted operations. Note: there is 
still a possibility of
    *   operation finishing just as it is interrupted.
    *
    * @since 3.5.0
@@ -628,7 +628,7 @@ class SparkSession private[sql] (
    * Interrupt all operations of this session with the given operation tag.
    *
    * @return
-   *   sequence of operationIds of interrupted operations. Note: there is 
still a possiblility of
+   *   sequence of operationIds of interrupted operations. Note: there is 
still a possibility of
    *   operation finishing just as it is interrupted.
    *
    * @since 3.5.0
@@ -641,7 +641,7 @@ class SparkSession private[sql] (
    * Interrupt an operation of this session with the given operationId.
    *
    * @return
-   *   sequence of operationIds of interrupted operations. Note: there is 
still a possiblility of
+   *   sequence of operationIds of interrupted operations. Note: there is 
still a possibility of
    *   operation finishing just as it is interrupted.
    *
    * @since 3.5.0
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index 74530ad032f..36c96b2617f 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -42,7 +42,7 @@ private[connect] class ExecuteHolder(
     s"SparkConnect_Execute_" +
       s"User_${sessionHolder.userId}_" +
       s"Session_${sessionHolder.sessionId}_" +
-      s"Request_${operationId}"
+      s"Operation_${operationId}"
 
   /**
    * Tags set by Spark Connect client users via SparkSession.addTag. Used to 
identify and group
@@ -118,7 +118,7 @@ private[connect] class ExecuteHolder(
    * need to be combined with userId and sessionId.
    */
   def tagToSparkJobTag(tag: String): String = {
-    "SparkConnect_Tag_" +
-      s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}"
+    "SparkConnect_Execute_" +
+      
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Tag_${tag}"
   }
 }
diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst 
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 9867a9cd121..a5f8bc47d44 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -63,3 +63,10 @@ Spark Connect Only
     SparkSession.addArtifacts
     SparkSession.copyFromLocalToFs
     SparkSession.client
+    SparkSession.interruptAll
+    SparkSession.interruptTag
+    SparkSession.interruptOperation
+    SparkSession.addTag
+    SparkSession.removeTag
+    SparkSession.getTags
+    SparkSession.clearTags
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 56236892122..482482123c0 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -23,6 +23,7 @@ from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
 
+import threading
 import logging
 import os
 import platform
@@ -41,6 +42,7 @@ from typing import (
     List,
     Tuple,
     Dict,
+    Set,
     NoReturn,
     cast,
     Callable,
@@ -574,6 +576,8 @@ class SparkConnectClient(object):
             the $USER environment. Defining the user ID as part of the 
connection string
             takes precedence.
         """
+        self.thread_local = threading.local()
+
         # Parse the connection string.
         self._builder = (
             connection
@@ -922,9 +926,11 @@ class SparkConnectClient(object):
         return self._builder._token
 
     def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
-        req = pb2.ExecutePlanRequest()
-        req.session_id = self._session_id
-        req.client_type = self._builder.userAgent
+        req = pb2.ExecutePlanRequest(
+            session_id=self._session_id,
+            client_type=self._builder.userAgent,
+            tags=list(self.get_tags()),
+        )
         if self._user_id:
             req.user_context.user_id = self._user_id
         return req
@@ -1243,12 +1249,22 @@ class SparkConnectClient(object):
         except Exception as error:
             self._handle_error(error)
 
-    def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest:
+    def _interrupt_request(
+        self, interrupt_type: str, id_or_tag: Optional[str] = None
+    ) -> pb2.InterruptRequest:
         req = pb2.InterruptRequest()
         req.session_id = self._session_id
         req.client_type = self._builder.userAgent
         if interrupt_type == "all":
             req.interrupt_type = 
pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL
+        elif interrupt_type == "tag":
+            assert id_or_tag is not None
+            req.interrupt_type = 
pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG
+            req.operation_tag = id_or_tag
+        elif interrupt_type == "operation":
+            assert id_or_tag is not None
+            req.interrupt_type = 
pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID
+            req.operation_id = id_or_tag
         else:
             raise PySparkValueError(
                 error_class="UNKNOWN_INTERRUPT_TYPE",
@@ -1260,14 +1276,7 @@ class SparkConnectClient(object):
             req.user_context.user_id = self._user_id
         return req
 
-    def interrupt_all(self) -> None:
-        """
-        Call the interrupt RPC of Spark Connect to interrupt all executions in 
this session.
-
-        Returns
-        -------
-        None
-        """
+    def interrupt_all(self) -> Optional[List[str]]:
         req = self._interrupt_request("all")
         try:
             for attempt in Retrying(
@@ -1280,11 +1289,80 @@ class SparkConnectClient(object):
                             "Received incorrect session identifier for 
request:"
                             f"{resp.session_id} != {self._session_id}"
                         )
-                    return
+                    return list(resp.interrupted_ids)
+            raise SparkConnectException("Invalid state during retry exception 
handling.")
+        except Exception as error:
+            self._handle_error(error)
+
+    def interrupt_tag(self, tag: str) -> Optional[List[str]]:
+        req = self._interrupt_request("tag", tag)
+        try:
+            for attempt in Retrying(
+                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+            ):
+                with attempt:
+                    resp = self._stub.Interrupt(req, 
metadata=self._builder.metadata())
+                    if resp.session_id != self._session_id:
+                        raise SparkConnectException(
+                            "Received incorrect session identifier for 
request:"
+                            f"{resp.session_id} != {self._session_id}"
+                        )
+                    return list(resp.interrupted_ids)
             raise SparkConnectException("Invalid state during retry exception 
handling.")
         except Exception as error:
             self._handle_error(error)
 
+    def interrupt_operation(self, op_id: str) -> Optional[List[str]]:
+        req = self._interrupt_request("operation", op_id)
+        try:
+            for attempt in Retrying(
+                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
+            ):
+                with attempt:
+                    resp = self._stub.Interrupt(req, 
metadata=self._builder.metadata())
+                    if resp.session_id != self._session_id:
+                        raise SparkConnectException(
+                            "Received incorrect session identifier for 
request:"
+                            f"{resp.session_id} != {self._session_id}"
+                        )
+                    return list(resp.interrupted_ids)
+            raise SparkConnectException("Invalid state during retry exception 
handling.")
+        except Exception as error:
+            self._handle_error(error)
+
+    def add_tag(self, tag: str) -> None:
+        self._throw_if_invalid_tag(tag)
+        if not hasattr(self.thread_local, "tags"):
+            self.thread_local.tags = set()
+        self.thread_local.tags.add(tag)
+
+    def remove_tag(self, tag: str) -> None:
+        self._throw_if_invalid_tag(tag)
+        if not hasattr(self.thread_local, "tags"):
+            self.thread_local.tags = set()
+        self.thread_local.tags.remove(tag)
+
+    def get_tags(self) -> Set[str]:
+        if not hasattr(self.thread_local, "tags"):
+            self.thread_local.tags = set()
+        return self.thread_local.tags
+
+    def clear_tags(self) -> None:
+        self.thread_local.tags = set()
+
+    def _throw_if_invalid_tag(self, tag: str) -> None:
+        """
+        Validate if a tag for ExecutePlanRequest.tags is valid. Throw 
``ValueError`` if
+        not.
+        """
+        spark_job_tags_sep = ","
+        if tag is None:
+            raise ValueError("Spark Connect tag cannot be null.")
+        if spark_job_tags_sep in tag:
+            raise ValueError(f"Spark Connect tag cannot contain 
'{spark_job_tags_sep}'.")
+        if len(tag) == 0:
+            raise ValueError("Spark Connect tag cannot be an empty string.")
+
     def _handle_error(self, error: Exception) -> NoReturn:
         """
         Handle errors that occur during RPC calls.
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index a49e4cdd0f4..8cd39ba7a79 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -31,6 +31,7 @@ from typing import (
     Dict,
     List,
     Tuple,
+    Set,
     cast,
     overload,
     Iterable,
@@ -550,8 +551,46 @@ class SparkSession:
         except Exception:
             pass
 
-    def interrupt_all(self) -> None:
-        self.client.interrupt_all()
+    def interruptAll(self) -> List[str]:
+        op_ids = self.client.interrupt_all()
+        assert op_ids is not None
+        return op_ids
+
+    interruptAll.__doc__ = PySparkSession.interruptAll.__doc__
+
+    def interruptTag(self, tag: str) -> List[str]:
+        op_ids = self.client.interrupt_tag(tag)
+        assert op_ids is not None
+        return op_ids
+
+    interruptTag.__doc__ = PySparkSession.interruptTag.__doc__
+
+    def interruptOperation(self, op_id: str) -> List[str]:
+        op_ids = self.client.interrupt_operation(op_id)
+        assert op_ids is not None
+        return op_ids
+
+    interruptOperation.__doc__ = PySparkSession.interruptOperation.__doc__
+
+    def addTag(self, tag: str) -> None:
+        self.client.add_tag(tag)
+
+    addTag.__doc__ = PySparkSession.addTag.__doc__
+
+    def removeTag(self, tag: str) -> None:
+        self.client.remove_tag(tag)
+
+    removeTag.__doc__ = PySparkSession.removeTag.__doc__
+
+    def getTags(self) -> Set[str]:
+        return self.client.get_tags()
+
+    getTags.__doc__ = PySparkSession.getTags.__doc__
+
+    def clearTags(self) -> None:
+        return self.client.clear_tags()
+
+    clearTags.__doc__ = PySparkSession.clearTags.__doc__
 
     def stop(self) -> None:
         # Stopping the session will only close the connection to the current 
session (and
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 00a0047dfd1..834b0307238 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -31,6 +31,7 @@ from typing import (
     Tuple,
     Type,
     Union,
+    Set,
     cast,
     no_type_check,
     overload,
@@ -1858,6 +1859,127 @@ class SparkSession(SparkConversionMixin):
             "however, the current Spark session does not use Spark Connect."
         )
 
+    def interruptAll(self) -> List[str]:
+        """
+        Interrupt all operations of this session currently running on the 
connected server.
+
+        .. versionadded:: 3.5.0
+
+        Returns
+        -------
+        list of str
+            List of operationIds of interrupted operations.
+
+        Notes
+        -----
+        There is still a possibility of operation finishing just as it is 
interrupted.
+        """
+        raise RuntimeError(
+            "SparkSession.interruptAll is only supported with Spark Connect; "
+            "however, the current Spark session does not use Spark Connect."
+        )
+
+    def interruptTag(self, tag: str) -> List[str]:
+        """
+        Interrupt all operations of this session with the given operation tag.
+
+        .. versionadded:: 3.5.0
+
+        Returns
+        -------
+        list of str
+            List of operationIds of interrupted operations.
+
+        Notes
+        -----
+        There is still a possibility of operation finishing just as it is 
interrupted.
+        """
+        raise RuntimeError(
+            "SparkSession.interruptTag is only supported with Spark Connect; "
+            "however, the current Spark session does not use Spark Connect."
+        )
+
+    def interruptOperation(self, op_id: str) -> List[str]:
+        """
+        Interrupt an operation of this session with the given operationId.
+
+        .. versionadded:: 3.5.0
+
+        Returns
+        -------
+        list of str
+            List of operationIds of interrupted operations.
+
+        Notes
+        -----
+        There is still a possibility of operation finishing just as it is 
interrupted.
+        """
+        raise RuntimeError(
+            "SparkSession.interruptOperation is only supported with Spark 
Connect; "
+            "however, the current Spark session does not use Spark Connect."
+        )
+
+    def addTag(self, tag: str) -> None:
+        """
+        Add a tag to be assigned to all the operations started by this thread 
in this session.
+
+        .. versionadded:: 3.5.0
+
+        Parameters
+        ----------
+        tag : list of str
+            The tag to be added. Cannot contain ',' (comma) character or be an 
empty string.
+        """
+        raise RuntimeError(
+            "SparkSession.addTag is only supported with Spark Connect; "
+            "however, the current Spark session does not use Spark Connect."
+        )
+
+    def removeTag(self, tag: str) -> None:
+        """
+        Remove a tag previously added to be assigned to all the operations 
started by this thread in
+        this session. Noop if such a tag was not added earlier.
+
+        .. versionadded:: 3.5.0
+
+        Parameters
+        ----------
+        tag : list of str
+            The tag to be removed. Cannot contain ',' (comma) character or be 
an empty string.
+        """
+        raise RuntimeError(
+            "SparkSession.removeTag is only supported with Spark Connect; "
+            "however, the current Spark session does not use Spark Connect."
+        )
+
+    def getTags(self) -> Set[str]:
+        """
+        Get the tags that are currently set to be assigned to all the 
operations started by this
+        thread.
+
+        .. versionadded:: 3.5.0
+
+        Returns
+        -------
+        set of str
+            Set of tags of interrupted operations.
+        """
+        raise RuntimeError(
+            "SparkSession.getTags is only supported with Spark Connect; "
+            "however, the current Spark session does not use Spark Connect."
+        )
+
+    def clearTags(self) -> None:
+        """
+        Clear the current thread's operation tags.
+
+        .. versionadded:: 3.5.0
+        """
+        raise RuntimeError(
+            "SparkSession.clearTags is only supported with Spark Connect; "
+            "however, the current Spark session does not use Spark Connect."
+        )
+
 
 def _test() -> None:
     import os
diff --git a/python/pyspark/sql/tests/connect/test_session.py 
b/python/pyspark/sql/tests/connect/test_session.py
index 2f14eeddc1e..0482f119d63 100644
--- a/python/pyspark/sql/tests/connect/test_session.py
+++ b/python/pyspark/sql/tests/connect/test_session.py
@@ -14,12 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import threading
+import time
 import unittest
 from typing import Optional
 
 from pyspark.sql.connect.client import ChannelBuilder
 from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
+from pyspark.testing.connectutils import should_test_connect
+
+if should_test_connect:
+    from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
 class CustomChannelBuilder(ChannelBuilder):
@@ -70,3 +75,107 @@ class SparkSessionTestCase(unittest.TestCase):
 
         self.assertIs(session, session2)
         session.stop()
+
+
+class ArrowParityTests(ReusedConnectTestCase):
+    def test_tags(self):
+        self.spark.clearTags()
+        self.spark.addTag("a")
+        self.assertEqual(self.spark.getTags(), {"a"})
+        self.spark.addTag("b")
+        self.spark.removeTag("a")
+        self.assertEqual(self.spark.getTags(), {"b"})
+        self.spark.addTag("c")
+        self.spark.clearTags()
+        self.assertEqual(self.spark.getTags(), set())
+        self.spark.clearTags()
+
+    def test_interrupt_tag(self):
+        thread_ids = range(4)
+        self.check_job_cancellation(
+            lambda job_group: self.spark.addTag(job_group),
+            lambda job_group: self.spark.interruptTag(job_group),
+            thread_ids,
+            [i for i in thread_ids if i % 2 == 0],
+            [i for i in thread_ids if i % 2 != 0],
+        )
+        self.spark.clearTags()
+
+    def test_interrupt_all(self):
+        thread_ids = range(4)
+        self.check_job_cancellation(
+            lambda job_group: None,
+            lambda job_group: self.spark.interruptAll(),
+            thread_ids,
+            thread_ids,
+            [],
+        )
+        self.spark.clearTags()
+
+    def check_job_cancellation(
+        self, setter, canceller, thread_ids, thread_ids_to_cancel, 
thread_ids_to_run
+    ):
+
+        job_id_a = "job_ids_to_cancel"
+        job_id_b = "job_ids_to_run"
+        threads = []
+
+        # A list which records whether job is cancelled.
+        # The index of the array is the thread index which job run in.
+        is_job_cancelled = [False for _ in thread_ids]
+
+        def run_job(job_id, index):
+            """
+            Executes a job with the group ``job_group``. Each job waits for 3 
seconds
+            and then exits.
+            """
+            try:
+                setter(job_id)
+
+                def func(itr):
+                    for pdf in itr:
+                        time.sleep(pdf._1.iloc[0])
+                        yield pdf
+
+                self.spark.createDataFrame([[20]]).repartition(1).mapInPandas(
+                    func, schema="_1 LONG"
+                ).collect()
+                is_job_cancelled[index] = False
+            except Exception:
+                # Assume that exception means job cancellation.
+                is_job_cancelled[index] = True
+
+        # Test if job succeeded when not cancelled.
+        run_job(job_id_a, 0)
+        self.assertFalse(is_job_cancelled[0])
+        self.spark.clearTags()
+
+        # Run jobs
+        for i in thread_ids_to_cancel:
+            t = threading.Thread(target=run_job, args=(job_id_a, i))
+            t.start()
+            threads.append(t)
+
+        for i in thread_ids_to_run:
+            t = threading.Thread(target=run_job, args=(job_id_b, i))
+            t.start()
+            threads.append(t)
+
+        # Wait to make sure all jobs are executed.
+        time.sleep(10)
+        # And then, cancel one job group.
+        canceller(job_id_a)
+
+        # Wait until all threads launching jobs are finished.
+        for t in threads:
+            t.join()
+
+        for i in thread_ids_to_cancel:
+            self.assertTrue(
+                is_job_cancelled[i], "Thread {i}: Job in group A was not 
cancelled.".format(i=i)
+            )
+
+        for i in thread_ids_to_run:
+            self.assertFalse(
+                is_job_cancelled[i], "Thread {i}: Job in group B did not 
succeeded.".format(i=i)
+            )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to