You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/09/22 03:48:17 UTC
[spark] branch master updated: [SPARK-43433][PS] Match `GroupBy.nth` behavior to the latest Pandas
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0bf950ee4f7 [SPARK-43433][PS] Match `GroupBy.nth` behavior to the latest Pandas
0bf950ee4f7 is described below
commit 0bf950ee4f77eb1b50d7bd26df330094d44c0804
Author: Haejoon Lee <ha...@databricks.com>
AuthorDate: Fri Sep 22 12:48:03 2023 +0900
[SPARK-43433][PS] Match `GroupBy.nth` behavior to the latest Pandas
### What changes were proposed in this pull request?
This PR proposes to match `GroupBy.nth` behavior to the latest Pandas.
### Why are the changes needed?
To match the behavior of Pandas 2.0.0 and above.
### Does this PR introduce _any_ user-facing change?
**Test DataFrame**
```python
>>> psdf = ps.DataFrame(
... {
... "A": [1, 2, 1, 2],
... "B": [3.1, 4.1, 4.1, 3.1],
... "C": ["a", "b", "b", "a"],
... "D": [True, False, False, True],
... }
... )
>>> psdf
A B C D
0 1 3.1 a True
1 2 4.1 b False
2 1 4.1 b False
3 2 3.1 a True
```
**Before fixing**
```python
>>> psdf.groupby("A").nth(-1)
B C D
A
1 4.1 b False
2 3.1 a True
>>> psdf.groupby("A")[["C"]].nth(-1)
C
A
1 b
2 a
>>> psdf.groupby("A")["B"].nth(-1)
A
1 4.1
2 3.1
Name: B, dtype: float64
```
**After fixing**
```python
>>> psdf.groupby("A").nth(-1)
A B C D
2 1 4.1 b False
3 2 3.1 a True
>>> psdf.groupby("A")[["C"]].nth(-1)
C
2 b
3 a
>>> psdf.groupby("A")["B"].nth(-1)
2 4.1
3 3.1
Name: B, dtype: float64
```
### How was this patch tested?
Enabling the existing tests & updating the doctests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #42994 from itholic/SPARK-43552.
Authored-by: Haejoon Lee <ha...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/pandas/groupby.py | 89 +++++++++++++++---------
python/pyspark/pandas/tests/groupby/test_stat.py | 4 --
2 files changed, 58 insertions(+), 35 deletions(-)
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index c7924fa3345..7bd64376152 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -143,7 +143,9 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
pass
@abstractmethod
- def _handle_output(self, psdf: DataFrame) -> FrameLike:
+ def _handle_output(
+ self, psdf: DataFrame, agg_column_names: Optional[List[str]] = None
+ ) -> FrameLike:
pass
# TODO: Series support is not implemented yet.
@@ -1091,24 +1093,22 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
Examples
--------
+ >>> import numpy as np
>>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
... 'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
>>> g = df.groupby('A')
>>> g.nth(0)
- B
- A
- 1 NaN
- 2 3.0
+ A B
+ 0 1 NaN
+ 2 2 3.0
>>> g.nth(1)
- B
- A
- 1 2.0
- 2 5.0
+ A B
+ 1 1 2.0
+ 4 2 5.0
>>> g.nth(-1)
- B
- A
- 1 4.0
- 2 5.0
+ A B
+ 3 1 4.0
+ 4 2 5.0
See Also
--------
@@ -1120,13 +1120,10 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
if not isinstance(n, int):
raise TypeError("Invalid index %s" % type(n).__name__)
- groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
- internal, agg_columns, sdf = self._prepare_reduce(
- groupkey_names=groupkey_names,
- accepted_spark_types=None,
- bool_to_numeric=False,
- )
- psdf: DataFrame = DataFrame(internal)
+ groupkey_names: List[str] = [str(groupkey.name) for groupkey in self._groupkeys]
+ psdf = self._psdf
+ internal = psdf._internal
+ sdf = internal.spark_frame
if len(psdf._internal.column_labels) > 0:
window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
@@ -1155,14 +1152,32 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
else:
sdf = sdf.select(*groupkey_names).distinct()
- internal = internal.copy(
+ agg_columns = []
+ if not self._agg_columns_selected:
+ for psser in self._groupkeys:
+ agg_columns.append(psser)
+ for psser in self._agg_columns:
+ agg_columns.append(psser)
+ internal = InternalFrame(
spark_frame=sdf,
- index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
- data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
- data_fields=None,
+ index_spark_columns=[scol_for(sdf, col) for col in internal.index_spark_column_names],
+ index_names=internal.index_names,
+ index_fields=internal.index_fields,
+ data_spark_columns=[
+ scol_for(sdf, psser._internal.data_spark_column_names[0]) for psser in agg_columns
+ ],
+ column_labels=[psser._column_label for psser in agg_columns],
+ data_fields=[psser._internal.data_fields[0] for psser in agg_columns],
+ column_label_names=self._psdf._internal.column_label_names,
)
- return self._prepare_return(DataFrame(internal))
+ agg_column_names = (
+ [str(agg_column.name) for agg_column in self._agg_columns]
+ if self._agg_columns_selected
+ else None
+ )
+
+ return self._prepare_return(DataFrame(internal), agg_column_names=agg_column_names)
def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameLike:
"""
@@ -3595,7 +3610,9 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
return self._prepare_return(psdf)
- def _prepare_return(self, psdf: DataFrame) -> FrameLike:
+ def _prepare_return(
+ self, psdf: DataFrame, agg_column_names: Optional[List[str]] = None
+ ) -> FrameLike:
if self._dropna:
psdf = DataFrame(
psdf._internal.with_new_sdf(
@@ -3622,7 +3639,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
psdf = psdf.reset_index(level=should_drop_index, drop=True)
if len(should_drop_index) < len(self._groupkeys):
psdf = psdf.reset_index()
- return self._handle_output(psdf)
+ return self._handle_output(psdf, agg_column_names)
def _prepare_reduce(
self,
@@ -3864,8 +3881,13 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
internal = internal.resolved_copy
return DataFrame(internal)
- def _handle_output(self, psdf: DataFrame) -> DataFrame:
- return psdf
+ def _handle_output(
+ self, psdf: DataFrame, agg_column_names: Optional[List[str]] = None
+ ) -> DataFrame:
+ if agg_column_names is not None:
+ return psdf[agg_column_names]
+ else:
+ return psdf
# TODO: Implement 'percentiles', 'include', and 'exclude' arguments.
# TODO: Add ``DataFrame.select_dtypes`` to See Also when 'include'
@@ -4016,8 +4038,13 @@ class SeriesGroupBy(GroupBy[Series]):
else:
return psser.copy()
- def _handle_output(self, psdf: DataFrame) -> Series:
- return first_series(psdf).rename(self._psser.name)
+ def _handle_output(
+ self, psdf: DataFrame, agg_column_names: Optional[List[str]] = None
+ ) -> Series:
+ if agg_column_names is not None:
+ return psdf[agg_column_names[0]].rename(self._psser.name)
+ else:
+ return first_series(psdf).rename(self._psser.name)
def agg(self, *args: Any, **kwargs: Any) -> None:
return MissingPandasLikeSeriesGroupBy.agg(self, *args, **kwargs)
diff --git a/python/pyspark/pandas/tests/groupby/test_stat.py b/python/pyspark/pandas/tests/groupby/test_stat.py
index bc78e02c90e..695d079db49 100644
--- a/python/pyspark/pandas/tests/groupby/test_stat.py
+++ b/python/pyspark/pandas/tests/groupby/test_stat.py
@@ -244,10 +244,6 @@ class GroupbyStatMixin:
psdf.groupby("A").last(min_count=2).sort_index(),
)
- @unittest.skipIf(
- LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
- "TODO(SPARK-43552): Enable GroupByTests.test_nth for pandas 2.0.0.",
- )
def test_nth(self):
for n in [0, 1, 2, 128, -1, -2, -128]:
self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org