This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 7aff3635 Use explicit timer in unit test (#1338)
7aff3635 is described below
commit 7aff3635c93d5897d470642928c39c86e7851931
Author: Tim Saucer <[email protected]>
AuthorDate: Mon Jan 12 07:22:52 2026 -0500
Use explicit timer in unit test (#1338)
* Use an explicit wait in a dataframe query during testing to check for
keyboard interrupts
* Add interrupt check when spawning futures
* Update unit test to do four variantions of fast/slow queries and
interrupt either collect or stream
---
python/tests/test_dataframe.py | 199 +++++++----------------------------------
src/utils.rs | 5 ++
2 files changed, 38 insertions(+), 166 deletions(-)
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index f481f31f..30f9ab90 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -37,6 +37,7 @@ from datafusion import (
WindowFrame,
column,
literal,
+ udf,
)
from datafusion import (
col as df_col,
@@ -3190,179 +3191,42 @@ def test_fill_null_all_null_column(ctx):
assert result.column(1).to_pylist() == ["filled", "filled", "filled"]
-def test_collect_interrupted():
- """Test that a long-running query can be interrupted with Ctrl-C.
+@udf([pa.int64()], pa.int64(), "immutable")
+def slow_udf(x: pa.Array) -> pa.Array:
+ # This must be longer than the check interval in wait_for_future
+ time.sleep(2.0)
+ return x
- This test simulates a Ctrl-C keyboard interrupt by raising a
KeyboardInterrupt
- exception in the main thread during a long-running query execution.
- """
- # Create a context and a DataFrame with a query that will run for a while
- ctx = SessionContext()
-
- # Create a recursive computation that will run for some time
- batches = []
- for i in range(10):
- batch = pa.RecordBatch.from_arrays(
- [
- pa.array(list(range(i * 1000, (i + 1) * 1000))),
- pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) *
1000)]),
- ],
- names=["a", "b"],
- )
- batches.append(batch)
-
- # Register tables
- ctx.register_record_batches("t1", [batches])
- ctx.register_record_batches("t2", [batches])
-
- # Create a large join operation that will take time to process
- df = ctx.sql("""
- WITH t1_expanded AS (
- SELECT
- a,
- b,
- CAST(a AS DOUBLE) / 1.5 AS c,
- CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
- FROM t1
- CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
- ),
- t2_expanded AS (
- SELECT
- a,
- b,
- CAST(a AS DOUBLE) * 2.5 AS e,
- CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
- FROM t2
- CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
- )
- SELECT
- t1.a, t1.b, t1.c, t1.d,
- t2.a AS a2, t2.b AS b2, t2.e, t2.f
- FROM t1_expanded t1
- JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
- WHERE t1.a > 100 AND t2.a > 100
- """)
-
- # Flag to track if the query was interrupted
- interrupted = False
- interrupt_error = None
- main_thread = threading.main_thread()
-
- # Shared flag to indicate query execution has started
- query_started = threading.Event()
- max_wait_time = 5.0 # Maximum wait time in seconds
-
- # This function will be run in a separate thread and will raise
- # KeyboardInterrupt in the main thread
- def trigger_interrupt():
- """Poll for query start, then raise KeyboardInterrupt in the main
thread"""
- # Poll for query to start with small sleep intervals
- start_time = time.time()
- while not query_started.is_set():
- time.sleep(0.1) # Small sleep between checks
- if time.time() - start_time > max_wait_time:
- msg = f"Query did not start within {max_wait_time} seconds"
- raise RuntimeError(msg)
-
- # Check if thread ID is available
- thread_id = main_thread.ident
- if thread_id is None:
- msg = "Cannot get main thread ID"
- raise RuntimeError(msg)
-
- # Use ctypes to raise exception in main thread
- exception = ctypes.py_object(KeyboardInterrupt)
- res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
- ctypes.c_long(thread_id), exception
- )
- if res != 1:
- # If res is 0, the thread ID was invalid
- # If res > 1, we modified multiple threads
- ctypes.pythonapi.PyThreadState_SetAsyncExc(
- ctypes.c_long(thread_id), ctypes.py_object(0)
- )
- msg = "Failed to raise KeyboardInterrupt in main thread"
- raise RuntimeError(msg)
-
- # Start a thread to trigger the interrupt
- interrupt_thread = threading.Thread(target=trigger_interrupt)
- # we mark as daemon so the test process can exit even if this thread
doesn't finish
- interrupt_thread.daemon = True
- interrupt_thread.start()
-
- # Execute the query and expect it to be interrupted
- try:
- # Signal that we're about to start the query
- query_started.set()
- df.collect()
- except KeyboardInterrupt:
- interrupted = True
- except Exception as e:
- interrupt_error = e
-
- # Assert that the query was interrupted properly
- if not interrupted:
- pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")
-
- # Make sure the interrupt thread has finished
- interrupt_thread.join(timeout=1.0)
[email protected](
+ ("slow_query", "as_c_stream"),
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+)
+def test_collect_or_stream_interrupted(slow_query, as_c_stream): # noqa: C901
PLR0915
+ """Ensure collection responds to ``KeyboardInterrupt`` signals.
-def test_arrow_c_stream_interrupted(): # noqa: C901 PLR0915
- """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
+ This test issues a long-running query, and consumes the results via
+ either a collect() call or ``__arrow_c_stream__``. It raises
+ ``KeyboardInterrupt`` in the main thread and verifies that the
+ process has interrupted.
- Similar to ``test_collect_interrupted`` this test issues a long running
- query, but consumes the results via ``__arrow_c_stream__``. It then raises
- ``KeyboardInterrupt`` in the main thread and verifies that the stream
- iteration stops promptly with the appropriate exception.
+ The `slow_query` determines if the query itself is slow via a
+ UDF with a timeout or if it is a fast query that generates many
+ results so it takes a long time to iterate through them all.
"""
ctx = SessionContext()
+ df = ctx.sql("select * from generate_series(1, 1000000000000000000)")
+ if slow_query:
+ df = ctx.from_pydict({"a": [1, 2, 3]}).select(slow_udf(column("a")))
- batches = []
- for i in range(10):
- batch = pa.RecordBatch.from_arrays(
- [
- pa.array(list(range(i * 1000, (i + 1) * 1000))),
- pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) *
1000)]),
- ],
- names=["a", "b"],
- )
- batches.append(batch)
-
- ctx.register_record_batches("t1", [batches])
- ctx.register_record_batches("t2", [batches])
-
- df = ctx.sql(
- """
- WITH t1_expanded AS (
- SELECT
- a,
- b,
- CAST(a AS DOUBLE) / 1.5 AS c,
- CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
- FROM t1
- CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
- ),
- t2_expanded AS (
- SELECT
- a,
- b,
- CAST(a AS DOUBLE) * 2.5 AS e,
- CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
- FROM t2
- CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
- )
- SELECT
- t1.a, t1.b, t1.c, t1.d,
- t2.a AS a2, t2.b AS b2, t2.e, t2.f
- FROM t1_expanded t1
- JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
- WHERE t1.a > 100 AND t2.a > 100
- """
- )
-
- reader = pa.RecordBatchReader.from_stream(df)
+ if as_c_stream:
+ reader = pa.RecordBatchReader.from_stream(df)
read_started = threading.Event()
read_exception = []
@@ -3396,7 +3260,10 @@ def test_arrow_c_stream_interrupted(): # noqa: C901
PLR0915
read_thread_id = threading.get_ident()
try:
read_started.set()
- reader.read_all()
+ if as_c_stream:
+ reader.read_all()
+ else:
+ df.collect()
# If we get here, the read completed without interruption
read_exception.append(RuntimeError("Read completed without
interruption"))
except KeyboardInterrupt:
diff --git a/src/utils.rs b/src/utils.rs
index cbc3d6d9..6038c77b 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -77,6 +77,11 @@ where
let runtime: &Runtime = &get_tokio_runtime().0;
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
+ // Some fast running processes that generate many `wait_for_future` calls
like
+ // PartitionedDataFrameStreamReader::next require checking for interrupts
early
+ py.run(cr"pass", None, None)?;
+ py.check_signals()?;
+
py.detach(|| {
runtime.block_on(async {
tokio::pin!(fut);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]