You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/09/08 08:04:59 UTC
[spark] branch master updated: [SPARK-40333][PS] Implement `GroupBy.nth`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4d73552abf3 [SPARK-40333][PS] Implement `GroupBy.nth`
4d73552abf3 is described below
commit 4d73552abf39c687a1ef1f742fcecdf7492995af
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Thu Sep 8 16:04:26 2022 +0800
[SPARK-40333][PS] Implement `GroupBy.nth`
### What changes were proposed in this pull request?
Implement `GroupBy.nth`
### Why are the changes needed?
for API coverage
### Does this PR introduce _any_ user-facing change?
yes, new API
```
In [4]: import pyspark.pandas as ps
In [5]: import numpy as np
In [6]: df = ps.DataFrame({'A': [1, 1, 2, 1, 2], 'B': [np.nan, 2, 3, 4, 5], 'C': ['a', 'b', 'c', 'd', 'e']}, columns=['A', 'B', 'C'])
In [7]: df.groupby('A').nth(0)
B C
A
1 NaN a
2 3.0 c
In [8]: df.groupby('A').nth(2)
Out[8]:
B C
A
1 4.0 d
In [9]: df.C.groupby(df.A).nth(-1)
Out[9]:
A
1 d
2 e
Name: C, dtype: object
In [10]: df.C.groupby(df.A).nth(-2)
Out[10]:
A
1 b
2 c
Name: C, dtype: object
```
### How was this patch tested?
added UT
Closes #37801 from zhengruifeng/ps_groupby_nth.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../source/reference/pyspark.pandas/groupby.rst | 1 +
python/pyspark/pandas/groupby.py | 98 ++++++++++++++++++++++
python/pyspark/pandas/missing/groupby.py | 2 -
python/pyspark/pandas/tests/test_groupby.py | 11 +++
4 files changed, 110 insertions(+), 2 deletions(-)
diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst
index b331a49b683..24e3bde91f5 100644
--- a/python/docs/source/reference/pyspark.pandas/groupby.rst
+++ b/python/docs/source/reference/pyspark.pandas/groupby.rst
@@ -73,6 +73,7 @@ Computations / Descriptive Stats
GroupBy.mean
GroupBy.median
GroupBy.min
+ GroupBy.nth
GroupBy.rank
GroupBy.sem
GroupBy.std
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 84a5a3377f3..01163b61375 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -895,6 +895,104 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
bool_to_numeric=True,
)
+ # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
+ def nth(self, n: int) -> FrameLike:
+ """
+ Take the nth row from each group.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ n : int
+ A single nth value for the row
+
+ Returns
+ -------
+ Series or DataFrame
+
+ Notes
+ -----
+ There is a behavior difference between pandas-on-Spark and pandas:
+
+ * when there is no aggregation column, and `n` not equal to 0 or -1,
+ the returned empty dataframe may have an index with different lenght `__len__`.
+
+ Examples
+ --------
+ >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
+ ... 'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
+ >>> g = df.groupby('A')
+ >>> g.nth(0)
+ B
+ A
+ 1 NaN
+ 2 3.0
+ >>> g.nth(1)
+ B
+ A
+ 1 2.0
+ 2 5.0
+ >>> g.nth(-1)
+ B
+ A
+ 1 4.0
+ 2 5.0
+
+ See Also
+ --------
+ pyspark.pandas.Series.groupby
+ pyspark.pandas.DataFrame.groupby
+ """
+ if isinstance(n, slice) or is_list_like(n):
+ raise NotImplementedError("n doesn't support slice or list for now")
+ if not isinstance(n, int):
+ raise TypeError("Invalid index %s" % type(n).__name__)
+
+ groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
+ internal, agg_columns, sdf = self._prepare_reduce(
+ groupkey_names=groupkey_names,
+ accepted_spark_types=None,
+ bool_to_numeric=False,
+ )
+ psdf: DataFrame = DataFrame(internal)
+
+ if len(psdf._internal.column_labels) > 0:
+ window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
+ tmp_row_number_col = verify_temp_column_name(sdf, "__tmp_row_number_col__")
+ if n >= 0:
+ sdf = (
+ psdf._internal.spark_frame.withColumn(
+ tmp_row_number_col, F.row_number().over(window1)
+ )
+ .where(F.col(tmp_row_number_col) == n + 1)
+ .drop(tmp_row_number_col)
+ )
+ else:
+ window2 = Window.partitionBy(*groupkey_names).rowsBetween(
+ Window.unboundedPreceding, Window.unboundedFollowing
+ )
+ tmp_group_size_col = verify_temp_column_name(sdf, "__tmp_group_size_col__")
+ sdf = (
+ psdf._internal.spark_frame.withColumn(
+ tmp_group_size_col, F.count(F.lit(0)).over(window2)
+ )
+ .withColumn(tmp_row_number_col, F.row_number().over(window1))
+ .where(F.col(tmp_row_number_col) == F.col(tmp_group_size_col) + 1 + n)
+ .drop(tmp_group_size_col, tmp_row_number_col)
+ )
+ else:
+ sdf = sdf.select(*groupkey_names).distinct()
+
+ internal = internal.copy(
+ spark_frame=sdf,
+ index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
+ data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
+ data_fields=None,
+ )
+
+ return self._prepare_return(DataFrame(internal))
+
def all(self, skipna: bool = True) -> FrameLike:
"""
Returns True if all values in the group are truthful, else False.
diff --git a/python/pyspark/pandas/missing/groupby.py b/python/pyspark/pandas/missing/groupby.py
index 8ae8a68b5fe..e913835ca72 100644
--- a/python/pyspark/pandas/missing/groupby.py
+++ b/python/pyspark/pandas/missing/groupby.py
@@ -59,7 +59,6 @@ class MissingPandasLikeDataFrameGroupBy:
# Functions
boxplot = _unsupported_function("boxplot")
ngroup = _unsupported_function("ngroup")
- nth = _unsupported_function("nth")
ohlc = _unsupported_function("ohlc")
pct_change = _unsupported_function("pct_change")
pipe = _unsupported_function("pipe")
@@ -93,7 +92,6 @@ class MissingPandasLikeSeriesGroupBy:
aggregate = _unsupported_function("aggregate")
describe = _unsupported_function("describe")
ngroup = _unsupported_function("ngroup")
- nth = _unsupported_function("nth")
ohlc = _unsupported_function("ohlc")
pct_change = _unsupported_function("pct_change")
pipe = _unsupported_function("pipe")
diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py
index e76fcf00faf..1076d867344 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -1380,6 +1380,17 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=None))
self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=True))
+ def test_nth(self):
+ for n in [0, 1, 2, 128, -1, -2, -128]:
+ self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))
+
+ with self.assertRaisesRegex(NotImplementedError, "slice or list"):
+ self.psdf.groupby("B").nth(slice(0, 2))
+ with self.assertRaisesRegex(NotImplementedError, "slice or list"):
+ self.psdf.groupby("B").nth([0, 1, -1])
+ with self.assertRaisesRegex(TypeError, "Invalid index"):
+ self.psdf.groupby("B").nth("x")
+
def test_cumcount(self):
pdf = pd.DataFrame(
{
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org