TheNeuralBit commented on a change in pull request #12469:
URL: https://github.com/apache/beam/pull/12469#discussion_r469410134



##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -35,20 +35,34 @@ def __array__(self, dtype=None):
   transform = frame_base._elementwise_method(
       'transform', restrictions={'axis': 0})
 
-  def agg(self, *args, **kwargs):
-    return frame_base.DeferredFrame.wrap(
-        expressions.ComputedExpression(
-            'agg',
-            lambda df: df.agg(*args, **kwargs), [self._expr],
-            preserves_partition_by=partitionings.Singleton(),
-            requires_partition_by=partitionings.Singleton()))
-
-  all = frame_base._associative_agg_method('all')
-  any = frame_base._associative_agg_method('any')
-  min = frame_base._associative_agg_method('min')
-  max = frame_base._associative_agg_method('max')
-  prod = product = frame_base._associative_agg_method('prod')
-  sum = frame_base._associative_agg_method('sum')
+  def agg(self, func, axis=0, *args, **kwargs):
+    if isinstance(func, list) and len(func) > 1:
+      rows = [self.agg([f], *args, **kwargs) for f in func]
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+              lambda *rows: pd.concat(rows), [row._expr for row in rows]))
+    else:
+      base_func = func[0] if isinstance(func, list) else func
+      if _is_associative(base_func) and not args and not kwargs:
+        intermediate = expressions.elementwise_expression(
+            'pre_agg',
+            lambda s: s.agg([base_func], *args, **kwargs), [self._expr])
+      else:
+        intermediate = self._expr
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'agg',
+              lambda s: s.agg(func, *args, **kwargs), [intermediate],
+              preserves_partition_by=partitionings.Singleton(),
+              requires_partition_by=partitionings.Singleton()))

Review comment:
       Here as well

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -244,6 +261,9 @@ def expr_to_stages(expr):
             # It also must be declared as an output of the producing stage.
             expr_to_stage(arg).outputs.add(arg)
       stage.ops.append(expr)
+      for arg in expr.args():
+        if arg in inputs:
+          stage.inputs.add(arg)

Review comment:
       ```suggestion
         # Ensure that any inputs for the overall transform are added in 
downstream stages
         for arg in expr.args():
           if arg in inputs:
             stage.inputs.add(arg)
   ```

##########
File path: sdks/python/apache_beam/dataframe/frames_test.py
##########
@@ -80,6 +81,24 @@ def test_loc(self):
     self._run_test(lambda df: df.loc[df.A > 10], df)
     self._run_test(lambda df: df.loc[lambda df: df.A > 10], df)
 
+  def test_series_agg(self):
+    s = pd.Series(list(range(16)))
+    self._run_test(lambda s: s.agg('sum'), s)
+    self._run_test(lambda s: s.agg(['sum']), s)
+    self._run_test(lambda s: s.agg(['sum', 'mean']), s)
+    self._run_test(lambda s: s.agg(['mean']), s)
+    self._run_test(lambda s: s.agg('mean'), s)
+
+  @unittest.skipIf(sys.version_info < (3, 6), 'Nondeterministic dict 
ordering.')

Review comment:
       Would it be reasonable to re-order the columns by name when asserting 
equality?

##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -150,35 +164,79 @@ def at(self, *args, **kwargs):
   def loc(self):
     return _DeferredLoc(self)
 
-  @frame_base.args_to_kwargs(pd.DataFrame)
-  @frame_base.populate_defaults(pd.DataFrame)
-  def aggregate(self, axis, **kwargs):
+  def aggregate(self, func, axis=0, *args, **kwargs):
     if axis is None:
-      return self.agg(axis=1, **kwargs).agg(axis=0, **kwargs)
-    return frame_base.DeferredFrame.wrap(
+      return self.agg(func, *args, **dict(kwargs, axis=1)).agg(
+          func, *args, **dict(kwargs, axis=0))
+    elif axis in (1, 'columns'):
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'aggregate',
+              lambda df: df.agg(func, axis=1, *args, **kwargs),
+              [self._expr],
+              requires_partition_by=partitionings.Nothing()))
+    elif len(self._expr.proxy().columns) == 0 or args or kwargs:
+      return frame_base.DeferredFrame.wrap(
         expressions.ComputedExpression(
             'aggregate',
-            lambda df: df.agg(axis=axis, **kwargs),
+            lambda df: df.agg(func, *args, **kwargs),
             [self._expr],
-            # TODO(robertwb): Sub-aggregate when possible.
             requires_partition_by=partitionings.Singleton()))
