This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 5b8743b9366 [BEAM-36736] Add state sampling for timer processing in 
the Python SDK (#36737)
5b8743b9366 is described below

commit 5b8743b9366ddc6e0300aa6e5ce6c2fbe1e4274f
Author: Karthik Talluri <[email protected]>
AuthorDate: Tue Nov 25 15:39:18 2025 -0800

    [BEAM-36736] Add state sampling for timer processing in the Python SDK 
(#36737)
    
    * [BEAM-36736] Add state sampling for timer processing
    
    * Force CI to rebuild
    
    * Fix error with no state found
    
    * Fix error for Regex test
    
    * Resolve linting error
    
    * Add test case to test full functionality
    
    * Fix suffix issue
    
    * Fix formatting issues using tox -e yapf-check
    
    * Add test cases to test code paths
    
    * Address comments and remove extra test case
    
    * Remove user state context variable
    
    * Adjust state duration for test to avoid flakiness
    
    * Add different tests, remove no op scoped state, and address 
formatting/lint issues
    
    * Add patch to deal with CI presubmit errors
    
    * Adjust test case to not use dofn_runner
    
    * Test case failing presubmits, attempting to fix
    
    * Fix mocking for tests and ensure all pass
    
    * Remove extra test and increase retries on the process timer tests to 
avoid flakiness
    
    * Remove upper bound restriction and reduce retries
    
    * Remove unused suffix param.
    
    ---------
    
    Co-authored-by: tvalentyn <[email protected]>
---
 .../apache_beam/runners/worker/operations.pxd      |   1 +
 .../apache_beam/runners/worker/operations.py       |  22 ++-
 .../runners/worker/statesampler_test.py            | 185 +++++++++++++++++++++
 3 files changed, 199 insertions(+), 9 deletions(-)

diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd 
b/sdks/python/apache_beam/runners/worker/operations.pxd
index f24b75a720e..52211e4d8ce 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -117,6 +117,7 @@ cdef class DoOperation(Operation):
   cdef dict timer_specs
   cdef public object input_info
   cdef object fn
+  cdef object scoped_timer_processing_state
 
 
 cdef class SdfProcessSizedElements(DoOperation):
diff --git a/sdks/python/apache_beam/runners/worker/operations.py 
b/sdks/python/apache_beam/runners/worker/operations.py
index 9f490e4ae44..d0f7cceb558 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -809,7 +809,10 @@ class DoOperation(Operation):
     self.tagged_receivers = None  # type: Optional[_TaggedReceivers]
     # A mapping of timer tags to the input "PCollections" they come in on.
     self.input_info = None  # type: Optional[OpInputInfo]
-
+    self.scoped_timer_processing_state = self.state_sampler.scoped_state(
+        self.name_context,
+        'process-timers',
+        metrics_container=self.metrics_container)
     # See fn_data in dataflow_runner.py
     # TODO: Store all the items from spec?
     self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn))
@@ -971,14 +974,15 @@ class DoOperation(Operation):
     self.user_state_context.add_timer_info(timer_family_id, timer_info)
 
   def process_timer(self, tag, timer_data):
-    timer_spec = self.timer_specs[tag]
-    self.dofn_runner.process_user_timer(
-        timer_spec,
-        timer_data.user_key,
-        timer_data.windows[0],
-        timer_data.fire_timestamp,
-        timer_data.paneinfo,
-        timer_data.dynamic_timer_tag)
+    with self.scoped_timer_processing_state:
+      timer_spec = self.timer_specs[tag]
+      self.dofn_runner.process_user_timer(
+          timer_spec,
+          timer_data.user_key,
+          timer_data.windows[0],
+          timer_data.fire_timestamp,
+          timer_data.paneinfo,
+          timer_data.dynamic_timer_tag)
 
   def finish(self):
     # type: () -> None
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py 
b/sdks/python/apache_beam/runners/worker/statesampler_test.py
index c9ea7e8eef9..0d0ce1d2c8d 100644
--- a/sdks/python/apache_beam/runners/worker/statesampler_test.py
+++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py
@@ -21,17 +21,56 @@
 import logging
 import time
 import unittest
+from unittest import mock
+from unittest.mock import Mock
+from unittest.mock import patch
 
 from tenacity import retry
 from tenacity import stop_after_attempt
 
+from apache_beam.internal import pickler
+from apache_beam.runners import common
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
 from apache_beam.runners.worker import statesampler
+from apache_beam.transforms import core
+from apache_beam.transforms import userstate
+from apache_beam.transforms.core import GlobalWindows
+from apache_beam.transforms.core import Windowing
+from apache_beam.transforms.window import GlobalWindow
 from apache_beam.utils.counters import CounterFactory
 from apache_beam.utils.counters import CounterName
+from apache_beam.utils.windowed_value import PaneInfo
 
 _LOGGER = logging.getLogger(__name__)
 
 
+class TimerDoFn(core.DoFn):
+  TIMER_SPEC = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
+
+  def __init__(self, sleep_duration_s=0):
+    self._sleep_duration_s = sleep_duration_s
+
+  @userstate.on_timer(TIMER_SPEC)
+  def on_timer_f(self):
+    if self._sleep_duration_s:
+      time.sleep(self._sleep_duration_s)
+
+
+class ExceptionTimerDoFn(core.DoFn):
+  """A DoFn that raises an exception when its timer fires."""
+  TIMER_SPEC = userstate.TimerSpec('ts-timer', userstate.TimeDomain.WATERMARK)
+
+  def __init__(self, sleep_duration_s=0):
+    self._sleep_duration_s = sleep_duration_s
+
+  @userstate.on_timer(TIMER_SPEC)
+  def on_timer_f(self):
+    if self._sleep_duration_s:
+      time.sleep(self._sleep_duration_s)
+    raise RuntimeError("Test exception from timer")
+
+
 class StateSamplerTest(unittest.TestCase):
 
   # Due to somewhat non-deterministic nature of state sampling and sleep,
