This is an automated email from the ASF dual-hosted git repository.
derrickaw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ea2fde7a020 Fix race condition in statesampler_fast.pyx (#38851)
ea2fde7a020 is described below
commit ea2fde7a020afa3d5d2f7d79a9d2668a4c1441fb
Author: Derrick Williams <[email protected]>
AuthorDate: Tue Jun 9 17:04:26 2026 -0400
Fix race condition in statesampler_fast.pyx (#38851)
* fix race condition in reads
* rename test case - doc string
* move import to top
* address gemini
* address gemini comments
* address gemini comments
---
.../runners/worker/statesampler_fast.pyx | 8 +++-
.../runners/worker/statesampler_test.py | 43 ++++++++++++++++++++++
2 files changed, 50 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
index 45700a0b0f8..7075ef47017 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx
@@ -217,7 +217,13 @@ cdef class ScopedState(object):
@property
def nsecs(self):
- return self._nsecs
+ cdef pythread.PyThread_type_lock lock = self.sampler.lock
+ cdef int64_t val
+ with nogil:
+ pythread.PyThread_acquire_lock(lock, pythread.WAIT_LOCK)
+ val = self._nsecs
+ pythread.PyThread_release_lock(lock)
+ return val
def sampled_seconds(self):
return 1e-9 * self.nsecs
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py
b/sdks/python/apache_beam/runners/worker/statesampler_test.py
index 0d0ce1d2c8d..0495dc507da 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_test.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py
@@ -19,6 +19,7 @@
# pytype: skip-file
import logging
+import threading
import time
import unittest
from unittest import mock
@@ -312,6 +313,48 @@ class StateSamplerTest(unittest.TestCase):
actual_value, state_duration_ms * (1.0 - margin_of_error))
_LOGGER.info("Exception test finished successfully.")
+ def test_concurrent_nsecs_reads(self):
+ """Verify that concurrent reads of nsecs behave correctly under thread
contention.
+
+ This test runs state transitions on the main thread and reads `nsecs`
properties
+ from a secondary Python thread, while the background sampler thread is
concurrently
+ updating counter states.
+ """
+ if not statesampler.FAST_SAMPLER:
+ self.skipTest('test_concurrent_nsecs_reads requires FAST_SAMPLER')
+
+ counter_factory = CounterFactory()
+ sampler = statesampler.StateSampler(
+ 'concurrent', counter_factory, sampling_period_ms=1)
+
+ sampler.start()
+ reader_thread = None
+ try:
+ state_a = sampler.scoped_state('step1', 'statea')
+ state_b = sampler.scoped_state('step1', 'stateb')
+
+ stop_signal = False
+
+ def read_nsecs_loop():
+ while not stop_signal:
+ _ = state_a.nsecs
+ _ = state_b.nsecs
+ time.sleep(0.001)
+
+ reader_thread = threading.Thread(target=read_nsecs_loop)
+ reader_thread.start()
+
+ for _ in range(100):
+ with state_a:
+ time.sleep(0.001)
+ with state_b:
+ time.sleep(0.001)
+ finally:
+ if reader_thread is not None:
+ stop_signal = True
+ reader_thread.join()
+ sampler.stop()
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)