This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 738acd16847b [SPARK-48648][PYTHON][CONNECT] Make 
SparkConnectClient.tags properly threadlocal
738acd16847b is described below

commit 738acd16847b7d8dc173f3fe2cf18f349fc27af9
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Jun 18 14:12:31 2024 +0900

    [SPARK-48648][PYTHON][CONNECT] Make SparkConnectClient.tags properly 
threadlocal
    
    ### What changes were proposed in this pull request?
    
    This PR changes the `thread.local` in `SparkConnectClient` to be used 
properly to fix the bug caused by https://github.com/apache/spark/pull/44210. 
It mistakenly used `thread.local` wrongly by inheriting `thread.local` and 
setting the class-level variables which always exist.
    
    ### Why are the changes needed?
    
    So users can properly use thread-based `interruptTag`. Now the code below 
cancels both queries:
    
    ```python
    import concurrent.futures
    import time
    import threading
    from pyspark.sql.functions import udf
    
    def run_query_with_tag(query, tag):
        try:
            spark.addTag(tag)
            print(f"starting query {tag}")
            df = spark.sql(query).select(udf(lambda: time.sleep(10))())
            print(f"collecting query {tag}")
            res = df.collect()
            print(f"done with query {tag}")
        finally:
            spark.removeTag(tag)
    
    queries_with_tags = [
        ("SELECT * FROM range(100)", "tag1"),
        ("SELECT * FROM range(100)", "tag2"),
    ]
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = {executor.submit(run_query_with_tag, query, tag): (query, 
tag) for query, tag in queries_with_tags}
        time.sleep(5)
        print("Interrupting tag1")
        print(spark.interruptTag("tag1"))
        for f in futures:
            try:
                f.result()
                print(f"done with {f.result()}")
            except:
                print(f"failed with {f.exception()}")
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, this was caused by https://github.com/apache/spark/pull/44210 but the 
change has not been released out.
    
    ### How was this patch tested?
    
    Unittest was added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47005 from HyukjinKwon/thread-local.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client/core.py        |  9 ++------
 python/pyspark/sql/tests/connect/test_session.py | 28 ++++++++++++++++++++++++
 2 files changed, 30 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 4c638be3b0af..f3bbab69f271 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -659,12 +659,7 @@ class SparkConnectClient(object):
         use_reattachable_execute: bool
             Enable reattachable execution.
         """
-
-        class ClientThreadLocals(threading.local):
-            tags: set = set()
-            inside_error_handling: bool = False
-
-        self.thread_local = ClientThreadLocals()
+        self.thread_local = threading.local()
 
         # Parse the connection string.
         self._builder = (
@@ -1693,7 +1688,7 @@ class SparkConnectClient(object):
         Throws the appropriate internal Python exception.
         """
 
-        if self.thread_local.inside_error_handling:
+        if getattr(self.thread_local, "inside_error_handling", False):
             # We are already inside error handling routine,
             # avoid recursive error processing (with potentially infinite 
recursion)
             raise error
diff --git a/python/pyspark/sql/tests/connect/test_session.py 
b/python/pyspark/sql/tests/connect/test_session.py
index 820f54b83327..6f0e4aaad3f8 100644
--- a/python/pyspark/sql/tests/connect/test_session.py
+++ b/python/pyspark/sql/tests/connect/test_session.py
@@ -119,6 +119,34 @@ class JobCancellationTests(ReusedConnectTestCase):
         self.assertEqual(self.spark.getTags(), set())
         self.spark.clearTags()
 
+    def test_tags_multithread(self):
+        output1 = None
+        output2 = None
+
+        def tag1():
+            nonlocal output1
+
+            self.spark.addTag("tag1")
+            output1 = self.spark.getTags()
+
+        def tag2():
+            nonlocal output2
+
+            self.spark.addTag("tag2")
+            output2 = self.spark.getTags()
+
+        t1 = threading.Thread(target=tag1)
+        t1.start()
+        t1.join()
+        t2 = threading.Thread(target=tag2)
+        t2.start()
+        t2.join()
+
+        self.assertIsNotNone(output1)
+        self.assertEquals(output1, {"tag1"})
+        self.assertIsNotNone(output2)
+        self.assertEquals(output2, {"tag2"})
+
     def test_interrupt_tag(self):
         thread_ids = range(4)
         self.check_job_cancellation(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to