@@ -127,6 +166,152 @@ class StateSamplerTest(unittest.TestCase):
     # debug mode).
     self.assertLess(overhead_us, 20.0)
 
+  @retry(reraise=True, stop=stop_after_attempt(3))
+  # Patch the problematic function to return the correct timer spec
+  @patch('apache_beam.transforms.userstate.get_dofn_specs')
+  def test_do_operation_process_timer(self, mock_get_dofn_specs):
+    fn = TimerDoFn()
+    mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])
+
+    if not statesampler.FAST_SAMPLER:
+      self.skipTest('DoOperation test requires FAST_SAMPLER')
+
+    state_duration_ms = 200
+    margin_of_error = 0.75
+
+    counter_factory = CounterFactory()
+    sampler = statesampler.StateSampler(
+        'test_do_op', counter_factory, sampling_period_ms=1)
+
+    fn_for_spec = TimerDoFn(sleep_duration_s=state_duration_ms / 1000.0)
+
+    spec = operation_specs.WorkerDoFn(
+        serialized_fn=pickler.dumps(
+            (fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
+        output_tags=[],
+        input=None,
+        side_inputs=[],
+        output_coders=[])
+
+    mock_user_state_context = mock.MagicMock()
+    op = operations.DoOperation(
+        common.NameContext('step1'),
+        spec,
+        counter_factory,
+        sampler,
+        user_state_context=mock_user_state_context)
+
+    op.setup()
+
+    timer_data = Mock()
+    timer_data.user_key = None
+    timer_data.windows = [GlobalWindow()]
+    timer_data.fire_timestamp = 0
+    timer_data.paneinfo = PaneInfo(
+        is_first=False,
+        is_last=False,
+        timing=0,
+        index=0,
+        nonspeculative_index=0)
+    timer_data.dynamic_timer_tag = ''
+
+    sampler.start()
+    op.process_timer('ts-timer', timer_data=timer_data)
+    sampler.stop()
+    sampler.commit_counters()
+
+    expected_name = CounterName(
+        'process-timers-msecs', step_name='step1', stage_name='test_do_op')
+
+    found_counter = None
+    for counter in counter_factory.get_counters():
+      if counter.name == expected_name:
+        found_counter = counter
+        break
+
+    self.assertIsNotNone(
+        found_counter, f"Expected counter '{expected_name}' to be created.")
+
+    actual_value = found_counter.value()
+    logging.info("Actual value %d", actual_value)
+    self.assertGreater(
+        actual_value, state_duration_ms * (1.0 - margin_of_error))
+
+  @retry(reraise=True, stop=stop_after_attempt(3))
+  @patch('apache_beam.runners.worker.operations.userstate.get_dofn_specs')
+  def test_do_operation_process_timer_with_exception(self, 
mock_get_dofn_specs):
+    fn = ExceptionTimerDoFn()
+    mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])
+
+    if not statesampler.FAST_SAMPLER:
+      self.skipTest('DoOperation test requires FAST_SAMPLER')
+
+    state_duration_ms = 200
+    margin_of_error = 0.50
+
+    counter_factory = CounterFactory()
+    sampler = statesampler.StateSampler(
+        'test_do_op_exception', counter_factory, sampling_period_ms=1)
+
+    fn_for_spec = ExceptionTimerDoFn(
+        sleep_duration_s=state_duration_ms / 1000.0)
+
+    spec = operation_specs.WorkerDoFn(
+        serialized_fn=pickler.dumps(
+            (fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
+        output_tags=[],
+        input=None,
+        side_inputs=[],
+        output_coders=[])
+
+    mock_user_state_context = mock.MagicMock()
+    op = operations.DoOperation(
+        common.NameContext('step1'),
+        spec,
+        counter_factory,
+        sampler,
+        user_state_context=mock_user_state_context)
+
+    op.setup()
+
+    timer_data = Mock()
+    timer_data.user_key = None
+    timer_data.windows = [GlobalWindow()]
+    timer_data.fire_timestamp = 0
+    timer_data.paneinfo = PaneInfo(
+        is_first=False,
+        is_last=False,
+        timing=0,
+        index=0,
+        nonspeculative_index=0)
+    timer_data.dynamic_timer_tag = ''
+
+    sampler.start()
+    # Assert that the expected exception is raised
+    with self.assertRaises(RuntimeError):
+      op.process_timer('ts-ts-timer', timer_data=timer_data)
+    sampler.stop()
+    sampler.commit_counters()
+
+    expected_name = CounterName(
+        'process-timers-msecs',
+        step_name='step1',
+        stage_name='test_do_op_exception')
+
+    found_counter = None
+    for counter in counter_factory.get_counters():
+      if counter.name == expected_name:
+        found_counter = counter
+        break
+
+    self.assertIsNotNone(
+        found_counter, f"Expected counter '{expected_name}' to be created.")
+
+    actual_value = found_counter.value()
+    self.assertGreater(
+        actual_value, state_duration_ms * (1.0 - margin_of_error))
+    _LOGGER.info("Exception test finished successfully.")
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to