This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 738acd16847b [SPARK-48648][PYTHON][CONNECT] Make
SparkConnectClient.tags properly threadlocal
738acd16847b is described below
commit 738acd16847b7d8dc173f3fe2cf18f349fc27af9
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Jun 18 14:12:31 2024 +0900
[SPARK-48648][PYTHON][CONNECT] Make SparkConnectClient.tags properly
threadlocal
### What changes were proposed in this pull request?
This PR changes the `thread.local` in `SparkConnectClient` to be used
properly to fix the bug caused by https://github.com/apache/spark/pull/44210.
It mistakenly used `thread.local` wrongly by inheriting `thread.local` and
setting the class-level variables which always exist.
### Why are the changes needed?
So users can properly use thread-based `interruptTag`. Now the code below
cancels both queries:
```python
import concurrent.futures
import time
import threading
from pyspark.sql.functions import udf
def run_query_with_tag(query, tag):
try:
spark.addTag(tag)
print(f"starting query {tag}")
df = spark.sql(query).select(udf(lambda: time.sleep(10))())
print(f"collecting query {tag}")
res = df.collect()
print(f"done with query {tag}")
finally:
spark.removeTag(tag)
queries_with_tags = [
("SELECT * FROM range(100)", "tag1"),
("SELECT * FROM range(100)", "tag2"),
]
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {executor.submit(run_query_with_tag, query, tag): (query,
tag) for query, tag in queries_with_tags}
time.sleep(5)
print("Interrupting tag1")
print(spark.interruptTag("tag1"))
for f in futures:
try:
f.result()
print(f"done with {f.result()}")
except:
print(f"failed with {f.exception()}")
```
### Does this PR introduce _any_ user-facing change?
No, this was caused by https://github.com/apache/spark/pull/44210 but the
change has not been released out.
### How was this patch tested?
Unittest was added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47005 from HyukjinKwon/thread-local.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 9 ++------
python/pyspark/sql/tests/connect/test_session.py | 28 ++++++++++++++++++++++++
2 files changed, 30 insertions(+), 7 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 4c638be3b0af..f3bbab69f271 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -659,12 +659,7 @@ class SparkConnectClient(object):
use_reattachable_execute: bool
Enable reattachable execution.
"""
-
- class ClientThreadLocals(threading.local):
- tags: set = set()
- inside_error_handling: bool = False
-
- self.thread_local = ClientThreadLocals()
+ self.thread_local = threading.local()
# Parse the connection string.
self._builder = (
@@ -1693,7 +1688,7 @@ class SparkConnectClient(object):
Throws the appropriate internal Python exception.
"""
- if self.thread_local.inside_error_handling:
+ if getattr(self.thread_local, "inside_error_handling", False):
# We are already inside error handling routine,
# avoid recursive error processing (with potentially infinite
recursion)
raise error
diff --git a/python/pyspark/sql/tests/connect/test_session.py
b/python/pyspark/sql/tests/connect/test_session.py
index 820f54b83327..6f0e4aaad3f8 100644
--- a/python/pyspark/sql/tests/connect/test_session.py
+++ b/python/pyspark/sql/tests/connect/test_session.py
@@ -119,6 +119,34 @@ class JobCancellationTests(ReusedConnectTestCase):
self.assertEqual(self.spark.getTags(), set())
self.spark.clearTags()
+ def test_tags_multithread(self):
+ output1 = None
+ output2 = None
+
+ def tag1():
+ nonlocal output1
+
+ self.spark.addTag("tag1")
+ output1 = self.spark.getTags()
+
+ def tag2():
+ nonlocal output2
+
+ self.spark.addTag("tag2")
+ output2 = self.spark.getTags()
+
+ t1 = threading.Thread(target=tag1)
+ t1.start()
+ t1.join()
+ t2 = threading.Thread(target=tag2)
+ t2.start()
+ t2.join()
+
+ self.assertIsNotNone(output1)
+ self.assertEquals(output1, {"tag1"})
+ self.assertIsNotNone(output2)
+ self.assertEquals(output2, {"tag2"})
+
def test_interrupt_tag(self):
thread_ids = range(4)
self.check_job_cancellation(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]