You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/09/22 03:48:17 UTC

[spark] branch master updated: [SPARK-43433][PS] Match `GroupBy.nth` behavior to the latest Pandas

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0bf950ee4f7 [SPARK-43433][PS] Match `GroupBy.nth` behavior to the latest Pandas
0bf950ee4f7 is described below

commit 0bf950ee4f77eb1b50d7bd26df330094d44c0804
Author: Haejoon Lee <ha...@databricks.com>
AuthorDate: Fri Sep 22 12:48:03 2023 +0900

    [SPARK-43433][PS] Match `GroupBy.nth` behavior to the latest Pandas
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to match `GroupBy.nth` behavior to the latest Pandas.
    
    ### Why are the changes needed?
    
    To match the behavior of Pandas 2.0.0 and above.
    
    ### Does this PR introduce _any_ user-facing change?
    **Test DataFrame**
    ```python
    >>> psdf = ps.DataFrame(
    ...     {
    ...         "A": [1, 2, 1, 2],
    ...         "B": [3.1, 4.1, 4.1, 3.1],
    ...         "C": ["a", "b", "b", "a"],
    ...         "D": [True, False, False, True],
    ...     }
    ... )
    >>> psdf
       A    B  C      D
    0  1  3.1  a   True
    1  2  4.1  b  False
    2  1  4.1  b  False
    3  2  3.1  a   True
    ```
    **Before fixing**
    ```python
    >>> psdf.groupby("A").nth(-1)
         B  C      D
    A
    1  4.1  b  False
    2  3.1  a   True
    >>> psdf.groupby("A")[["C"]].nth(-1)
       C
    A
    1  b
    2  a
    >>> psdf.groupby("A")["B"].nth(-1)
    A
    1    4.1
    2    3.1
    Name: B, dtype: float64
    ```
    **After fixing**
    ```python
    >>> psdf.groupby("A").nth(-1)
       A    B  C      D
    2  1  4.1  b  False
    3  2  3.1  a   True
    >>> psdf.groupby("A")[["C"]].nth(-1)
       C
    2  b
    3  a
    >>> psdf.groupby("A")["B"].nth(-1)
    2    4.1
    3    3.1
    Name: B, dtype: float64
    ```
    
    ### How was this patch tested?
    
    Enabling the existing tests & updating the doctests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42994 from itholic/SPARK-43552.
    
    Authored-by: Haejoon Lee <ha...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/pandas/groupby.py                 | 89 +++++++++++++++---------
 python/pyspark/pandas/tests/groupby/test_stat.py |  4 --
 2 files changed, 58 insertions(+), 35 deletions(-)

diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index c7924fa3345..7bd64376152 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -143,7 +143,9 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
         pass
 
     @abstractmethod
-    def _handle_output(self, psdf: DataFrame) -> FrameLike:
+    def _handle_output(
+        self, psdf: DataFrame, agg_column_names: Optional[List[str]] = None
+    ) -> FrameLike:
         pass
 
     # TODO: Series support is not implemented yet.
@@ -1091,24 +1093,22 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
 
         Examples
         --------
+        >>> import numpy as np
         >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
         ...                    'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
         >>> g = df.groupby('A')
         >>> g.nth(0)
-             B
-        A
-        1  NaN
-        2  3.0
+           A    B
+        0  1  NaN
+        2  2  3.0
         >>> g.nth(1)
-             B
-        A
-        1  2.0
-        2  5.0
+           A    B
+        1  1  2.0
+        4  2  5.0
         >>> g.nth(-1)
-             B
-        A
-        1  4.0
-        2  5.0
+           A    B
+        3  1  4.0
+        4  2  5.0
 
         See Also
         --------
@@ -1120,13 +1120,10 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
         if not isinstance(n, int):
             raise TypeError("Invalid index %s" % type(n).__name__)
 
-        groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
-        internal, agg_columns, sdf = self._prepare_reduce(
-            groupkey_names=groupkey_names,
-            accepted_spark_types=None,
-            bool_to_numeric=False,
-        )
-        psdf: DataFrame = DataFrame(internal)
+        groupkey_names: List[str] = [str(groupkey.name) for groupkey in self._groupkeys]
+        psdf = self._psdf
+        internal = psdf._internal
+        sdf = internal.spark_frame
 
         if len(psdf._internal.column_labels) > 0:
             window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
