[ 
https://issues.apache.org/jira/browse/BEAM-14294?focusedWorklogId=765226&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-765226
 ]

ASF GitHub Bot logged work on BEAM-14294:
-----------------------------------------

                Author: ASF GitHub Bot
            Created on: 03/May/22 00:12
            Start Date: 03/May/22 00:12
    Worklog Time Spent: 10m 
      Work Description: robertwb commented on code in PR #17384:
URL: https://github.com/apache/beam/pull/17384#discussion_r863270883


##########
sdks/python/apache_beam/runners/worker/operations.py:
##########
@@ -223,6 +235,120 @@ def current_element_progress(self):
     return self.consumer.current_element_progress()
 
 
+class GeneralPurposeConsumerSet(ConsumerSet):
+  """ConsumerSet implementation that handles all combinations of possible 
edges.
+  """
+  def __init__(self,
+               counter_factory,
+               step_name,  # type: str
+               output_index,
+               coder,
+               producer_type_hints,
+               consumers,  # type: List[Operation]
+               producer_batch_converter):
+    super().__init__(
+        counter_factory,
+        step_name,
+        output_index,
+        consumers,
+        coder,
+        producer_type_hints)
+
+    self.producer_batch_converter = producer_batch_converter
+
+    # Partition consumers into three groups:
+    # - consumers that will be passed elements
+    # - consumers that will be passed batches (where their input batch type
+    #   matches the output of the producer)
+    # - consumers that will be passed converted batches
+    self.element_consumers: List[Operation] = []
+    self.passthrough_batch_consumers: List[Operation] = []
+    other_batch_consumers: DefaultDict[
+        BatchConverter, List[Operation]] = collections.defaultdict(lambda: [])
+
+    for consumer in consumers:
+      if not consumer.get_batching_preference().supports_batches:
+        self.element_consumers.append(consumer)
+      elif (consumer.get_input_batch_converter() ==
+            self.producer_batch_converter):
+        self.passthrough_batch_consumers.append(consumer)
+      else:
+        # Batch consumer with a mismatched batch type
+        if consumer.get_batching_preference().supports_elements:
+          # Pass it elements if we can
+          self.element_consumers.append(consumer)
+        else:
+          # As a last resort, explode and rebatch
+          consumer_batch_converter = consumer.get_input_batch_converter()
+          # This consumer supports batches, it must have a batch converter
+          assert consumer_batch_converter is not None
+          other_batch_consumers[consumer_batch_converter].append(consumer)
+
+    self.other_batch_consumers: Dict[BatchConverter, List[Operation]] = dict(
+        other_batch_consumers)
+
+    self.has_batch_consumers = (
+        self.passthrough_batch_consumers or self.other_batch_consumers)
+    self._batched_elements: List[Any] = []
+
+  def receive(self, windowed_value):
+    # type: (WindowedValue) -> None
+    self.update_counters_start(windowed_value)
+
+    for consumer in self.element_consumers:
+      cython.cast(Operation, consumer).process(windowed_value)
+
+    # TODO: Do this branching when contstructing ConsumerSet
+    if self.has_batch_consumers:
+      self._batched_elements.append(windowed_value)
+
+    self.update_counters_finish()
+
+  def receive_batch(self, windowed_batch):
+    #self.update_counters_start(windowed_value)
+    if self.element_consumers:
+      for wv in windowed_batch.as_windowed_values(
+          self.producer_batch_converter.explode_batch):
+        for consumer in self.element_consumers:
+          cython.cast(Operation, consumer).process(wv)
+
+    for consumer in self.passthrough_batch_consumers:
+      cython.cast(Operation, consumer).process_batch(windowed_batch)
+
+    for (consumer_batch_converter,
+         consumers) in self.other_batch_consumers.items():
+      # Explode and rebatch into the new batch type (ouch!)

Review Comment:
   No, it will be logged once per instance (which could be a lot, hundreds of 
times in streaming). 
   
   Ideally we could just name the two operations involved. I wouldn't 
log(error), just log(warning). It could be actionable if someone expects the 
batch converters to line up and they don't. 



##########
sdks/python/apache_beam/utils/windowed_value.py:
##########
@@ -279,6 +293,208 @@ def create(value, timestamp_micros, windows, 
pane_info=PANE_INFO_UNKNOWN):
   return wv
 
 
