damccorm commented on code in PR #29175:
URL: https://github.com/apache/beam/pull/29175#discussion_r1378198383


##########
sdks/python/apache_beam/transforms/util.py:
##########
@@ -646,6 +647,113 @@ def finish_bundle(self):
     self._target_batch_size = self._batch_size_estimator.next_batch_size()
 
 
+def _pardo_stateful_batch_elements(
+    input_coder: coders.Coder,
+    batch_size_estimator: _BatchSizeEstimator,
+    max_buffering_duration_secs: int,
+    clock=time.time):
+  ELEMENT_STATE = BagStateSpec('values', input_coder)
+  COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn())
+  BATCH_SIZE_STATE = ReadModifyWriteStateSpec('batch_size', input_coder)
+  WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK)
+  BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME)
+  BATCH_ESTIMATOR_STATE = ReadModifyWriteStateSpec(
+      'batch_estimator', coders.PickleCoder())
+
+  class _StatefulBatchElementsDoFn(DoFn):
+    def process(
+        self,
+        element,
+        window=DoFn.WindowParam,
+        element_state=DoFn.StateParam(ELEMENT_STATE),
+        count_state=DoFn.StateParam(COUNT_STATE),
+        batch_size_state=DoFn.StateParam(BATCH_SIZE_STATE),
+        batch_estimator_state=DoFn.StateParam(BATCH_ESTIMATOR_STATE),
+        window_timer=DoFn.TimerParam(WINDOW_TIMER),
+        buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
+      state_estimator = batch_estimator_state.read()
+      if state_estimator is not None:
+        batch_estimator = state_estimator
+      else:
+        # Should only happen on the first element
+        batch_estimator = batch_size_estimator

Review Comment:
   Can we move this read block into the `if target_size is None` block? That 
way we're only doing the state read if we need it



##########
sdks/python/apache_beam/transforms/util.py:
##########
@@ -712,10 +824,20 @@ def __init__(
         clock=clock,
         record_metrics=record_metrics)
     self._element_size_fn = element_size_fn
+    self._max_batch_dur = max_batch_duration_secs
+    self._clock = clock
 
   def expand(self, pcoll):
     if getattr(pcoll.pipeline.runner, 'is_streaming', False):
       raise NotImplementedError("Requires stateful processing (BEAM-2687)")
+    elif self._max_batch_dur is not None:
+      coder = coders.registry.get_coder(pcoll)
+      return pcoll | WithKeys(0) | ParDo(

Review Comment:
   Non-blocking for this PR, but something we may want to consider; rather than 
using a single fixed key, does it make sense to try to have a single key per 
worker somehow? (one way to do this would be using multi_process_shared.py)
   
   That way we're still batching per machine in a parallelizable way, but we 
get stateful batching across bundles.
   
   The current implementation is likely still useful for many use cases where 
batching is not the expensive part (e.g. RunInference) or there are few workers.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to