You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/09/26 00:00:55 UTC

[spark] branch master updated: [SPARK-40334][PS] Implement `GroupBy.prod`

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c01e524c298 [SPARK-40334][PS] Implement `GroupBy.prod`
c01e524c298 is described below

commit c01e524c2985be06027191e51bb94d9ee5637d40
Author: artsiomyudovin <a....@gmail.com>
AuthorDate: Mon Sep 26 08:00:20 2022 +0800

    [SPARK-40334][PS] Implement `GroupBy.prod`
    
    ### What changes were proposed in this pull request?
    Implement `GroupBy.prod`
    
    ### Why are the changes needed?
    for API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, the new API
    ```
    df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
                                'B': [np.nan, 2, 3, 4, 5],
                                'C': [1, 2, 1, 1, 2],
                                'D': [True, False, True, False, True]})
    
            Groupby one column and return the prod of the remaining columns in
            each group.
    
            df.groupby('A').prod()
                 B  C  D
            A
            1  8.0  2  0
            2  15.0 2  11
    
            df.groupby('A').prod(min_count=3)
                 B  C   D
            A
            1  NaN  2  0
            2  NaN NaN  NaN
    ```
    
    ### How was this patch tested?
    added UT
    
    Closes #37923 from ayudovin/ps_group_by_prod.
    
    Authored-by: artsiomyudovin <a....@gmail.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../source/reference/pyspark.pandas/groupby.rst    |   1 +
 python/pyspark/pandas/groupby.py                   | 106 ++++++++++++++++++++-
 python/pyspark/pandas/missing/groupby.py           |   2 -
 python/pyspark/pandas/tests/test_groupby.py        |  10 ++
 4 files changed, 114 insertions(+), 5 deletions(-)

diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst
index 4c29964966c..da1579fd723 100644
--- a/python/docs/source/reference/pyspark.pandas/groupby.rst
+++ b/python/docs/source/reference/pyspark.pandas/groupby.rst
@@ -74,6 +74,7 @@ Computations / Descriptive Stats
    GroupBy.median
    GroupBy.min
    GroupBy.nth
+   GroupBy.prod
    GroupBy.rank
    GroupBy.sem
    GroupBy.std
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 2e5c9ab219a..6d36cfecce6 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -18,7 +18,6 @@
 """
 A wrapper for GroupedData to behave similar to pandas GroupBy.
 """
-
 from abc import ABCMeta, abstractmethod
 import inspect
 from collections import defaultdict, namedtuple
@@ -63,6 +62,7 @@ from pyspark.sql.types import (
     StructField,
     StructType,
     StringType,
+    IntegralType,
 )
 
 from pyspark import pandas as ps  # For running doctests and reference resolution in PyCharm.
@@ -1055,6 +1055,106 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
 
         return self._prepare_return(DataFrame(internal))
 
