You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2020/08/05 00:10:55 UTC

[GitHub] [beam] robertwb opened a new pull request #12469: [BEAM-9547] Lift associative aggregations.

robertwb opened a new pull request #12469:
URL: https://github.com/apache/beam/pull/12469


   Also fix issue with inputs getting used in downstream stages.
   
   ------------------------
   
   Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
   
    - [ ] [**Choose reviewer(s)**](https://beam.apache.org/contribute/#make-your-change) and mention them in a comment (`R: @username`).
    - [ ] Format the pull request title like `[BEAM-XXX] Fixes bug in ApproximateQuantiles`, where you replace `BEAM-XXX` with the appropriate JIRA issue, if applicable. This will automatically link the pull request to the issue.
    - [ ] Update `CHANGES.md` with noteworthy changes.
    - [ ] If this contribution is large, please file an Apache [Individual Contributor License Agreement](https://www.apache.org/licenses/icla.pdf).
   
   See the [Contributor Guide](https://beam.apache.org/contribute) for more tips on [how to make review process smoother](https://beam.apache.org/contribute/#make-reviewers-job-easier).
   
   Post-Commit Tests Status (on master branch)
   ------------------------------------------------------------------------------------------------
   
   Lang | SDK | Dataflow | Flink | Samza | Spark | Twister2
   --- | --- | --- | --- | --- | --- | ---
   Go | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Go/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Go/lastCompletedBuild/) | --- | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Go_VR_Flink/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Go_VR_Flink/lastCompletedBuild/) | --- | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Go_VR_Spark/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Go_VR_Spark/lastCompletedBuild/) | ---
   Java | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Dataflow/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Dataflow/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Dataflow_Java11/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Dataflow_Java11/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Flink/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Flink/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Flink_Java11/lastCompletedBuild/badge/i
 con)](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Flink_Java11/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_PVR_Flink_Batch/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_PVR_Flink_Batch/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_PVR_Flink_Streaming/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_PVR_Flink_Streaming/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Samza/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Samza/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Spark/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Spark/lastCompletedBuild/)<br>[![Build Status](htt
 ps://ci-beam.apache.org/job/beam_PostCommit_Java_PVR_Spark_Batch/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_PVR_Spark_Batch/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Java_ValidatesRunner_SparkStructuredStreaming/lastCompletedBuild/) | [![Build Status](https://builds.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Twister2/lastCompletedBuild/badge/icon)](https://builds.apache.org/job/beam_PostCommit_Java_ValidatesRunner_Twister2/lastCompletedBuild/)
   Python | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Python2/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Python2/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Python35/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Python35/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Python36/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Python36/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Python37/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Python37/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Python38/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Python38/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_
 Py_VR_Dataflow/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Py_VR_Dataflow/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Py_VR_Dataflow_V2/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Py_VR_Dataflow_V2/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Py_ValCont/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Py_ValCont/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/beam_PreCommit_Python2_PVR_Flink_Cron/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PreCommit_Python2_PVR_Flink_Cron/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_Python35_VR_Flink/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Python35_VR_Flink/lastCompletedBuild/) | --- | [![Build Status](https://ci-beam.apache.org/job/beam_P
 ostCommit_Python_VR_Spark/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_Python_VR_Spark/lastCompletedBuild/) | ---
   XLang | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_XVR_Direct/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_XVR_Direct/lastCompletedBuild/) | --- | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_XVR_Flink/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_XVR_Flink/lastCompletedBuild/) | --- | [![Build Status](https://ci-beam.apache.org/job/beam_PostCommit_XVR_Spark/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PostCommit_XVR_Spark/lastCompletedBuild/) | ---
   
   Pre-Commit Tests Status (on master branch)
   ------------------------------------------------------------------------------------------------
   
   --- |Java | Python | Go | Website
   --- | --- | --- | --- | ---
   Non-portable | [![Build Status](https://ci-beam.apache.org/job/beam_PreCommit_Java_Cron/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PreCommit_Java_Cron/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/beam_PreCommit_Python_Cron/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PreCommit_Python_Cron/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PreCommit_PythonLint_Cron/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PreCommit_PythonLint_Cron/lastCompletedBuild/)<br>[![Build Status](https://ci-beam.apache.org/job/beam_PreCommit_PythonDocker_Cron/badge/icon)](https://ci-beam.apache.org/job/beam_PreCommit_PythonDocker_Cron/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/beam_PreCommit_Go_Cron/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PreCommit_Go_Cron/lastCompletedBuild/) | [![Build Status](https://ci-beam.apache.org/job/b
 eam_PreCommit_Website_Cron/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PreCommit_Website_Cron/lastCompletedBuild/)
   Portable | --- | [![Build Status](https://ci-beam.apache.org/job/beam_PreCommit_Portable_Python_Cron/lastCompletedBuild/badge/icon)](https://ci-beam.apache.org/job/beam_PreCommit_Portable_Python_Cron/lastCompletedBuild/) | --- | ---
   
   See [.test-infra/jenkins/README](https://github.com/apache/beam/blob/master/.test-infra/jenkins/README.md) for trigger phrase, status and link of all Jenkins jobs.
   
   
   GitHub Actions Tests Status (on master branch)
   ------------------------------------------------------------------------------------------------
   ![Build python source distribution and wheels](https://github.com/apache/beam/workflows/Build%20python%20source%20distribution%20and%20wheels/badge.svg)
   
   See [CI.md](https://github.com/apache/beam/blob/master/CI.md) for more information about GitHub Actions CI.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [beam] TheNeuralBit commented on a change in pull request #12469: [BEAM-9547] Lift associative aggregations.

Posted by GitBox <gi...@apache.org>.
TheNeuralBit commented on a change in pull request #12469:
URL: https://github.com/apache/beam/pull/12469#discussion_r469410134



##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -35,20 +35,34 @@ def __array__(self, dtype=None):
   transform = frame_base._elementwise_method(
       'transform', restrictions={'axis': 0})
 
-  def agg(self, *args, **kwargs):
-    return frame_base.DeferredFrame.wrap(
-        expressions.ComputedExpression(
-            'agg',
-            lambda df: df.agg(*args, **kwargs), [self._expr],
-            preserves_partition_by=partitionings.Singleton(),
-            requires_partition_by=partitionings.Singleton()))
-
-  all = frame_base._associative_agg_method('all')
-  any = frame_base._associative_agg_method('any')
-  min = frame_base._associative_agg_method('min')
-  max = frame_base._associative_agg_method('max')
-  prod = product = frame_base._associative_agg_method('prod')
-  sum = frame_base._associative_agg_method('sum')
+  def agg(self, func, axis=0, *args, **kwargs):
+    if isinstance(func, list) and len(func) > 1:
+      rows = [self.agg([f], *args, **kwargs) for f in func]
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+              lambda *rows: pd.concat(rows), [row._expr for row in rows]))
+    else:
+      base_func = func[0] if isinstance(func, list) else func
+      if _is_associative(base_func) and not args and not kwargs:
+        intermediate = expressions.elementwise_expression(
+            'pre_agg',
+            lambda s: s.agg([base_func], *args, **kwargs), [self._expr])
+      else:
+        intermediate = self._expr
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'agg',
+              lambda s: s.agg(func, *args, **kwargs), [intermediate],
+              preserves_partition_by=partitionings.Singleton(),
+              requires_partition_by=partitionings.Singleton()))

Review comment:
       Here as well

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -244,6 +261,9 @@ def expr_to_stages(expr):
             # It also must be declared as an output of the producing stage.
             expr_to_stage(arg).outputs.add(arg)
       stage.ops.append(expr)
+      for arg in expr.args():
+        if arg in inputs:
+          stage.inputs.add(arg)

Review comment:
       ```suggestion
         # Ensure that any inputs for the overall transform are added in downstream stages
         for arg in expr.args():
           if arg in inputs:
             stage.inputs.add(arg)
   ```

##########
File path: sdks/python/apache_beam/dataframe/frames_test.py
##########
@@ -80,6 +81,24 @@ def test_loc(self):
     self._run_test(lambda df: df.loc[df.A > 10], df)
     self._run_test(lambda df: df.loc[lambda df: df.A > 10], df)
 
+  def test_series_agg(self):
+    s = pd.Series(list(range(16)))
+    self._run_test(lambda s: s.agg('sum'), s)
+    self._run_test(lambda s: s.agg(['sum']), s)
+    self._run_test(lambda s: s.agg(['sum', 'mean']), s)
+    self._run_test(lambda s: s.agg(['mean']), s)
+    self._run_test(lambda s: s.agg('mean'), s)
+
+  @unittest.skipIf(sys.version_info < (3, 6), 'Nondeterministic dict ordering.')

Review comment:
       Would it be reasonable to re-order the columns by name when asserting equality?

##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -150,35 +164,79 @@ def at(self, *args, **kwargs):
   def loc(self):
     return _DeferredLoc(self)
 
-  @frame_base.args_to_kwargs(pd.DataFrame)
-  @frame_base.populate_defaults(pd.DataFrame)
-  def aggregate(self, axis, **kwargs):
+  def aggregate(self, func, axis=0, *args, **kwargs):
     if axis is None:
-      return self.agg(axis=1, **kwargs).agg(axis=0, **kwargs)
-    return frame_base.DeferredFrame.wrap(
+      return self.agg(func, *args, **dict(kwargs, axis=1)).agg(
+          func, *args, **dict(kwargs, axis=0))
+    elif axis in (1, 'columns'):
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'aggregate',
+              lambda df: df.agg(func, axis=1, *args, **kwargs),
+              [self._expr],
+              requires_partition_by=partitionings.Nothing()))
+    elif len(self._expr.proxy().columns) == 0 or args or kwargs:
+      return frame_base.DeferredFrame.wrap(
         expressions.ComputedExpression(
             'aggregate',
-            lambda df: df.agg(axis=axis, **kwargs),
+            lambda df: df.agg(func, *args, **kwargs),
             [self._expr],
-            # TODO(robertwb): Sub-aggregate when possible.
             requires_partition_by=partitionings.Singleton()))
+    else:
+      if not isinstance(func, dict):
+        col_names = list(self._expr.proxy().columns)
+        func = {col: func for col in col_names}
+      else:
+        col_names = list(func.keys())
+      aggregated_cols = []
+      for col in col_names:
+        funcs = func[col]
+        if not isinstance(funcs, list):
+          funcs = [funcs]
+        aggregated_cols.append(self[col].agg(funcs, *args, **kwargs))
+      if any(isinstance(funcs, list) for funcs in func.values()):
+        return frame_base.DeferredFrame.wrap(
+            expressions.ComputedExpression(
+                'join_aggregate',
+                lambda *cols: pd.DataFrame(
+                    {col: value for col, value in zip(col_names, cols)}),
+                [col._expr for col in aggregated_cols]))
+      else:
+        return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+                lambda *cols: pd.Series(
+                    {col: value[0] for col, value in zip(col_names, cols)}),
+              [col._expr for col in aggregated_cols],
+              proxy=self._expr.proxy().agg(func, *args, **kwargs)))

Review comment:
       Could you add some comments describing the case each if is handling? I had a hard time making sense of them all

##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -150,35 +164,79 @@ def at(self, *args, **kwargs):
   def loc(self):
     return _DeferredLoc(self)
 
-  @frame_base.args_to_kwargs(pd.DataFrame)
-  @frame_base.populate_defaults(pd.DataFrame)
-  def aggregate(self, axis, **kwargs):
+  def aggregate(self, func, axis=0, *args, **kwargs):
     if axis is None:
-      return self.agg(axis=1, **kwargs).agg(axis=0, **kwargs)
-    return frame_base.DeferredFrame.wrap(
+      return self.agg(func, *args, **dict(kwargs, axis=1)).agg(
+          func, *args, **dict(kwargs, axis=0))
+    elif axis in (1, 'columns'):
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'aggregate',
+              lambda df: df.agg(func, axis=1, *args, **kwargs),
+              [self._expr],
+              requires_partition_by=partitionings.Nothing()))
+    elif len(self._expr.proxy().columns) == 0 or args or kwargs:
+      return frame_base.DeferredFrame.wrap(
         expressions.ComputedExpression(
             'aggregate',
-            lambda df: df.agg(axis=axis, **kwargs),
+            lambda df: df.agg(func, *args, **kwargs),
             [self._expr],
-            # TODO(robertwb): Sub-aggregate when possible.
             requires_partition_by=partitionings.Singleton()))
