damccorm commented on code in PR #29106:
URL: https://github.com/apache/beam/pull/29106#discussion_r1373707395
##########
sdks/python/apache_beam/transforms/partitioners.py:
##########
@@ -0,0 +1,111 @@
+import uuid
+import apache_beam as beam
+from apache_beam import pvalue
+from typing import Optional, TypeVar
+from typing import Tuple
+from typing import Any
+from typing import Callable
+
+T = TypeVar('T')
+
+
+class Top(beam.PTransform):
+ """
+ A PTransform that takes a PCollection and partitions it into two
+ PCollections. The first PCollection contains the largest n elements of the
+ input PCollection, and the second PCollection contains the remaining
+ elements of the input PCollection.
+
+ Parameters:
+ n: The number of elements to take from the input PCollection.
+ key: A function that takes an element of the input PCollection and
+ returns a value to compare for the purpose of determining the top n
+ elements, similar to Python's built-in sorted function.
+ reverse: If True, the top n elements will be the n smallest elements of
+ the input PCollection.
+
+ Example usage:
+
+ >>> with beam.Pipeline() as p:
+ ... top, remaining = (p
+ ... | beam.Create(list(range(10)))
+ ... | partitioners.Top(3))
+ ... # top will contain [7, 8, 9]
+ ... # remaining will contain [0, 1, 2, 3, 4, 5, 6]
+
+ .. note::
+
+ This transform requires that the top PCollection fit into memory.
+
+ """
+ def __init__(
+ self, n: int, key: Optional[Callable[[Any], Any]] = None, reverse=False):
+ _validate_nonzero_positive_int(n)
+ self.n = n
+ self.key = key
+ self.reverse = reverse
+
+ def expand(self,
+ pcoll) -> Tuple[pvalue.PCollection[T], pvalue.PCollection[T]]:
+ # **Illustrative Example:**
+ # Our goal is to return two pcollections, top and
+ # remaining.
+
+ # Suppose you want to take the top element from `[1, 2, 2]`. Since we have
+ # identical elements, we need to be able to uniquely identify each one,
+ # so we assign a unique ID to each:
+ # `inputs_with_ids: [(1, "A"), (2, "B"), (2, "C")]`
+
+ # Then we sample, e.g.
+ # ``` sample: [(2, "B")] ```
+ # To get our goal `top` pcollection, we just strip the uuids from
+ # that sample.
+
+ # Now to get the `top` pcollection, we need to return essentially
+ # `inputs_with_ids` but without any of the elements fom the sample. To
+ # do this, we create a set from `sample`, getting `sample_ids:
+ # [set("B")]`. Now that we have this set, we can create our
+ # `remaining_with_ids` pcollection by just filtering out
+ # `inputs_with_ids` and checking for each element "Does this element's
+ # corresponding ID exist in `sample_ids`?"
+
+ # Finally, we just return `top` and strip the IDs as we no longer
+ # need them and the user doesn't care about them.
+ wrapped_key = lambda elem: self.key(elem[0]) if self.key else elem[0]
+ inputs_with_ids = (pcoll | beam.Map(_add_uuid))
+ sample = (
+ inputs_with_ids
+ | beam.combiners.Top.Of(self.n, key=wrapped_key, reverse=self.reverse))
+ sample_ids = (
+ sample
+ | beam.Map(lambda sample_list: set(ele[1] for ele in sample_list)))
+
+ def elem_is_not_sampled(elem, sampled_set):
+ return elem[1] not in sampled_set
+
+ remaining = (
Review Comment:
It does, though in exchange you're getting the ability to start processing
your data in the rest of your pipeline before you reach the end of the window
(which probably saves time/resources in most use cases, especially in
streaming). I think this is a fair concern though, especially in batch; one
optimization we could make to greatly reduce our bottleneck in cases where Top
is small would be to add a pre-step where we filter out the non-Top elements
and send them downstream to be joined with the remaining non-Top elements using
a `Flatten` (something like
[_TopPerBundle](https://github.com/apache/beam/blob/206042eb3e9529a92c13d1a050ffd89f307f6e7b/sdks/python/apache_beam/transforms/combiners.py#L327)).
Cases where Top is large will still end up running into a similar bottleneck
even if we use a combiner since we won't be able to do much combiner lifting
(local reduction before sending it over the wire).
I'll also note that state is per key/window, so we're actually just talking
the set of inputs for a single window and you will get parallelization with
multiple windows.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]