You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/09/08 08:04:59 UTC

[spark] branch master updated: [SPARK-40333][PS] Implement `GroupBy.nth`

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d73552abf3 [SPARK-40333][PS] Implement `GroupBy.nth`
4d73552abf3 is described below

commit 4d73552abf39c687a1ef1f742fcecdf7492995af
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Thu Sep 8 16:04:26 2022 +0800

    [SPARK-40333][PS] Implement `GroupBy.nth`
    
    ### What changes were proposed in this pull request?
    Implement `GroupBy.nth`
    
    ### Why are the changes needed?
    for API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new API
    
    ```
    In [4]: import pyspark.pandas as ps
    
    In [5]: import numpy as np
    
    In [6]: df = ps.DataFrame({'A': [1, 1, 2, 1, 2], 'B': [np.nan, 2, 3, 4, 5], 'C': ['a', 'b', 'c', 'd', 'e']}, columns=['A', 'B', 'C'])
    
    In [7]: df.groupby('A').nth(0)
    
         B  C
    A
    1  NaN  a
    2  3.0  c
    
    In [8]: df.groupby('A').nth(2)
    Out[8]:
         B  C
    A
    1  4.0  d
    
    In [9]: df.C.groupby(df.A).nth(-1)
    Out[9]:
    A
    1    d
    2    e
    Name: C, dtype: object
    
    In [10]: df.C.groupby(df.A).nth(-2)
    Out[10]:
    A
    1    b
    2    c
    Name: C, dtype: object
    ```
    
    ### How was this patch tested?
    added UT
    
    Closes #37801 from zhengruifeng/ps_groupby_nth.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../source/reference/pyspark.pandas/groupby.rst    |  1 +
 python/pyspark/pandas/groupby.py                   | 98 ++++++++++++++++++++++
 python/pyspark/pandas/missing/groupby.py           |  2 -
 python/pyspark/pandas/tests/test_groupby.py        | 11 +++
 4 files changed, 110 insertions(+), 2 deletions(-)

diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst
index b331a49b683..24e3bde91f5 100644
--- a/python/docs/source/reference/pyspark.pandas/groupby.rst
+++ b/python/docs/source/reference/pyspark.pandas/groupby.rst
@@ -73,6 +73,7 @@ Computations / Descriptive Stats
    GroupBy.mean
    GroupBy.median
    GroupBy.min
+   GroupBy.nth
    GroupBy.rank
    GroupBy.sem
    GroupBy.std
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 84a5a3377f3..01163b61375 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -895,6 +895,104 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
             bool_to_numeric=True,
         )
 
+    # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
+    def nth(self, n: int) -> FrameLike:
+        """
+        Take the nth row from each group.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        n : int
+            A single nth value for the row
+
+        Returns
+        -------
+        Series or DataFrame
+
+        Notes
+        -----
+        There is a behavior difference between pandas-on-Spark and pandas:
+
+        * when there is no aggregation column, and `n` not equal to 0 or -1,
+            the returned empty dataframe may have an index with different lenght `__len__`.
+
+        Examples
+        --------
+        >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
+        ...                    'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
+        >>> g = df.groupby('A')
+        >>> g.nth(0)
+             B
+        A
+        1  NaN
+        2  3.0
+        >>> g.nth(1)
+             B
+        A
+        1  2.0
+        2  5.0
+        >>> g.nth(-1)
+             B
+        A
+        1  4.0
+        2  5.0
+
+        See Also
+        --------
+        pyspark.pandas.Series.groupby
+        pyspark.pandas.DataFrame.groupby
+        """
+        if isinstance(n, slice) or is_list_like(n):
+            raise NotImplementedError("n doesn't support slice or list for now")
+        if not isinstance(n, int):
+            raise TypeError("Invalid index %s" % type(n).__name__)
+
+        groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
+        internal, agg_columns, sdf = self._prepare_reduce(
+            groupkey_names=groupkey_names,
+            accepted_spark_types=None,
+            bool_to_numeric=False,
+        )
+        psdf: DataFrame = DataFrame(internal)
+
+        if len(psdf._internal.column_labels) > 0:
+            window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
+            tmp_row_number_col = verify_temp_column_name(sdf, "__tmp_row_number_col__")
+            if n >= 0:
+                sdf = (
+                    psdf._internal.spark_frame.withColumn(
+                        tmp_row_number_col, F.row_number().over(window1)
+                    )
+                    .where(F.col(tmp_row_number_col) == n + 1)
+                    .drop(tmp_row_number_col)
+                )
+            else:
+                window2 = Window.partitionBy(*groupkey_names).rowsBetween(
+                    Window.unboundedPreceding, Window.unboundedFollowing
+                )
+                tmp_group_size_col = verify_temp_column_name(sdf, "__tmp_group_size_col__")
+                sdf = (
+                    psdf._internal.spark_frame.withColumn(
+                        tmp_group_size_col, F.count(F.lit(0)).over(window2)
+                    )
+                    .withColumn(tmp_row_number_col, F.row_number().over(window1))
+                    .where(F.col(tmp_row_number_col) == F.col(tmp_group_size_col) + 1 + n)
+                    .drop(tmp_group_size_col, tmp_row_number_col)
+                )
+        else:
+            sdf = sdf.select(*groupkey_names).distinct()
+
+        internal = internal.copy(
+            spark_frame=sdf,
+            index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
+            data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
+            data_fields=None,
+        )
+
+        return self._prepare_return(DataFrame(internal))
+
     def all(self, skipna: bool = True) -> FrameLike:
         """
         Returns True if all values in the group are truthful, else False.
diff --git a/python/pyspark/pandas/missing/groupby.py b/python/pyspark/pandas/missing/groupby.py
index 8ae8a68b5fe..e913835ca72 100644
--- a/python/pyspark/pandas/missing/groupby.py
+++ b/python/pyspark/pandas/missing/groupby.py
@@ -59,7 +59,6 @@ class MissingPandasLikeDataFrameGroupBy:
     # Functions
     boxplot = _unsupported_function("boxplot")
     ngroup = _unsupported_function("ngroup")
-    nth = _unsupported_function("nth")
     ohlc = _unsupported_function("ohlc")
     pct_change = _unsupported_function("pct_change")
     pipe = _unsupported_function("pipe")
@@ -93,7 +92,6 @@ class MissingPandasLikeSeriesGroupBy:
     aggregate = _unsupported_function("aggregate")
     describe = _unsupported_function("describe")
     ngroup = _unsupported_function("ngroup")
-    nth = _unsupported_function("nth")
     ohlc = _unsupported_function("ohlc")
     pct_change = _unsupported_function("pct_change")
     pipe = _unsupported_function("pipe")
diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py
index e76fcf00faf..1076d867344 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -1380,6 +1380,17 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
         self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=None))
         self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=True))
 
+    def test_nth(self):
+        for n in [0, 1, 2, 128, -1, -2, -128]:
+            self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))
+
+        with self.assertRaisesRegex(NotImplementedError, "slice or list"):
+            self.psdf.groupby("B").nth(slice(0, 2))
+        with self.assertRaisesRegex(NotImplementedError, "slice or list"):
+            self.psdf.groupby("B").nth([0, 1, -1])
+        with self.assertRaisesRegex(TypeError, "Invalid index"):
+            self.psdf.groupby("B").nth("x")
+
     def test_cumcount(self):
         pdf = pd.DataFrame(
             {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org