+    else:
+      if not isinstance(func, dict):
+        col_names = list(self._expr.proxy().columns)
+        func = {col: func for col in col_names}
+      else:
+        col_names = list(func.keys())
+      aggregated_cols = []
+      for col in col_names:
+        funcs = func[col]
+        if not isinstance(funcs, list):
+          funcs = [funcs]
+        aggregated_cols.append(self[col].agg(funcs, *args, **kwargs))
+      if any(isinstance(funcs, list) for funcs in func.values()):
+        return frame_base.DeferredFrame.wrap(
+            expressions.ComputedExpression(
+                'join_aggregate',
+                lambda *cols: pd.DataFrame(
+                    {col: value for col, value in zip(col_names, cols)}),
+                [col._expr for col in aggregated_cols]))
+      else:
+        return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+                lambda *cols: pd.Series(
+                    {col: value[0] for col, value in zip(col_names, cols)}),
+              [col._expr for col in aggregated_cols],
+              proxy=self._expr.proxy().agg(func, *args, **kwargs)))
 
   agg = aggregate

Review comment:
       I think we're missing this alias in `Series`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [beam] robertwb commented on pull request #12469: [BEAM-9547] Lift associative aggregations.

Posted by GitBox <gi...@apache.org>.
robertwb commented on pull request #12469:
URL: https://github.com/apache/beam/pull/12469#issuecomment-673626876


   Run Python2_PVR_Flink PreCommit


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [beam] robertwb merged pull request #12469: [BEAM-9547] Lift associative aggregations.

