This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new d569c2db833 [SPARK-44509][PYTHON][CONNECT] Add job cancellation API
set in Spark Connect Python client
d569c2db833 is described below
commit d569c2db833b2d63b8a29bd7af202ee788232fd1
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Jul 25 11:33:12 2023 +0900
[SPARK-44509][PYTHON][CONNECT] Add job cancellation API set in Spark
Connect Python client
### What changes were proposed in this pull request?
This PR proposes the Python implementations for
https://github.com/apache/spark/pull/42009.
### Why are the changes needed?
For the feature parity, and better control of query cancelation in Spark
Connect
### Does this PR introduce _any_ user-facing change?
Yes. New Apis in Spark Connect Python client:
```
SparkSession.addTag
SparkSession.removeTag
SparkSession.getTags
SparkSession.clearTags
SparkSession.interruptTag
SparkSession.interruptOperation
```
### How was this patch tested?
Unittests were added, and manually tested too.
Closes #42120 from HyukjinKwon/SPARK-44509.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 5ee462d4cb60d4159abe36093f61ec0e5c749826)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 6 +-
.../spark/sql/connect/service/ExecuteHolder.scala | 6 +-
.../source/reference/pyspark.sql/spark_session.rst | 7 ++
python/pyspark/sql/connect/client/core.py | 104 +++++++++++++++---
python/pyspark/sql/connect/session.py | 43 +++++++-
python/pyspark/sql/session.py | 122 +++++++++++++++++++++
python/pyspark/sql/tests/connect/test_session.py | 111 ++++++++++++++++++-
7 files changed, 377 insertions(+), 22 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 161b5a0217e..5a0f33ffd5d 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -615,7 +615,7 @@ class SparkSession private[sql] (
* Interrupt all operations of this session currently running on the
connected server.
*
* @return
- * sequence of operationIds of interrupted operations. Note: there is
still a possiblility of
+ * sequence of operationIds of interrupted operations. Note: there is
still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
@@ -628,7 +628,7 @@ class SparkSession private[sql] (
* Interrupt all operations of this session with the given operation tag.
*
* @return
- * sequence of operationIds of interrupted operations. Note: there is
still a possiblility of
+ * sequence of operationIds of interrupted operations. Note: there is
still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
@@ -641,7 +641,7 @@ class SparkSession private[sql] (
* Interrupt an operation of this session with the given operationId.
*
* @return
- * sequence of operationIds of interrupted operations. Note: there is
still a possiblility of
+ * sequence of operationIds of interrupted operations. Note: there is
still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index 74530ad032f..36c96b2617f 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -42,7 +42,7 @@ private[connect] class ExecuteHolder(
s"SparkConnect_Execute_" +
s"User_${sessionHolder.userId}_" +
s"Session_${sessionHolder.sessionId}_" +
- s"Request_${operationId}"
+ s"Operation_${operationId}"
/**
* Tags set by Spark Connect client users via SparkSession.addTag. Used to
identify and group
@@ -118,7 +118,7 @@ private[connect] class ExecuteHolder(
* need to be combined with userId and sessionId.
*/
def tagToSparkJobTag(tag: String): String = {
- "SparkConnect_Tag_" +
- s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}"
+ "SparkConnect_Execute_" +
+
s"User_${sessionHolder.userId}_Session_${sessionHolder.sessionId}_Tag_${tag}"
}
}
diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 9867a9cd121..a5f8bc47d44 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -63,3 +63,10 @@ Spark Connect Only
SparkSession.addArtifacts
SparkSession.copyFromLocalToFs
SparkSession.client
+ SparkSession.interruptAll
+ SparkSession.interruptTag
+ SparkSession.interruptOperation
+ SparkSession.addTag
+ SparkSession.removeTag
+ SparkSession.getTags
+ SparkSession.clearTags
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 56236892122..482482123c0 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -23,6 +23,7 @@ from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
+import threading
import logging
import os
import platform
@@ -41,6 +42,7 @@ from typing import (
List,
Tuple,
Dict,
+ Set,
NoReturn,
cast,
Callable,
@@ -574,6 +576,8 @@ class SparkConnectClient(object):
the $USER environment. Defining the user ID as part of the
connection string
takes precedence.
"""
+ self.thread_local = threading.local()
+
# Parse the connection string.
self._builder = (
connection
@@ -922,9 +926,11 @@ class SparkConnectClient(object):
return self._builder._token
def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
- req = pb2.ExecutePlanRequest()
- req.session_id = self._session_id
- req.client_type = self._builder.userAgent
+ req = pb2.ExecutePlanRequest(
+ session_id=self._session_id,
+ client_type=self._builder.userAgent,
+ tags=list(self.get_tags()),
+ )
if self._user_id:
req.user_context.user_id = self._user_id
return req
@@ -1243,12 +1249,22 @@ class SparkConnectClient(object):
except Exception as error:
self._handle_error(error)
- def _interrupt_request(self, interrupt_type: str) -> pb2.InterruptRequest:
+ def _interrupt_request(
+ self, interrupt_type: str, id_or_tag: Optional[str] = None
+ ) -> pb2.InterruptRequest:
req = pb2.InterruptRequest()
req.session_id = self._session_id
req.client_type = self._builder.userAgent
if interrupt_type == "all":
req.interrupt_type =
pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL
+ elif interrupt_type == "tag":
+ assert id_or_tag is not None
+ req.interrupt_type =
pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG
+ req.operation_tag = id_or_tag
+ elif interrupt_type == "operation":
+ assert id_or_tag is not None
+ req.interrupt_type =
pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID
+ req.operation_id = id_or_tag
else:
raise PySparkValueError(
error_class="UNKNOWN_INTERRUPT_TYPE",
@@ -1260,14 +1276,7 @@ class SparkConnectClient(object):
req.user_context.user_id = self._user_id
return req
- def interrupt_all(self) -> None:
- """
- Call the interrupt RPC of Spark Connect to interrupt all executions in
this session.
-
- Returns
- -------
- None
- """
+ def interrupt_all(self) -> Optional[List[str]]:
req = self._interrupt_request("all")
try:
for attempt in Retrying(
@@ -1280,11 +1289,80 @@ class SparkConnectClient(object):
"Received incorrect session identifier for
request:"
f"{resp.session_id} != {self._session_id}"
)
- return
+ return list(resp.interrupted_ids)
+ raise SparkConnectException("Invalid state during retry exception
handling.")
+ except Exception as error:
+ self._handle_error(error)
+
+ def interrupt_tag(self, tag: str) -> Optional[List[str]]:
+ req = self._interrupt_request("tag", tag)
+ try:
+ for attempt in Retrying(
+ can_retry=SparkConnectClient.retry_exception,
**self._retry_policy
+ ):
+ with attempt:
+ resp = self._stub.Interrupt(req,
metadata=self._builder.metadata())
+ if resp.session_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request:"
+ f"{resp.session_id} != {self._session_id}"
+ )
+ return list(resp.interrupted_ids)
raise SparkConnectException("Invalid state during retry exception
handling.")
except Exception as error:
self._handle_error(error)
+ def interrupt_operation(self, op_id: str) -> Optional[List[str]]:
+ req = self._interrupt_request("operation", op_id)
+ try:
+ for attempt in Retrying(
+ can_retry=SparkConnectClient.retry_exception,
**self._retry_policy
+ ):
+ with attempt:
+ resp = self._stub.Interrupt(req,
metadata=self._builder.metadata())
+ if resp.session_id != self._session_id:
+ raise SparkConnectException(
+ "Received incorrect session identifier for
request:"
+ f"{resp.session_id} != {self._session_id}"
+ )
+ return list(resp.interrupted_ids)
+ raise SparkConnectException("Invalid state during retry exception
handling.")
+ except Exception as error:
+ self._handle_error(error)
+
+ def add_tag(self, tag: str) -> None:
+ self._throw_if_invalid_tag(tag)
+ if not hasattr(self.thread_local, "tags"):
+ self.thread_local.tags = set()
+ self.thread_local.tags.add(tag)
+
+ def remove_tag(self, tag: str) -> None:
+ self._throw_if_invalid_tag(tag)
+ if not hasattr(self.thread_local, "tags"):
+ self.thread_local.tags = set()
+ self.thread_local.tags.remove(tag)
+
+ def get_tags(self) -> Set[str]:
+ if not hasattr(self.thread_local, "tags"):
+ self.thread_local.tags = set()
+ return self.thread_local.tags
+
+ def clear_tags(self) -> None:
+ self.thread_local.tags = set()
+
+ def _throw_if_invalid_tag(self, tag: str) -> None:
+ """
+ Validate if a tag for ExecutePlanRequest.tags is valid. Throw
``ValueError`` if
+ not.
+ """
+ spark_job_tags_sep = ","
+ if tag is None:
+ raise ValueError("Spark Connect tag cannot be null.")
+ if spark_job_tags_sep in tag:
+ raise ValueError(f"Spark Connect tag cannot contain
'{spark_job_tags_sep}'.")
+ if len(tag) == 0:
+ raise ValueError("Spark Connect tag cannot be an empty string.")
+
def _handle_error(self, error: Exception) -> NoReturn:
"""
Handle errors that occur during RPC calls.
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index a49e4cdd0f4..8cd39ba7a79 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -31,6 +31,7 @@ from typing import (
Dict,
List,
Tuple,
+ Set,
cast,
overload,
Iterable,
@@ -550,8 +551,46 @@ class SparkSession:
except Exception:
pass
- def interrupt_all(self) -> None:
- self.client.interrupt_all()
+ def interruptAll(self) -> List[str]:
+ op_ids = self.client.interrupt_all()
+ assert op_ids is not None
+ return op_ids
+
+ interruptAll.__doc__ = PySparkSession.interruptAll.__doc__
+
+ def interruptTag(self, tag: str) -> List[str]:
+ op_ids = self.client.interrupt_tag(tag)
+ assert op_ids is not None
+ return op_ids
+
+ interruptTag.__doc__ = PySparkSession.interruptTag.__doc__
+
+ def interruptOperation(self, op_id: str) -> List[str]:
+ op_ids = self.client.interrupt_operation(op_id)
+ assert op_ids is not None
+ return op_ids
+
+ interruptOperation.__doc__ = PySparkSession.interruptOperation.__doc__
+
+ def addTag(self, tag: str) -> None:
+ self.client.add_tag(tag)
+
+ addTag.__doc__ = PySparkSession.addTag.__doc__
+
+ def removeTag(self, tag: str) -> None:
+ self.client.remove_tag(tag)
+
+ removeTag.__doc__ = PySparkSession.removeTag.__doc__
+
+ def getTags(self) -> Set[str]:
+ return self.client.get_tags()
+
+ getTags.__doc__ = PySparkSession.getTags.__doc__
+
+ def clearTags(self) -> None:
+ return self.client.clear_tags()
+
+ clearTags.__doc__ = PySparkSession.clearTags.__doc__
def stop(self) -> None:
# Stopping the session will only close the connection to the current
session (and
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 00a0047dfd1..834b0307238 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -31,6 +31,7 @@ from typing import (
Tuple,
Type,
Union,
+ Set,
cast,
no_type_check,
overload,
@@ -1858,6 +1859,127 @@ class SparkSession(SparkConversionMixin):
"however, the current Spark session does not use Spark Connect."
)
+ def interruptAll(self) -> List[str]:
+ """
+ Interrupt all operations of this session currently running on the
connected server.
+
+ .. versionadded:: 3.5.0
+
+ Returns
+ -------
+ list of str
+ List of operationIds of interrupted operations.
+
+ Notes
+ -----
+ There is still a possibility of operation finishing just as it is
interrupted.
+ """
+ raise RuntimeError(
+ "SparkSession.interruptAll is only supported with Spark Connect; "
+ "however, the current Spark session does not use Spark Connect."
+ )
+
+ def interruptTag(self, tag: str) -> List[str]:
+ """
+ Interrupt all operations of this session with the given operation tag.
+
+ .. versionadded:: 3.5.0
+
+ Returns
+ -------
+ list of str
+ List of operationIds of interrupted operations.
+
+ Notes
+ -----
+ There is still a possibility of operation finishing just as it is
interrupted.
+ """
+ raise RuntimeError(
+ "SparkSession.interruptTag is only supported with Spark Connect; "
+ "however, the current Spark session does not use Spark Connect."
+ )
+
+ def interruptOperation(self, op_id: str) -> List[str]:
+ """
+ Interrupt an operation of this session with the given operationId.
+
+ .. versionadded:: 3.5.0
+
+ Returns
+ -------
+ list of str
+ List of operationIds of interrupted operations.
+
+ Notes
+ -----
+ There is still a possibility of operation finishing just as it is
interrupted.
+ """
+ raise RuntimeError(
+ "SparkSession.interruptOperation is only supported with Spark
Connect; "
+ "however, the current Spark session does not use Spark Connect."
+ )
+
+ def addTag(self, tag: str) -> None:
+ """
+ Add a tag to be assigned to all the operations started by this thread
in this session.
+
+ .. versionadded:: 3.5.0
+
+ Parameters
+ ----------
+ tag : list of str
+ The tag to be added. Cannot contain ',' (comma) character or be an
empty string.
+ """
+ raise RuntimeError(
+ "SparkSession.addTag is only supported with Spark Connect; "
+ "however, the current Spark session does not use Spark Connect."
+ )
+
+ def removeTag(self, tag: str) -> None:
+ """
+ Remove a tag previously added to be assigned to all the operations
started by this thread in
+ this session. Noop if such a tag was not added earlier.
+
+ .. versionadded:: 3.5.0
+
+ Parameters
+ ----------
+ tag : list of str
+ The tag to be removed. Cannot contain ',' (comma) character or be
an empty string.
+ """
+ raise RuntimeError(
+ "SparkSession.removeTag is only supported with Spark Connect; "
+ "however, the current Spark session does not use Spark Connect."
+ )
+
+ def getTags(self) -> Set[str]:
+ """
+ Get the tags that are currently set to be assigned to all the
operations started by this
+ thread.
+
+ .. versionadded:: 3.5.0
+
+ Returns
+ -------
+ set of str
+ Set of tags of interrupted operations.
+ """
+ raise RuntimeError(
+ "SparkSession.getTags is only supported with Spark Connect; "
+ "however, the current Spark session does not use Spark Connect."
+ )
+
+ def clearTags(self) -> None:
+ """
+ Clear the current thread's operation tags.
+
+ .. versionadded:: 3.5.0
+ """
+ raise RuntimeError(
+ "SparkSession.clearTags is only supported with Spark Connect; "
+ "however, the current Spark session does not use Spark Connect."
+ )
+
def _test() -> None:
import os
diff --git a/python/pyspark/sql/tests/connect/test_session.py
b/python/pyspark/sql/tests/connect/test_session.py
index 2f14eeddc1e..0482f119d63 100644
--- a/python/pyspark/sql/tests/connect/test_session.py
+++ b/python/pyspark/sql/tests/connect/test_session.py
@@ -14,12 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import threading
+import time
import unittest
from typing import Optional
from pyspark.sql.connect.client import ChannelBuilder
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
+from pyspark.testing.connectutils import should_test_connect
+
+if should_test_connect:
+ from pyspark.testing.connectutils import ReusedConnectTestCase
class CustomChannelBuilder(ChannelBuilder):
@@ -70,3 +75,107 @@ class SparkSessionTestCase(unittest.TestCase):
self.assertIs(session, session2)
session.stop()
+
+
+class ArrowParityTests(ReusedConnectTestCase):
+ def test_tags(self):
+ self.spark.clearTags()
+ self.spark.addTag("a")
+ self.assertEqual(self.spark.getTags(), {"a"})
+ self.spark.addTag("b")
+ self.spark.removeTag("a")
+ self.assertEqual(self.spark.getTags(), {"b"})
+ self.spark.addTag("c")
+ self.spark.clearTags()
+ self.assertEqual(self.spark.getTags(), set())
+ self.spark.clearTags()
+
+ def test_interrupt_tag(self):
+ thread_ids = range(4)
+ self.check_job_cancellation(
+ lambda job_group: self.spark.addTag(job_group),
+ lambda job_group: self.spark.interruptTag(job_group),
+ thread_ids,
+ [i for i in thread_ids if i % 2 == 0],
+ [i for i in thread_ids if i % 2 != 0],
+ )
+ self.spark.clearTags()
+
+ def test_interrupt_all(self):
+ thread_ids = range(4)
+ self.check_job_cancellation(
+ lambda job_group: None,
+ lambda job_group: self.spark.interruptAll(),
+ thread_ids,
+ thread_ids,
+ [],
+ )
+ self.spark.clearTags()
+
+ def check_job_cancellation(
+ self, setter, canceller, thread_ids, thread_ids_to_cancel,
thread_ids_to_run
+ ):
+
+ job_id_a = "job_ids_to_cancel"
+ job_id_b = "job_ids_to_run"
+ threads = []
+
+ # A list which records whether job is cancelled.
+ # The index of the array is the thread index which job run in.
+ is_job_cancelled = [False for _ in thread_ids]
+
+ def run_job(job_id, index):
+ """
+ Executes a job with the group ``job_group``. Each job waits for 3
seconds
+ and then exits.
+ """
+ try:
+ setter(job_id)
+
+ def func(itr):
+ for pdf in itr:
+ time.sleep(pdf._1.iloc[0])
+ yield pdf
+
+ self.spark.createDataFrame([[20]]).repartition(1).mapInPandas(
+ func, schema="_1 LONG"
+ ).collect()
+ is_job_cancelled[index] = False
+ except Exception:
+ # Assume that exception means job cancellation.
+ is_job_cancelled[index] = True
+
+ # Test if job succeeded when not cancelled.
+ run_job(job_id_a, 0)
+ self.assertFalse(is_job_cancelled[0])
+ self.spark.clearTags()
+
+ # Run jobs
+ for i in thread_ids_to_cancel:
+ t = threading.Thread(target=run_job, args=(job_id_a, i))
+ t.start()
+ threads.append(t)
+
+ for i in thread_ids_to_run:
+ t = threading.Thread(target=run_job, args=(job_id_b, i))
+ t.start()
+ threads.append(t)
+
+ # Wait to make sure all jobs are executed.
+ time.sleep(10)
+ # And then, cancel one job group.
+ canceller(job_id_a)
+
+ # Wait until all threads launching jobs are finished.
+ for t in threads:
+ t.join()
+
+ for i in thread_ids_to_cancel:
+ self.assertTrue(
+ is_job_cancelled[i], "Thread {i}: Job in group A was not
cancelled.".format(i=i)
+ )
+
+ for i in thread_ids_to_run:
+ self.assertFalse(
+ is_job_cancelled[i], "Thread {i}: Job in group B did not
succeeded.".format(i=i)
+ )
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]