damccorm commented on code in PR #29106:
URL: https://github.com/apache/beam/pull/29106#discussion_r1369282464


##########
sdks/python/apache_beam/transforms/partitioners.py:
##########
@@ -0,0 +1,111 @@
+import uuid
+import apache_beam as beam
+from apache_beam import pvalue
+from typing import Optional, TypeVar
+from typing import Tuple
+from typing import Any
+from typing import Callable
+
+T = TypeVar('T')
+
+
+class Top(beam.PTransform):
+  """
+    A PTransform that takes a PCollection and partitions it into two
+    PCollections.  The first PCollection contains the largest n elements of the
+    input PCollection, and the second PCollection contains the remaining
+    elements of the input PCollection.
+
+    Parameters:
+        n: The number of elements to take from the input PCollection.
+        key: A function that takes an element of the input PCollection and
+            returns a value to compare for the purpose of determining the top n
+            elements, similar to Python's built-in sorted function.
+        reverse: If True, the top n elements will be the n smallest elements of
+            the input PCollection.
+
+    Example usage:
+
+        >>> with beam.Pipeline() as p:
+        ...     top, remaining = (p
+        ...         | beam.Create(list(range(10)))
+        ...         | partitioners.Top(3))
+        ...     # top will contain [7, 8, 9]
+        ...     # remaining will contain [0, 1, 2, 3, 4, 5, 6]
+
+    .. note::
+
+        This transform requires that the top PCollection fit into memory.
+
+    """
+  def __init__(
+      self, n: int, key: Optional[Callable[[Any], Any]] = None, reverse=False):
+    _validate_nonzero_positive_int(n)
+    self.n = n
+    self.key = key
+    self.reverse = reverse
+
+  def expand(self,
+             pcoll) -> Tuple[pvalue.PCollection[T], pvalue.PCollection[T]]:
+    # **Illustrative Example:**
+    # Our goal is to return two pcollections, top and
+    # remaining.
+
+    # Suppose you want to take the top element from `[1, 2, 2]`. Since we have
+    # identical elements, we need to be able to uniquely identify each one,
+    # so we assign a unique ID to each:
+    # `inputs_with_ids: [(1, "A"), (2, "B"), (2, "C")]`
+
+    # Then we sample, e.g.
+    # ``` sample: [(2, "B")] ```
+    # To get our goal `top` pcollection, we just strip the uuids from
+    # that sample.
+
+    # Now to get the `top` pcollection, we need to return essentially
+    # `inputs_with_ids` but without any of the elements fom the sample. To
+    # do this, we create a set from `sample`, getting `sample_ids:
+    # [set("B")]`. Now that we have this set, we can create our
+    # `remaining_with_ids` pcollection by just filtering out
+    # `inputs_with_ids` and checking for each element "Does this element's
+    # corresponding ID exist in `sample_ids`?"
+
+    # Finally, we just return `top` and strip the IDs as we no longer
+    # need them and the user doesn't care about them.
+    wrapped_key = lambda elem: self.key(elem[0]) if self.key else elem[0]
+    inputs_with_ids = (pcoll | beam.Map(_add_uuid))
+    sample = (
+        inputs_with_ids
+        | beam.combiners.Top.Of(self.n, key=wrapped_key, reverse=self.reverse))
+    sample_ids = (
+        sample
+        | beam.Map(lambda sample_list: set(ele[1] for ele in sample_list)))
+
+    def elem_is_not_sampled(elem, sampled_set):
+      return elem[1] not in sampled_set
+
+    remaining = (

Review Comment:
   I think we can both simplify this implementation and significantly improve 
performance using [user 
state](https://beam.apache.org/documentation/programming-guide/#types-of-state).
 Basically the idea would be to use bag state and:
   1) Add a dummy key if the input is unkeyed
   2) Read state to get our current running set of N elements.
   3) Scan that state to get the smallest element
   4) Check if the current element is larger than the smallest element grabbed 
from state. If yes, then add it to state. If no, emit it into the non-top 
PCollection
   5) In StartBundle, set an [event time 
timer](https://beam.apache.org/documentation/programming-guide/#timers) that 
expires at the end of the window. In that timer, if you have any buffered items 
in state emit them and clear your state.
   6) Remove the dummy key
   
   This would have a few advantages:
   1) It wouldn't block execution of later steps for the non-Top PCollection. 
So, you could partition your data and wait until the end of the window to 
process your Top PCollection, but you could start processing elements in your 
non-top PCollection as soon as you've seen enough data to guarantee that they 
can't be a part of your Top PCollection.
   2) It wouldn't require doing a join which could be at least a bit expensive.
   3) You get per-key Top for free since state is per-key - you just would skip 
adding/removing the dummy key.
   
   There's even a few optimizations we could potentially add:
   1) Store the current minimum value in valueState (duplicated from the bag 
state) and only read that initially (since it is a potentially much cheaper 
read if N gets large)
   2) Cache the current minimum value in the DoFn and route any future entries 
that are smaller than that to the non-Top PCollection. This would get out of 
date, so we'd still need to do the comparison to state if our value was larger 
than the min_val.
   
   Thoughts?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to