Posted by GitBox <gi...@apache.org>.
robertwb merged pull request #12469:
URL: https://github.com/apache/beam/pull/12469


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [beam] robertwb commented on a change in pull request #12469: [BEAM-9547] Lift associative aggregations.

Posted by GitBox <gi...@apache.org>.
robertwb commented on a change in pull request #12469:
URL: https://github.com/apache/beam/pull/12469#discussion_r469616761



##########
File path: sdks/python/apache_beam/dataframe/frames_test.py
##########
@@ -80,6 +81,24 @@ def test_loc(self):
     self._run_test(lambda df: df.loc[df.A > 10], df)
     self._run_test(lambda df: df.loc[lambda df: df.A > 10], df)
 
+  def test_series_agg(self):
+    s = pd.Series(list(range(16)))
+    self._run_test(lambda s: s.agg('sum'), s)
+    self._run_test(lambda s: s.agg(['sum']), s)
+    self._run_test(lambda s: s.agg(['sum', 'mean']), s)
+    self._run_test(lambda s: s.agg(['mean']), s)
+    self._run_test(lambda s: s.agg('mean'), s)
+
+  @unittest.skipIf(sys.version_info < (3, 6), 'Nondeterministic dict ordering.')

Review comment:
       Column ordering seems to be a fairly fundamental property of dataframes that I'd prefer to check in general, and 3.5 won't be supported for long. 