+    else:
+      if not isinstance(func, dict):
+        col_names = list(self._expr.proxy().columns)
+        func = {col: func for col in col_names}
+      else:
+        col_names = list(func.keys())
+      aggregated_cols = []
+      for col in col_names:
+        funcs = func[col]
+        if not isinstance(funcs, list):
+          funcs = [funcs]
+        aggregated_cols.append(self[col].agg(funcs, *args, **kwargs))
+      if any(isinstance(funcs, list) for funcs in func.values()):
+        return frame_base.DeferredFrame.wrap(
+            expressions.ComputedExpression(
+                'join_aggregate',
+                lambda *cols: pd.DataFrame(
+                    {col: value for col, value in zip(col_names, cols)}),
+                [col._expr for col in aggregated_cols]))
+      else:
+        return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+                lambda *cols: pd.Series(
+                    {col: value[0] for col, value in zip(col_names, cols)}),
+              [col._expr for col in aggregated_cols],
+              proxy=self._expr.proxy().agg(func, *args, **kwargs)))

Review comment:
       Could you add some comments describing the case each if is handling? I 
had a hard time making sense of them all

##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -150,35 +164,79 @@ def at(self, *args, **kwargs):
   def loc(self):
     return _DeferredLoc(self)
 
-  @frame_base.args_to_kwargs(pd.DataFrame)
-  @frame_base.populate_defaults(pd.DataFrame)
-  def aggregate(self, axis, **kwargs):
+  def aggregate(self, func, axis=0, *args, **kwargs):
     if axis is None:
-      return self.agg(axis=1, **kwargs).agg(axis=0, **kwargs)
-    return frame_base.DeferredFrame.wrap(
+      return self.agg(func, *args, **dict(kwargs, axis=1)).agg(
+          func, *args, **dict(kwargs, axis=0))
+    elif axis in (1, 'columns'):
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'aggregate',
+              lambda df: df.agg(func, axis=1, *args, **kwargs),
+              [self._expr],
+              requires_partition_by=partitionings.Nothing()))
+    elif len(self._expr.proxy().columns) == 0 or args or kwargs:
+      return frame_base.DeferredFrame.wrap(
         expressions.ComputedExpression(
             'aggregate',
-            lambda df: df.agg(axis=axis, **kwargs),
+            lambda df: df.agg(func, *args, **kwargs),
             [self._expr],
-            # TODO(robertwb): Sub-aggregate when possible.
             requires_partition_by=partitionings.Singleton()))
+    else:
+      if not isinstance(func, dict):
+        col_names = list(self._expr.proxy().columns)
+        func = {col: func for col in col_names}
+      else:
+        col_names = list(func.keys())
+      aggregated_cols = []
+      for col in col_names:
+        funcs = func[col]
+        if not isinstance(funcs, list):
+          funcs = [funcs]
+        aggregated_cols.append(self[col].agg(funcs, *args, **kwargs))
+      if any(isinstance(funcs, list) for funcs in func.values()):
+        return frame_base.DeferredFrame.wrap(
+            expressions.ComputedExpression(
+                'join_aggregate',
+                lambda *cols: pd.DataFrame(
+                    {col: value for col, value in zip(col_names, cols)}),
+                [col._expr for col in aggregated_cols]))
+      else:
+        return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+                lambda *cols: pd.Series(
+                    {col: value[0] for col, value in zip(col_names, cols)}),
+              [col._expr for col in aggregated_cols],
+              proxy=self._expr.proxy().agg(func, *args, **kwargs)))
 
   agg = aggregate

Review comment:
       I think we're missing this alias in `Series`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to