You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/09/26 00:00:55 UTC
[spark] branch master updated: [SPARK-40334][PS] Implement `GroupBy.prod`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c01e524c298 [SPARK-40334][PS] Implement `GroupBy.prod`
c01e524c298 is described below
commit c01e524c2985be06027191e51bb94d9ee5637d40
Author: artsiomyudovin <a....@gmail.com>
AuthorDate: Mon Sep 26 08:00:20 2022 +0800
[SPARK-40334][PS] Implement `GroupBy.prod`
### What changes were proposed in this pull request?
Implement `GroupBy.prod`
### Why are the changes needed?
for API coverage
### Does this PR introduce _any_ user-facing change?
yes, the new API
```
df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
'B': [np.nan, 2, 3, 4, 5],
'C': [1, 2, 1, 1, 2],
'D': [True, False, True, False, True]})
Groupby one column and return the prod of the remaining columns in
each group.
df.groupby('A').prod()
B C D
A
1 8.0 2 0
2 15.0 2 11
df.groupby('A').prod(min_count=3)
B C D
A
1 NaN 2 0
2 NaN NaN NaN
```
### How was this patch tested?
added UT
Closes #37923 from ayudovin/ps_group_by_prod.
Authored-by: artsiomyudovin <a....@gmail.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../source/reference/pyspark.pandas/groupby.rst | 1 +
python/pyspark/pandas/groupby.py | 106 ++++++++++++++++++++-
python/pyspark/pandas/missing/groupby.py | 2 -
python/pyspark/pandas/tests/test_groupby.py | 10 ++
4 files changed, 114 insertions(+), 5 deletions(-)
diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst
index 4c29964966c..da1579fd723 100644
--- a/python/docs/source/reference/pyspark.pandas/groupby.rst
+++ b/python/docs/source/reference/pyspark.pandas/groupby.rst
@@ -74,6 +74,7 @@ Computations / Descriptive Stats
GroupBy.median
GroupBy.min
GroupBy.nth
+ GroupBy.prod
GroupBy.rank
GroupBy.sem
GroupBy.std
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 2e5c9ab219a..6d36cfecce6 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -18,7 +18,6 @@
"""
A wrapper for GroupedData to behave similar to pandas GroupBy.
"""
-
from abc import ABCMeta, abstractmethod
import inspect
from collections import defaultdict, namedtuple
@@ -63,6 +62,7 @@ from pyspark.sql.types import (
StructField,
StructType,
StringType,
+ IntegralType,
)
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
@@ -1055,6 +1055,106 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
return self._prepare_return(DataFrame(internal))
+ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameLike:
+ """
+ Compute prod of groups.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ numeric_only : bool, default False
+ Include only float, int, boolean columns. If None, will attempt to use
+ everything, then use only numeric data.
+
+ min_count: int, default 0
+ The required number of valid values to perform the operation.
+ If fewer than min_count non-NA values are present the result will be NA.
+
+ Returns
+ -------
+ Series or DataFrame
+ Computed prod of values within each group.
+
+ See Also
+ --------
+ pyspark.pandas.Series.groupby
+ pyspark.pandas.DataFrame.groupby
+
+ Examples
+ --------
+ >>> import numpy as np
+ >>> df = ps.DataFrame(
+ ... {
+ ... "A": [1, 1, 2, 1, 2],
+ ... "B": [np.nan, 2, 3, 4, 5],
+ ... "C": [1, 2, 1, 1, 2],
+ ... "D": [True, False, True, False, True],
+ ... }
+ ... )
+
+ Groupby one column and return the prod of the remaining columns in
+ each group.
+
+ >>> df.groupby('A').prod().sort_index()
+ B C D
+ A
+ 1 8.0 2 0
+ 2 15.0 2 1
+
+ >>> df.groupby('A').prod(min_count=3).sort_index()
+ B C D
+ A
+ 1 NaN 2.0 0.0
+ 2 NaN NaN NaN
+ """
+
+ self._validate_agg_columns(numeric_only=numeric_only, function_name="prod")
+
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
+ internal, agg_columns, sdf = self._prepare_reduce(
+ groupkey_names=groupkey_names,
+ accepted_spark_types=(NumericType, BooleanType),
+ bool_to_numeric=True,
+ )
+
+ psdf: DataFrame = DataFrame(internal)
+ if len(psdf._internal.column_labels) > 0:
+
+ stat_exprs = []
+ for label in psdf._internal.column_labels:
+ psser = psdf._psser_for(label)
+ column = psser._dtype_op.nan_to_null(psser).spark.column
+ data_type = psser.spark.data_type
+ aggregating = (
+ F.product(column).cast("long")
+ if isinstance(data_type, IntegralType)
+ else F.product(column)
+ )
+
+ if min_count > 0:
+ prod_scol = F.when(
+ F.count(F.when(~F.isnull(column), F.lit(0))) < min_count, F.lit(None)
+ ).otherwise(aggregating)
+ else:
+ prod_scol = aggregating
+
+ stat_exprs.append(prod_scol.alias(psser._internal.data_spark_column_names[0]))
+
+ sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
+
+ else:
+ sdf = sdf.select(*groupkey_names).distinct()
+
+ internal = internal.copy(
+ spark_frame=sdf,
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
+ data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
+ data_fields=None,
+ )
+
+ return self._prepare_return(DataFrame(internal))
+
def all(self, skipna: bool = True) -> FrameLike:
"""
Returns True if all values in the group are truthful, else False.
@@ -3297,10 +3397,10 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
if not numeric_only:
if has_non_numeric:
warnings.warn(
- "Dropping invalid columns in DataFrameGroupBy.mean is deprecated. "
+ "Dropping invalid columns in DataFrameGroupBy.%s is deprecated. "
"In a future version, a TypeError will be raised. "
"Before calling .%s, select only columns which should be "
- "valid for the function." % function_name,
+ "valid for the function." % (function_name, function_name),
FutureWarning,
)
diff --git a/python/pyspark/pandas/missing/groupby.py b/python/pyspark/pandas/missing/groupby.py
index 3a0e90c2151..1799fac0033 100644
--- a/python/pyspark/pandas/missing/groupby.py
+++ b/python/pyspark/pandas/missing/groupby.py
@@ -61,7 +61,6 @@ class MissingPandasLikeDataFrameGroupBy:
ohlc = _unsupported_function("ohlc")
pct_change = _unsupported_function("pct_change")
pipe = _unsupported_function("pipe")
- prod = _unsupported_function("prod")
resample = _unsupported_function("resample")
@@ -93,5 +92,4 @@ class MissingPandasLikeSeriesGroupBy:
ohlc = _unsupported_function("ohlc")
pct_change = _unsupported_function("pct_change")
pipe = _unsupported_function("pipe")
- prod = _unsupported_function("prod")
resample = _unsupported_function("resample")
diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py
index 6e4aa6186c6..4a57a3421df 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -1433,6 +1433,16 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
with self.assertRaisesRegex(TypeError, "Invalid index"):
self.psdf.groupby("B").nth("x")
+ def test_prod(self):
+ for n in [0, 1, 2, 128, -1, -2, -128]:
+ self._test_stat_func(lambda groupby_obj: groupby_obj.prod(min_count=n))
+ self._test_stat_func(
+ lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n)
+ )
+ self._test_stat_func(
+ lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n)
+ )
+
def test_cumcount(self):
pdf = pd.DataFrame(
{
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org