##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -150,35 +164,79 @@ def at(self, *args, **kwargs):
   def loc(self):
     return _DeferredLoc(self)
 
-  @frame_base.args_to_kwargs(pd.DataFrame)
-  @frame_base.populate_defaults(pd.DataFrame)
-  def aggregate(self, axis, **kwargs):
+  def aggregate(self, func, axis=0, *args, **kwargs):
     if axis is None:
-      return self.agg(axis=1, **kwargs).agg(axis=0, **kwargs)
-    return frame_base.DeferredFrame.wrap(
+      return self.agg(func, *args, **dict(kwargs, axis=1)).agg(
+          func, *args, **dict(kwargs, axis=0))
+    elif axis in (1, 'columns'):
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'aggregate',
+              lambda df: df.agg(func, axis=1, *args, **kwargs),
+              [self._expr],
+              requires_partition_by=partitionings.Nothing()))
+    elif len(self._expr.proxy().columns) == 0 or args or kwargs:
+      return frame_base.DeferredFrame.wrap(
         expressions.ComputedExpression(
             'aggregate',
-            lambda df: df.agg(axis=axis, **kwargs),
+            lambda df: df.agg(func, *args, **kwargs),
             [self._expr],
-            # TODO(robertwb): Sub-aggregate when possible.
             requires_partition_by=partitionings.Singleton()))
+    else:
+      if not isinstance(func, dict):
+        col_names = list(self._expr.proxy().columns)
+        func = {col: func for col in col_names}
+      else:
+        col_names = list(func.keys())
+      aggregated_cols = []
+      for col in col_names:
+        funcs = func[col]
+        if not isinstance(funcs, list):
+          funcs = [funcs]
+        aggregated_cols.append(self[col].agg(funcs, *args, **kwargs))
+      if any(isinstance(funcs, list) for funcs in func.values()):
+        return frame_base.DeferredFrame.wrap(
+            expressions.ComputedExpression(
+                'join_aggregate',
+                lambda *cols: pd.DataFrame(
+                    {col: value for col, value in zip(col_names, cols)}),
+                [col._expr for col in aggregated_cols]))
+      else:
+        return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+                lambda *cols: pd.Series(
+                    {col: value[0] for col, value in zip(col_names, cols)}),
+              [col._expr for col in aggregated_cols],
+              proxy=self._expr.proxy().agg(func, *args, **kwargs)))

Review comment:
       Done.

