This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 7aff3635 Use explicit timer in unit test (#1338)
7aff3635 is described below

commit 7aff3635c93d5897d470642928c39c86e7851931
Author: Tim Saucer <[email protected]>
AuthorDate: Mon Jan 12 07:22:52 2026 -0500

    Use explicit timer in unit test (#1338)
    
    * Use an explicit wait in a dataframe query during testing to check for 
keyboard interrupts
    
    * Add interrupt check when spawning futures
    
    * Update unit test to do four variantions of fast/slow queries and 
interrupt either collect or stream
---
 python/tests/test_dataframe.py | 199 +++++++----------------------------------
 src/utils.rs                   |   5 ++
 2 files changed, 38 insertions(+), 166 deletions(-)

diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index f481f31f..30f9ab90 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -37,6 +37,7 @@ from datafusion import (
     WindowFrame,
     column,
     literal,
+    udf,
 )
 from datafusion import (
     col as df_col,
@@ -3190,179 +3191,42 @@ def test_fill_null_all_null_column(ctx):
     assert result.column(1).to_pylist() == ["filled", "filled", "filled"]
 
 
-def test_collect_interrupted():
-    """Test that a long-running query can be interrupted with Ctrl-C.
+@udf([pa.int64()], pa.int64(), "immutable")
+def slow_udf(x: pa.Array) -> pa.Array:
+    # This must be longer than the check interval in wait_for_future
+    time.sleep(2.0)
+    return x
 
-    This test simulates a Ctrl-C keyboard interrupt by raising a 
KeyboardInterrupt
-    exception in the main thread during a long-running query execution.
-    """
-    # Create a context and a DataFrame with a query that will run for a while
-    ctx = SessionContext()
-
-    # Create a recursive computation that will run for some time
-    batches = []
-    for i in range(10):
-        batch = pa.RecordBatch.from_arrays(
-            [
-                pa.array(list(range(i * 1000, (i + 1) * 1000))),
-                pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 
1000)]),
-            ],
-            names=["a", "b"],
-        )
-        batches.append(batch)
-
-    # Register tables
-    ctx.register_record_batches("t1", [batches])
-    ctx.register_record_batches("t2", [batches])
-
-    # Create a large join operation that will take time to process
-    df = ctx.sql("""
-        WITH t1_expanded AS (
-            SELECT
-                a,
-                b,
-                CAST(a AS DOUBLE) / 1.5 AS c,
-                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
-            FROM t1
-            CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
-        ),
-        t2_expanded AS (
-            SELECT
-                a,
-                b,
-                CAST(a AS DOUBLE) * 2.5 AS e,
-                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
-            FROM t2
-            CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
-        )
-        SELECT
-            t1.a, t1.b, t1.c, t1.d,
-            t2.a AS a2, t2.b AS b2, t2.e, t2.f
-        FROM t1_expanded t1
-        JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
-        WHERE t1.a > 100 AND t2.a > 100
-    """)
-
-    # Flag to track if the query was interrupted
-    interrupted = False
-    interrupt_error = None
-    main_thread = threading.main_thread()
-
-    # Shared flag to indicate query execution has started
-    query_started = threading.Event()
-    max_wait_time = 5.0  # Maximum wait time in seconds
-
-    # This function will be run in a separate thread and will raise
-    # KeyboardInterrupt in the main thread
-    def trigger_interrupt():
-        """Poll for query start, then raise KeyboardInterrupt in the main 
thread"""
-        # Poll for query to start with small sleep intervals
-        start_time = time.time()
-        while not query_started.is_set():
-            time.sleep(0.1)  # Small sleep between checks
-            if time.time() - start_time > max_wait_time:
-                msg = f"Query did not start within {max_wait_time} seconds"
-                raise RuntimeError(msg)
-
-        # Check if thread ID is available
-        thread_id = main_thread.ident
-        if thread_id is None:
-            msg = "Cannot get main thread ID"
-            raise RuntimeError(msg)
-
-        # Use ctypes to raise exception in main thread
-        exception = ctypes.py_object(KeyboardInterrupt)
-        res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
-            ctypes.c_long(thread_id), exception
-        )
-        if res != 1:
-            # If res is 0, the thread ID was invalid
-            # If res > 1, we modified multiple threads
-            ctypes.pythonapi.PyThreadState_SetAsyncExc(
-                ctypes.c_long(thread_id), ctypes.py_object(0)
-            )
-            msg = "Failed to raise KeyboardInterrupt in main thread"
-            raise RuntimeError(msg)
-
-    # Start a thread to trigger the interrupt
-    interrupt_thread = threading.Thread(target=trigger_interrupt)
-    # we mark as daemon so the test process can exit even if this thread 
doesn't finish
-    interrupt_thread.daemon = True
-    interrupt_thread.start()
-
-    # Execute the query and expect it to be interrupted
-    try:
-        # Signal that we're about to start the query
-        query_started.set()
-        df.collect()
-    except KeyboardInterrupt:
-        interrupted = True
-    except Exception as e:
-        interrupt_error = e
-
-    # Assert that the query was interrupted properly
-    if not interrupted:
-        pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")
-
-    # Make sure the interrupt thread has finished
-    interrupt_thread.join(timeout=1.0)
 