+    def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameLike:
+        """
+        Compute prod of groups.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        numeric_only : bool, default False
+            Include only float, int, boolean columns. If None, will attempt to use
+            everything, then use only numeric data.
+
+        min_count: int, default 0
+            The required number of valid values to perform the operation.
+            If fewer than min_count non-NA values are present the result will be NA.
+
+        Returns
+        -------
+        Series or DataFrame
+            Computed prod of values within each group.
+
+        See Also
+        --------
+        pyspark.pandas.Series.groupby
+        pyspark.pandas.DataFrame.groupby
+
+        Examples
+        --------
+        >>> import numpy as np
+        >>> df = ps.DataFrame(
+        ...     {
+        ...         "A": [1, 1, 2, 1, 2],
+        ...         "B": [np.nan, 2, 3, 4, 5],
+        ...         "C": [1, 2, 1, 1, 2],
+        ...         "D": [True, False, True, False, True],
+        ...     }
+        ... )
+
+        Groupby one column and return the prod of the remaining columns in
+        each group.
+
+        >>> df.groupby('A').prod().sort_index()
+             B  C  D
+        A
+        1  8.0  2  0
+        2  15.0 2  1
+
+        >>> df.groupby('A').prod(min_count=3).sort_index()
+             B  C   D
+        A
+        1  NaN  2.0  0.0
+        2  NaN NaN  NaN
+        """
+
+        self._validate_agg_columns(numeric_only=numeric_only, function_name="prod")
+
+        groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
+        internal, agg_columns, sdf = self._prepare_reduce(
+            groupkey_names=groupkey_names,
+            accepted_spark_types=(NumericType, BooleanType),
+            bool_to_numeric=True,
+        )
+
+        psdf: DataFrame = DataFrame(internal)
+        if len(psdf._internal.column_labels) > 0:
+
+            stat_exprs = []
+            for label in psdf._internal.column_labels:
+                psser = psdf._psser_for(label)
+                column = psser._dtype_op.nan_to_null(psser).spark.column
+                data_type = psser.spark.data_type
+                aggregating = (
+                    F.product(column).cast("long")
+                    if isinstance(data_type, IntegralType)
+                    else F.product(column)
+                )
+
+                if min_count > 0:
+                    prod_scol = F.when(
+                        F.count(F.when(~F.isnull(column), F.lit(0))) < min_count, F.lit(None)
+                    ).otherwise(aggregating)
+                else:
+                    prod_scol = aggregating
+
+                stat_exprs.append(prod_scol.alias(psser._internal.data_spark_column_names[0]))
+
+            sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
+
+        else:
+            sdf = sdf.select(*groupkey_names).distinct()
+
+        internal = internal.copy(
+            spark_frame=sdf,
+            index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
+            data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
+            data_fields=None,
+        )
+
+        return self._prepare_return(DataFrame(internal))
+
     def all(self, skipna: bool = True) -> FrameLike:
         """
         Returns True if all values in the group are truthful, else False.
@@ -3297,10 +3397,10 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
             if not numeric_only:
                 if has_non_numeric:
                     warnings.warn(
-                        "Dropping invalid columns in DataFrameGroupBy.mean is deprecated. "
+                        "Dropping invalid columns in DataFrameGroupBy.%s is deprecated. "
                         "In a future version, a TypeError will be raised. "
                         "Before calling .%s, select only columns which should be "
-                        "valid for the function." % function_name,
+                        "valid for the function." % (function_name, function_name),
                         FutureWarning,
                     )
 
diff --git a/python/pyspark/pandas/missing/groupby.py b/python/pyspark/pandas/missing/groupby.py
index 3a0e90c2151..1799fac0033 100644
--- a/python/pyspark/pandas/missing/groupby.py
+++ b/python/pyspark/pandas/missing/groupby.py
@@ -61,7 +61,6 @@ class MissingPandasLikeDataFrameGroupBy:
     ohlc = _unsupported_function("ohlc")
     pct_change = _unsupported_function("pct_change")
     pipe = _unsupported_function("pipe")
-    prod = _unsupported_function("prod")
     resample = _unsupported_function("resample")
 
 
@@ -93,5 +92,4 @@ class MissingPandasLikeSeriesGroupBy:
     ohlc = _unsupported_function("ohlc")
     pct_change = _unsupported_function("pct_change")
     pipe = _unsupported_function("pipe")
-    prod = _unsupported_function("prod")
     resample = _unsupported_function("resample")
diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py
index 6e4aa6186c6..4a57a3421df 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -1433,6 +1433,16 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
         with self.assertRaisesRegex(TypeError, "Invalid index"):
             self.psdf.groupby("B").nth("x")
 
+    def test_prod(self):
+        for n in [0, 1, 2, 128, -1, -2, -128]:
+            self._test_stat_func(lambda groupby_obj: groupby_obj.prod(min_count=n))
+            self._test_stat_func(
+                lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n)
+            )
+            self._test_stat_func(
+                lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n)
+            )
+
     def test_cumcount(self):
         pdf = pd.DataFrame(
             {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org