##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -35,20 +35,34 @@ def __array__(self, dtype=None):
   transform = frame_base._elementwise_method(
       'transform', restrictions={'axis': 0})
 
-  def agg(self, *args, **kwargs):
-    return frame_base.DeferredFrame.wrap(
-        expressions.ComputedExpression(
-            'agg',
-            lambda df: df.agg(*args, **kwargs), [self._expr],
-            preserves_partition_by=partitionings.Singleton(),
-            requires_partition_by=partitionings.Singleton()))
-
-  all = frame_base._associative_agg_method('all')
-  any = frame_base._associative_agg_method('any')
-  min = frame_base._associative_agg_method('min')
-  max = frame_base._associative_agg_method('max')
-  prod = product = frame_base._associative_agg_method('prod')
-  sum = frame_base._associative_agg_method('sum')
+  def agg(self, func, axis=0, *args, **kwargs):
+    if isinstance(func, list) and len(func) > 1:
+      rows = [self.agg([f], *args, **kwargs) for f in func]
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+              lambda *rows: pd.concat(rows), [row._expr for row in rows]))
+    else:
+      base_func = func[0] if isinstance(func, list) else func
+      if _is_associative(base_func) and not args and not kwargs:
+        intermediate = expressions.elementwise_expression(
+            'pre_agg',
+            lambda s: s.agg([base_func], *args, **kwargs), [self._expr])
+      else:
+        intermediate = self._expr
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'agg',
+              lambda s: s.agg(func, *args, **kwargs), [intermediate],
+              preserves_partition_by=partitionings.Singleton(),
+              requires_partition_by=partitionings.Singleton()))

Review comment:
       Done.

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -244,6 +261,9 @@ def expr_to_stages(expr):
             # It also must be declared as an output of the producing stage.
             expr_to_stage(arg).outputs.add(arg)
       stage.ops.append(expr)
+      for arg in expr.args():
+        if arg in inputs:
+          stage.inputs.add(arg)

Review comment:
       Done.

##########
File path: sdks/python/apache_beam/dataframe/frames.py
##########
@@ -150,35 +164,79 @@ def at(self, *args, **kwargs):
   def loc(self):
     return _DeferredLoc(self)
 
-  @frame_base.args_to_kwargs(pd.DataFrame)
-  @frame_base.populate_defaults(pd.DataFrame)
-  def aggregate(self, axis, **kwargs):
+  def aggregate(self, func, axis=0, *args, **kwargs):
     if axis is None:
-      return self.agg(axis=1, **kwargs).agg(axis=0, **kwargs)
-    return frame_base.DeferredFrame.wrap(
+      return self.agg(func, *args, **dict(kwargs, axis=1)).agg(
+          func, *args, **dict(kwargs, axis=0))
+    elif axis in (1, 'columns'):
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'aggregate',
+              lambda df: df.agg(func, axis=1, *args, **kwargs),
+              [self._expr],
+              requires_partition_by=partitionings.Nothing()))
+    elif len(self._expr.proxy().columns) == 0 or args or kwargs:
+      return frame_base.DeferredFrame.wrap(
         expressions.ComputedExpression(
             'aggregate',
-            lambda df: df.agg(axis=axis, **kwargs),
+            lambda df: df.agg(func, *args, **kwargs),
             [self._expr],
-            # TODO(robertwb): Sub-aggregate when possible.
             requires_partition_by=partitionings.Singleton()))
+    else:
+      if not isinstance(func, dict):
+        col_names = list(self._expr.proxy().columns)
+        func = {col: func for col in col_names}
+      else:
+        col_names = list(func.keys())
+      aggregated_cols = []
+      for col in col_names:
+        funcs = func[col]
+        if not isinstance(funcs, list):
+          funcs = [funcs]
+        aggregated_cols.append(self[col].agg(funcs, *args, **kwargs))
+      if any(isinstance(funcs, list) for funcs in func.values()):
+        return frame_base.DeferredFrame.wrap(
+            expressions.ComputedExpression(
+                'join_aggregate',
+                lambda *cols: pd.DataFrame(
+                    {col: value for col, value in zip(col_names, cols)}),
+                [col._expr for col in aggregated_cols]))
+      else:
+        return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'join_aggregate',
+                lambda *cols: pd.Series(
+                    {col: value[0] for col, value in zip(col_names, cols)}),
+              [col._expr for col in aggregated_cols],
+              proxy=self._expr.proxy().agg(func, *args, **kwargs)))
 
   agg = aggregate

Review comment:
       Good call. Done.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org