[email protected](
+    ("slow_query", "as_c_stream"),
+    [
+        (True, True),
+        (True, False),
+        (False, True),
+        (False, False),
+    ],
+)
+def test_collect_or_stream_interrupted(slow_query, as_c_stream):  # noqa: C901 
PLR0915
+    """Ensure collection responds to ``KeyboardInterrupt`` signals.
 
-def test_arrow_c_stream_interrupted():  # noqa: C901 PLR0915
-    """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
+    This test issues a long-running query, and consumes the results via
+    either a collect() call or ``__arrow_c_stream__``. It raises
+    ``KeyboardInterrupt`` in the main thread and verifies that the
+    process has interrupted.
 
-    Similar to ``test_collect_interrupted`` this test issues a long running
-    query, but consumes the results via ``__arrow_c_stream__``. It then raises
-    ``KeyboardInterrupt`` in the main thread and verifies that the stream
-    iteration stops promptly with the appropriate exception.
+    The `slow_query` determines if the query itself is slow via a
+    UDF with a timeout or if it is a fast query that generates many
+    results so it takes a long time to iterate through them all.
     """
 
     ctx = SessionContext()
+    df = ctx.sql("select * from generate_series(1, 1000000000000000000)")
+    if slow_query:
+        df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a")))
 
-    batches = []
-    for i in range(10):
-        batch = pa.RecordBatch.from_arrays(
-            [
-                pa.array(list(range(i * 1000, (i + 1) * 1000))),
-                pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 
1000)]),
-            ],
-            names=["a", "b"],
-        )
-        batches.append(batch)
-
-    ctx.register_record_batches("t1", [batches])
-    ctx.register_record_batches("t2", [batches])
-
-    df = ctx.sql(
-        """
-        WITH t1_expanded AS (
-            SELECT
-                a,
-                b,
-                CAST(a AS DOUBLE) / 1.5 AS c,
-                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
-            FROM t1
-            CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
-        ),
-        t2_expanded AS (
-            SELECT
-                a,
-                b,
-                CAST(a AS DOUBLE) * 2.5 AS e,
-                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
-            FROM t2
-            CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
-        )
-        SELECT
-            t1.a, t1.b, t1.c, t1.d,
-            t2.a AS a2, t2.b AS b2, t2.e, t2.f
-        FROM t1_expanded t1
-        JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
-        WHERE t1.a > 100 AND t2.a > 100
-        """
-    )
-
-    reader = pa.RecordBatchReader.from_stream(df)
+    if as_c_stream:
+        reader = pa.RecordBatchReader.from_stream(df)
 
     read_started = threading.Event()
     read_exception = []
@@ -3396,7 +3260,10 @@ def test_arrow_c_stream_interrupted():  # noqa: C901 
PLR0915
         read_thread_id = threading.get_ident()
         try:
             read_started.set()
-            reader.read_all()
+            if as_c_stream:
+                reader.read_all()
+            else:
+                df.collect()
             # If we get here, the read completed without interruption
             read_exception.append(RuntimeError("Read completed without 
interruption"))
         except KeyboardInterrupt:
diff --git a/src/utils.rs b/src/utils.rs
index cbc3d6d9..6038c77b 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -77,6 +77,11 @@ where
     let runtime: &Runtime = &get_tokio_runtime().0;
     const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
 
+    // Some fast running processes that generate many `wait_for_future` calls 
like
+    // PartitionedDataFrameStreamReader::next require checking for interrupts 
early
+    py.run(cr"pass", None, None)?;
+    py.check_signals()?;
+
     py.detach(|| {
         runtime.block_on(async {
             tokio::pin!(fut);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to