+class BatchingMode(Enum):
+  CONCRETE = 1
+  HOMOGENEOUS = 2
+
+
+class WindowedBatch(object):
+  """A batch of N windowed values, each having a value, a timestamp and set of
+  windows."""
+  def with_values(self, new_values):
+    # type: (Any) -> WindowedBatch
+
+    """Creates a new WindowedBatch with the same timestamps and windows as 
this.
+
+    This is the fasted way to create a new WindowedValue.
+    """
+    raise NotImplementedError
+
+  def as_windowed_values(self, explode_fn: Callable) -> 
Iterable[WindowedValue]:
+    raise NotImplementedError
+
+  @staticmethod
+  def from_windowed_values(
+      windowed_values: Sequence[WindowedValue],
+      *,
+      produce_fn: Callable,
+      mode: BatchingMode = BatchingMode.CONCRETE) -> Iterable['WindowedBatch']:
+    if mode == BatchingMode.HOMOGENEOUS:

Review Comment:
   +1



##########
sdks/python/apache_beam/runners/common.py:
##########
@@ -1361,10 +1580,80 @@ def process_outputs(
         self.main_receivers.receive(windowed_value)
       else:
         self.tagged_receivers[tag].receive(windowed_value)
+
+    # TODO(BEAM-3937): Remove if block after output counter released.
+    # Only enable per_element_output_counter when counter cythonized
+    if self.per_element_output_counter is not None:
+      self.per_element_output_counter.add_input(output_element_count)
+
+  def process_batch_outputs(
+      self, windowed_input_batch, results, watermark_estimator=None):
+    # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> 
None
+
+    """Dispatch the result of process computation to the appropriate receivers.
+
+    A value wrapped in a TaggedOutput object will be unwrapped and
+    then dispatched to the appropriate indexed output.
+    """
+    if results is None:
+      # TODO(BEAM-3937): Remove if block after output counter released.
+      # Only enable per_element_output_counter when counter cythonized.
+      if self.per_element_output_counter is not None:
+        self.per_element_output_counter.add_input(0)
+      return
+
+    # TODO(BEAM-10782): Verify that the results object is a valid iterable type
+    #  if performance_runtime_type_check is active, without harming performance
+
+    output_element_count = 0
+    for result in results:
+      # results here may be a generator, which cannot call len on it.
+      output_element_count += 1
+      tag = None
+      if isinstance(result, TaggedOutput):
+        tag = result.tag
+        if not isinstance(tag, str):
+          raise TypeError('In %s, tag %s is not a string' % (self, tag))
+        result = result.value
+      if isinstance(result, (WindowedValue, TimestampedValue)):
+        raise TypeError(
+            f"Received {type(result).__name__} from DoFn that was "
+            "expected to produce a batch.")
+      if isinstance(result, WindowedBatch):
+        if isinstance(result, ConcreteWindowedBatch):
+          # TODO: Rebatch into homogenous batches (or remove
+          # ConcreteWindowedBatch)
+          raise NotImplementedError
+        elif isinstance(result, HomogeneousWindowedBatch):
+          windowed_batch = result
+        else:
+          raise AssertionError(
+              "Unrecognized WindowedBatch implementation: "
+              f"{type(windowed_batch)}")
+
+        if (windowed_input_batch is not None and
+            len(windowed_input_batch.windows) != 1):
+          windowed_batch.windows *= len(windowed_input_batch.windows)
+      # TODO(BEAM-14292): Add TimestampedBatch, an analogue for 
TimestampedValue

Review Comment:
   Maybe. I'd actually rather deprecate emitting TimestampedValues and have a 
separate AssignTimestamps operation. 



##########
sdks/python/apache_beam/runners/common.py:
##########
@@ -1200,6 +1404,13 @@ def process(self, windowed_value):
       self._reraise_augmented(exn)
       return []
 
+  def process_batch(self, windowed_batch):

Review Comment:
   Here (or elsewhere) it seems we should document that this will only be 
called with the kind of batch that this DoFn supports, right?



##########
sdks/python/apache_beam/utils/windowed_value.py:
##########
@@ -279,6 +293,208 @@ def create(value, timestamp_micros, windows, 
pane_info=PANE_INFO_UNKNOWN):
   return wv
 
 
+class BatchingMode(Enum):
+  CONCRETE = 1
+  HOMOGENEOUS = 2
+
+
+class WindowedBatch(object):
+  """A batch of N windowed values, each having a value, a timestamp and set of
+  windows."""
+  def with_values(self, new_values):
+    # type: (Any) -> WindowedBatch
+
+    """Creates a new WindowedBatch with the same timestamps and windows as 
this.
+
+    This is the fasted way to create a new WindowedValue.
+    """
+    raise NotImplementedError
+
+  def as_windowed_values(self, explode_fn: Callable) -> 
Iterable[WindowedValue]:
+    raise NotImplementedError
+
+  @staticmethod
+  def from_windowed_values(
+      windowed_values: Sequence[WindowedValue],
+      *,
+      produce_fn: Callable,
+      mode: BatchingMode = BatchingMode.CONCRETE) -> Iterable['WindowedBatch']:
+    if mode == BatchingMode.HOMOGENEOUS:
+      import collections
+      grouped = collections.defaultdict(lambda: [])
+      for wv in windowed_values:
+        grouped[(wv.timestamp, tuple(wv.windows),
+                 wv.pane_info)].append(wv.value)
+
+      for key, values in grouped.items():
+        timestamp, windows, pane_info = key
+        yield HomogeneousWindowedBatch.of(
+            produce_fn(values), timestamp, windows, pane_info)
+    elif mode == BatchingMode.CONCRETE:
+      yield ConcreteWindowedBatch(
+          produce_fn([wv.value for wv in windowed_values]),
+          [wv.timestamp
+           for wv in windowed_values], [wv.windows for wv in windowed_values],
+          [wv.pane_info for wv in windowed_values])
+    else:
+      raise AssertionError(
+          "Unrecognized BatchingMode in "
+          f"WindowedBatch.from_windowed_values: {mode!r}")
+
+
+class HomogeneousWindowedBatch(WindowedBatch):
+  """A WindowedBatch with Homogeneous event-time information, represented
+  internally as a WindowedValue.
+  """
+  def __init__(self, wv):
+    self._wv = wv
+
+  @staticmethod
+  def of(values, timestamp, windows, pane_info=PANE_INFO_UNKNOWN):
+    return HomogeneousWindowedBatch(
+        WindowedValue(values, timestamp, windows, pane_info))
+
+  @property
+  def values(self):
+    return self._wv.value
+
+  @property
+  def timestamp(self):

Review Comment:
   Ah, yes. 



##########
sdks/python/apache_beam/runners/common.py:
##########
@@ -1361,10 +1580,80 @@ def process_outputs(
         self.main_receivers.receive(windowed_value)
       else:
         self.tagged_receivers[tag].receive(windowed_value)
+
+    # TODO(BEAM-3937): Remove if block after output counter released.
+    # Only enable per_element_output_counter when counter cythonized
+    if self.per_element_output_counter is not None:
+      self.per_element_output_counter.add_input(output_element_count)
+
+  def process_batch_outputs(
+      self, windowed_input_batch, results, watermark_estimator=None):
+    # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> 
None
+
+    """Dispatch the result of process computation to the appropriate receivers.
+
+    A value wrapped in a TaggedOutput object will be unwrapped and
+    then dispatched to the appropriate indexed output.
+    """
+    if results is None:
+      # TODO(BEAM-3937): Remove if block after output counter released.
+      # Only enable per_element_output_counter when counter cythonized.
+      if self.per_element_output_counter is not None:
+        self.per_element_output_counter.add_input(0)
+      return
+
+    # TODO(BEAM-10782): Verify that the results object is a valid iterable type
+    #  if performance_runtime_type_check is active, without harming performance
+
+    output_element_count = 0
+    for result in results:
+      # results here may be a generator, which cannot call len on it.
+      output_element_count += 1
+      tag = None
+      if isinstance(result, TaggedOutput):
+        tag = result.tag
+        if not isinstance(tag, str):
+          raise TypeError('In %s, tag %s is not a string' % (self, tag))
+        result = result.value
+      if isinstance(result, (WindowedValue, TimestampedValue)):
+        raise TypeError(
+            f"Received {type(result).__name__} from DoFn that was "
+            "expected to produce a batch.")
+      if isinstance(result, WindowedBatch):
+        if isinstance(result, ConcreteWindowedBatch):
+          # TODO: Rebatch into homogenous batches (or remove
+          # ConcreteWindowedBatch)
+          raise NotImplementedError
+        elif isinstance(result, HomogeneousWindowedBatch):
+          windowed_batch = result
+        else:
+          raise AssertionError(
+              "Unrecognized WindowedBatch implementation: "
+              f"{type(windowed_batch)}")
+
+        if (windowed_input_batch is not None and
+            len(windowed_input_batch.windows) != 1):
+          windowed_batch.windows *= len(windowed_input_batch.windows)
+      # TODO(BEAM-14292): Add TimestampedBatch, an analogue for 
TimestampedValue
+      # and handle it here (see TimestampedValue logic in process_outputs).
+      else:
+        # TODO: This should error unless the DoFn was defined with
+        # @DoFn.yields_batches(output_aligned_with_input=True)
+        # We should consider also validating that the length is the same as
+        # windowed_input_batch
+        windowed_batch = windowed_input_batch.with_values(result)

Review Comment:
   This is safe if the input is homogeneous batch. Otherwise, let's throw an 
exception for now. 



##########
sdks/python/apache_beam/runners/common.py:
##########
@@ -1361,10 +1580,80 @@ def process_outputs(
         self.main_receivers.receive(windowed_value)
       else:
         self.tagged_receivers[tag].receive(windowed_value)
+
+    # TODO(BEAM-3937): Remove if block after output counter released.
+    # Only enable per_element_output_counter when counter cythonized
+    if self.per_element_output_counter is not None:
+      self.per_element_output_counter.add_input(output_element_count)
+
+  def process_batch_outputs(
+      self, windowed_input_batch, results, watermark_estimator=None):
+    # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> 
None
+
+    """Dispatch the result of process computation to the appropriate receivers.
+
+    A value wrapped in a TaggedOutput object will be unwrapped and
+    then dispatched to the appropriate indexed output.
+    """
+    if results is None:
+      # TODO(BEAM-3937): Remove if block after output counter released.
+      # Only enable per_element_output_counter when counter cythonized.
+      if self.per_element_output_counter is not None:
+        self.per_element_output_counter.add_input(0)
+      return
+
+    # TODO(BEAM-10782): Verify that the results object is a valid iterable type
+    #  if performance_runtime_type_check is active, without harming performance
+
+    output_element_count = 0
+    for result in results:
+      # results here may be a generator, which cannot call len on it.
+      output_element_count += 1

Review Comment:
   Increment by the batch size. 



##########
sdks/python/apache_beam/runners/common.py:
##########
@@ -1361,10 +1580,80 @@ def process_outputs(
         self.main_receivers.receive(windowed_value)
       else:
         self.tagged_receivers[tag].receive(windowed_value)
+
+    # TODO(BEAM-3937): Remove if block after output counter released.
+    # Only enable per_element_output_counter when counter cythonized
+    if self.per_element_output_counter is not None:
+      self.per_element_output_counter.add_input(output_element_count)
+
+  def process_batch_outputs(
+      self, windowed_input_batch, results, watermark_estimator=None):
+    # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> 
None
+
+    """Dispatch the result of process computation to the appropriate receivers.
+
+    A value wrapped in a TaggedOutput object will be unwrapped and
+    then dispatched to the appropriate indexed output.
+    """
+    if results is None:
+      # TODO(BEAM-3937): Remove if block after output counter released.
+      # Only enable per_element_output_counter when counter cythonized.
+      if self.per_element_output_counter is not None:
+        self.per_element_output_counter.add_input(0)
+      return
+
+    # TODO(BEAM-10782): Verify that the results object is a valid iterable type
+    #  if performance_runtime_type_check is active, without harming performance
+
+    output_element_count = 0
+    for result in results:
+      # results here may be a generator, which cannot call len on it.
+      output_element_count += 1
+      tag = None
+      if isinstance(result, TaggedOutput):
+        tag = result.tag
+        if not isinstance(tag, str):
+          raise TypeError('In %s, tag %s is not a string' % (self, tag))
+        result = result.value
+      if isinstance(result, (WindowedValue, TimestampedValue)):
+        raise TypeError(
+            f"Received {type(result).__name__} from DoFn that was "
+            "expected to produce a batch.")
+      if isinstance(result, WindowedBatch):
+        if isinstance(result, ConcreteWindowedBatch):
+          # TODO: Rebatch into homogenous batches (or remove
+          # ConcreteWindowedBatch)
+          raise NotImplementedError
+        elif isinstance(result, HomogeneousWindowedBatch):
+          windowed_batch = result
+        else:
+          raise AssertionError(
+              "Unrecognized WindowedBatch implementation: "
+              f"{type(windowed_batch)}")
+
+        if (windowed_input_batch is not None and
+            len(windowed_input_batch.windows) != 1):
+          windowed_batch.windows *= len(windowed_input_batch.windows)

Review Comment:
   Oh, here's where you use window assignment. I'm not sure this is safe....





Issue Time Tracking
-------------------

    Worklog Id:     (was: 765226)
    Time Spent: 2h 40m  (was: 2.5h)

> MVP for SDK worker changes to support process_batch
> ---------------------------------------------------
>
>                 Key: BEAM-14294
>                 URL: https://issues.apache.org/jira/browse/BEAM-14294
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-py-core
>            Reporter: Brian Hulette
>            Assignee: Brian Hulette
>            Priority: P2
>          Time Spent: 2h 40m
>  Remaining Estimate: 0h
>
> The initial MVP may only work in some restricted circumstances (e.g. 
> @yields_element on process_batch, or batch-to-batch without a 1:1 
> input:output mapping might not be supported). These cases should fail early. 



--
This message was sent by Atlassian Jira
(v8.20.7#820007)

Reply via email to