This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e4114f67e12 [SPARK-45048][CONNECT] Add additional tests for Python 
client and attachable execution
e4114f67e12 is described below

commit e4114f67e12a235b4784fcbfa6f6ba9b44a5e715
Author: Martin Grund <[email protected]>
AuthorDate: Fri Sep 1 22:15:23 2023 +0800

    [SPARK-45048][CONNECT] Add additional tests for Python client and 
attachable execution
    
    ### What changes were proposed in this pull request?
    For better test coverage add additional tests of the attachable Spark 
Connect Python client.
    
    ### Why are the changes needed?
    Stability
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    New test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #42769 from grundprinzip/SPARK-45048.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../sql/tests/connect/client/test_client.py        | 156 ++++++++++++++++++++-
 1 file changed, 154 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 2ba42cabf84..70280c1d24a 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -17,16 +17,21 @@
 
 import unittest
 import uuid
-from typing import Optional
+from collections.abc import Generator
+from typing import Optional, Any
 
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 if should_test_connect:
+    import grpc
     import pandas as pd
     import pyarrow as pa
     from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
     from pyspark.sql.connect.client.core import Retrying
-    from pyspark.sql.connect.client.reattach import RetryException
+    from pyspark.sql.connect.client.reattach import (
+        RetryException,
+        ExecutePlanResponseReattachableIterator,
+    )
     import pyspark.sql.connect.proto as proto
 
 
@@ -119,6 +124,153 @@ class SparkConnectClientTestCase(unittest.TestCase):
         self.assertEqual(client._session_id, chan.session_id)
 
 
[email protected](not should_test_connect, connect_requirement_message)
+class SparkConnectClientReattachTestCase(unittest.TestCase):
+    def setUp(self) -> None:
+        self.request = proto.ExecutePlanRequest()
+        self.policy = {
+            "max_retries": 3,
+            "backoff_multiplier": 4.0,
+            "initial_backoff": 10,
+            "max_backoff": 10,
+            "jitter": 10,
+            "min_jitter_threshold": 10,
+        }
+        self.response = proto.ExecutePlanResponse()
+        self.finished = proto.ExecutePlanResponse(
+            result_complete=proto.ExecutePlanResponse.ResultComplete()
+        )
+
+    def _stub_with(self, execute=None, attach=None):
+        return MockSparkConnectStub(
+            execute_ops=ResponseGenerator(execute) if execute is not None else 
None,
+            attach_ops=ResponseGenerator(attach) if attach is not None else 
None,
+        )
+
+    def test_basic_flow(self):
+        stub = self._stub_with([self.response, self.finished])
+        ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.policy, [])
+        for b in ite:
+            pass
+
+        self.assertEqual(0, stub.attach_calls)
+        self.assertGreater(1, stub.release_calls)
+        self.assertEqual(1, stub.execute_calls)
+
+    def test_fail_during_execute(self):
+        def fatal():
+            raise TestException("Fatal")
+
+        stub = self._stub_with([self.response, fatal])
+        with self.assertRaises(TestException):
+            ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.policy, [])
+            for b in ite:
+                pass
+
+        self.assertEqual(0, stub.attach_calls)
+        self.assertEqual(0, stub.release_calls)
+        self.assertEqual(1, stub.execute_calls)
+
+    def test_fail_and_retry_during_execute(self):
+        def non_fatal():
+            raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE)
+
+        stub = self._stub_with(
+            [self.response, non_fatal], [self.response, self.response, 
self.finished]
+        )
+        ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.policy, [])
+        for b in ite:
+            pass
+
+        self.assertEqual(1, stub.attach_calls)
+        self.assertEqual(1, stub.release_calls)
+        self.assertEqual(1, stub.execute_calls)
+
+    def test_fail_and_retry_during_reattach(self):
+        count = 0
+
+        def non_fatal():
+            nonlocal count
+            if count < 2:
+                count += 1
+                raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE)
+            else:
+                return proto.ExecutePlanResponse()
+
+        stub = self._stub_with(
+            [self.response, non_fatal], [self.response, non_fatal, 
self.response, self.finished]
+        )
+        ite = ExecutePlanResponseReattachableIterator(self.request, stub, 
self.policy, [])
+        for b in ite:
+            pass
+
+        self.assertEqual(2, stub.attach_calls)
+        self.assertEqual(2, stub.release_calls)
+        self.assertEqual(1, stub.execute_calls)
+
+
+class TestException(grpc.RpcError, grpc.Call):
+    """Exception mock to test retryable exceptions."""
+
+    def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
+        self.msg = msg
+        self._code = code
+
+    def code(self):
+        return self._code
+
+    def __str__(self):
+        return self.msg
+
+    def trailing_metadata(self):
+        return ()
+
+
+class ResponseGenerator(Generator):
+    """This class is used to generate values that are returned by the streaming
+    iterator of the GRPC stub."""
+
+    def __init__(self, funs):
+        self._funs = funs
+        self._iterator = iter(self._funs)
+
+    def send(self, value: Any) -> proto.ExecutePlanResponse:
+        val = next(self._iterator)
+        if callable(val):
+            return val()
+        else:
+            return val
+
+    def throw(self, type: Any = None, value: Any = None, traceback: Any = 
None) -> Any:
+        super().throw(type, value, traceback)
+
+    def close(self) -> None:
+        return super().close()
+
+
+class MockSparkConnectStub:
+    """Simple mock class for the GRPC stub used by the re-attachable 
execution."""
+
+    def __init__(self, execute_ops=None, attach_ops=None):
+        self._execute_ops = execute_ops
+        self._attach_ops = attach_ops
+        # Call counters
+        self.execute_calls = 0
+        self.release_calls = 0
+        self.attach_calls = 0
+
+    def ExecutePlan(self, *args, **kwargs):
+        self.execute_calls += 1
+        return self._execute_ops
+
+    def ReattachExecute(self, *args, **kwargs):
+        self.attach_calls += 1
+        return self._attach_ops
+
+    def ReleaseExecute(self, *args, **kwargs):
+        self.release_calls += 1
+
+
 class MockService:
     # Simplest mock of the SparkConnectService.
     # If this needs more complex logic, it needs to be replaced with Python 
mocking.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to