iindyk commented on a change in pull request #13175:
URL: https://github.com/apache/beam/pull/13175#discussion_r570716102



##########
File path: sdks/python/apache_beam/transforms/stats.py
##########
@@ -368,82 +383,126 @@ class PerKey(PTransform):
       weighted: (optional) if set to True, the transform returns weighted
         quantiles. The input PCollection is then expected to contain tuples of
         input values with the corresponding weight.
+      batch_input: (optional) if set to True, the transform expects each 
element
+        of input PCollection to be a batch. Provides a way to accumulate
+        multiple elements at a time more efficiently.
     """
-    def __init__(self, num_quantiles, key=None, reverse=False, weighted=False):
+    def __init__(
+        self,
+        num_quantiles,
+        key=None,
+        reverse=False,
+        weighted=False,
+        batch_input=False):
       self._num_quantiles = num_quantiles
       self._key = key
       self._reverse = reverse
       self._weighted = weighted
+      self._batch_input = batch_input
 
     def expand(self, pcoll):
       return pcoll | CombinePerKey(
           ApproximateQuantilesCombineFn.create(
               num_quantiles=self._num_quantiles,
               key=self._key,
               reverse=self._reverse,
-              weighted=self._weighted))
+              weighted=self._weighted,
+              batch_input=self._batch_input))
 
     def display_data(self):
       return ApproximateQuantiles._display_data(
           num_quantiles=self._num_quantiles,
           key=self._key,
           reverse=self._reverse,
-          weighted=self._weighted)
+          weighted=self._weighted,
+          batch_input=self._batch_input)
+
+
+class _QuantileSpec(object):
+  """Quantiles computation specifications."""
+  def __init__(self, buffer_size, num_buffers, weighted, key, reverse):
+    # type: (int, int, bool, Any, bool) -> None
+    self.buffer_size = buffer_size
+    self.num_buffers = num_buffers
+    self.weighted = weighted
+    self.key = key
+    self.reverse = reverse
+
+    # Used to sort tuples of values and weights.
+    self.weighted_key = None if key is None else (lambda x: key(x[0]))
+
+    # Used to compare values.
+    if reverse and key is None:
+      self.less_than = lambda a, b: a > b
+    elif reverse:
+      self.less_than = lambda a, b: key(a) > key(b)
+    elif key is None:
+      self.less_than = lambda a, b: a < b
+    else:
+      self.less_than = lambda a, b: key(a) < key(b)
+
+  def get_argsort_key(self, elements):
+    # type: (List) -> Any
+
+    """Returns a key for sorting indices of elements by element's value."""
+    if self.key is None:
+      return elements.__getitem__
+    else:
+      return lambda idx: self.key(elements[idx])
+
+  def __reduce__(self):
+    return (
+        self.__class__,
+        (
+            self.buffer_size,
+            self.num_buffers,
+            self.weighted,
+            self.key,
+            self.reverse))
 
 
-class _QuantileBuffer(Generic[T]):
+class _QuantileBuffer(object):
   """A single buffer in the sense of the referenced algorithm.
   (see http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.6513&rep=rep1
   &type=pdf and ApproximateQuantilesCombineFn for further information)"""
-  def __init__(self, elements, weighted, level=0, weight=1):
-    # type: (Sequence[T], bool, int, int) -> None
-    # In case of weighted quantiles, elements are tuples of values and weights.
+  def __init__(
+      self, elements, weights, weighted, level=0, min_val=None, max_val=None):

Review comment:
       Done.

##########
File path: sdks/python/apache_beam/transforms/stats.py
##########
@@ -523,29 +805,25 @@ def __init__(
       num_buffers,  # type: int
       key=None,
       reverse=False,
-      weighted=False):
-    def _comparator(a, b):
-      if key:
-        a, b = key(a), key(b)
-
-      retval = int(a > b) - int(a < b)
-
-      if reverse:
-        return -retval
-
-      return retval
-
-    self._comparator = _comparator
-
+      weighted=False,
+      batch_input=False):
     self._num_quantiles = num_quantiles
-    self._buffer_size = buffer_size
-    self._num_buffers = num_buffers
-    if weighted:
-      self._key = (lambda x: x[0]) if key is None else (lambda x: key(x[0]))
-    else:
-      self._key = key
-    self._reverse = reverse
-    self._weighted = weighted
+    self._spec = _QuantileSpec(buffer_size, num_buffers, weighted, key, 
reverse)
+    self._batch_input = batch_input
+    if self._batch_input:
+      setattr(self, 'add_input', self._add_inputs)

Review comment:
       Done.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to