@@ -1155,14 +1152,32 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
         else:
             sdf = sdf.select(*groupkey_names).distinct()
 
-        internal = internal.copy(
+        agg_columns = []
+        if not self._agg_columns_selected:
+            for psser in self._groupkeys:
+                agg_columns.append(psser)
+        for psser in self._agg_columns:
+            agg_columns.append(psser)
+        internal = InternalFrame(
             spark_frame=sdf,
-            index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
-            data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
-            data_fields=None,
+            index_spark_columns=[scol_for(sdf, col) for col in internal.index_spark_column_names],
+            index_names=internal.index_names,
+            index_fields=internal.index_fields,
+            data_spark_columns=[
+                scol_for(sdf, psser._internal.data_spark_column_names[0]) for psser in agg_columns
+            ],
+            column_labels=[psser._column_label for psser in agg_columns],
+            data_fields=[psser._internal.data_fields[0] for psser in agg_columns],
+            column_label_names=self._psdf._internal.column_label_names,
         )
 
-        return self._prepare_return(DataFrame(internal))
+        agg_column_names = (
+            [str(agg_column.name) for agg_column in self._agg_columns]
+            if self._agg_columns_selected
+            else None
+        )
+
+        return self._prepare_return(DataFrame(internal), agg_column_names=agg_column_names)
 
     def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameLike:
         """
@@ -3595,7 +3610,9 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
 
         return self._prepare_return(psdf)
 
-    def _prepare_return(self, psdf: DataFrame) -> FrameLike:
+    def _prepare_return(
+        self, psdf: DataFrame, agg_column_names: Optional[List[str]] = None
+    ) -> FrameLike:
         if self._dropna:
             psdf = DataFrame(
                 psdf._internal.with_new_sdf(
@@ -3622,7 +3639,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
                 psdf = psdf.reset_index(level=should_drop_index, drop=True)
             if len(should_drop_index) < len(self._groupkeys):
                 psdf = psdf.reset_index()
-        return self._handle_output(psdf)
+        return self._handle_output(psdf, agg_column_names)
 
     def _prepare_reduce(
         self,
@@ -3864,8 +3881,13 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
             internal = internal.resolved_copy
         return DataFrame(internal)
 
-    def _handle_output(self, psdf: DataFrame) -> DataFrame:
-        return psdf
+    def _handle_output(
+        self, psdf: DataFrame, agg_column_names: Optional[List[str]] = None
+    ) -> DataFrame:
+        if agg_column_names is not None:
+            return psdf[agg_column_names]
+        else:
+            return psdf
 
     # TODO: Implement 'percentiles', 'include', and 'exclude' arguments.
     # TODO: Add ``DataFrame.select_dtypes`` to See Also when 'include'
@@ -4016,8 +4038,13 @@ class SeriesGroupBy(GroupBy[Series]):
         else:
             return psser.copy()
 
-    def _handle_output(self, psdf: DataFrame) -> Series:
-        return first_series(psdf).rename(self._psser.name)
+    def _handle_output(
+        self, psdf: DataFrame, agg_column_names: Optional[List[str]] = None
+    ) -> Series:
+        if agg_column_names is not None:
+            return psdf[agg_column_names[0]].rename(self._psser.name)
+        else:
+            return first_series(psdf).rename(self._psser.name)
 
     def agg(self, *args: Any, **kwargs: Any) -> None:
         return MissingPandasLikeSeriesGroupBy.agg(self, *args, **kwargs)
diff --git a/python/pyspark/pandas/tests/groupby/test_stat.py b/python/pyspark/pandas/tests/groupby/test_stat.py
index bc78e02c90e..695d079db49 100644
--- a/python/pyspark/pandas/tests/groupby/test_stat.py
+++ b/python/pyspark/pandas/tests/groupby/test_stat.py
@@ -244,10 +244,6 @@ class GroupbyStatMixin:
             psdf.groupby("A").last(min_count=2).sort_index(),
         )
 
-    @unittest.skipIf(
-        LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
-        "TODO(SPARK-43552): Enable GroupByTests.test_nth for pandas 2.0.0.",
-    )
     def test_nth(self):
         for n in [0, 1, 2, 128, -1, -2, -128]:
             self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org