You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/09/05 09:45:25 UTC

[GitHub] [spark] Yikun commented on a diff in pull request #37801: [SPARK-40333][PS] Implement `GroupBy.nth`

Yikun commented on code in PR #37801:
URL: https://github.com/apache/spark/pull/37801#discussion_r962662977


##########
python/pyspark/pandas/groupby.py:
##########
@@ -895,6 +895,89 @@ def sem(col: Column) -> Column:
             bool_to_numeric=True,
         )
 
+    # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
+    def nth(self, n: int) -> FrameLike:
+        """
+        Take the nth row from each group.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        n : int
+            A single nth value for the row
+
+        Examples
+        --------
+
+        >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
+        ...                    'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
+        >>> g = df.groupby('A')
+        >>> g.nth(0)
+             B
+        A
+        1  NaN
+        2  3.0
+        >>> g.nth(1)
+             B
+        A
+        1  2.0
+        2  5.0
+        >>> g.nth(-1)
+             B
+        A
+        1  4.0
+        2  5.0
+
+        See Also
+        --------
+        pyspark.pandas.Series.groupby
+        pyspark.pandas.DataFrame.groupby
+        """
+        groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
+        internal, agg_columns, sdf = self._prepare_reduce(
+            groupkey_names=groupkey_names,
+            accepted_spark_types=None,
+            bool_to_numeric=False,
+        )
+        psdf: DataFrame = DataFrame(internal)
+
+        if len(psdf._internal.column_labels) > 0:
+            window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
+            tmp_row_number_col = "__tmp_row_number_col__"

Review Comment:
   verify_temp_column_name



##########
python/pyspark/pandas/groupby.py:
##########
@@ -895,6 +895,89 @@ def sem(col: Column) -> Column:
             bool_to_numeric=True,
         )
 
+    # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
+    def nth(self, n: int) -> FrameLike:
+        """
+        Take the nth row from each group.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        n : int
+            A single nth value for the row
+
+        Examples
+        --------
+
+        >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
+        ...                    'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
+        >>> g = df.groupby('A')
+        >>> g.nth(0)
+             B
+        A
+        1  NaN
+        2  3.0
+        >>> g.nth(1)
+             B
+        A
+        1  2.0
+        2  5.0
+        >>> g.nth(-1)
+             B
+        A
+        1  4.0
+        2  5.0
+
+        See Also
+        --------
+        pyspark.pandas.Series.groupby
+        pyspark.pandas.DataFrame.groupby
+        """
+        groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
+        internal, agg_columns, sdf = self._prepare_reduce(
+            groupkey_names=groupkey_names,
+            accepted_spark_types=None,
+            bool_to_numeric=False,
+        )
+        psdf: DataFrame = DataFrame(internal)
+
+        if len(psdf._internal.column_labels) > 0:
+            window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
+            tmp_row_number_col = "__tmp_row_number_col__"
+            if n >= 0:
+                sdf = (
+                    psdf._internal.spark_frame.withColumn(
+                        tmp_row_number_col, F.row_number().over(window1)
+                    )
+                    .where(F.col(tmp_row_number_col) == n + 1)
+                    .drop(tmp_row_number_col)
+                )
+            else:
+                window2 = Window.partitionBy(*groupkey_names).rowsBetween(
+                    Window.unboundedPreceding, Window.unboundedFollowing
+                )
+                tmp_group_size_col = "__tmp_group_size_col__"

Review Comment:
   verify_temp_column_name



##########
python/pyspark/pandas/groupby.py:
##########
@@ -895,6 +895,89 @@ def sem(col: Column) -> Column:
             bool_to_numeric=True,
         )
 
+    # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
+    def nth(self, n: int) -> FrameLike:
+        """
+        Take the nth row from each group.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        n : int
+            A single nth value for the row
+
+        Examples
+        --------
+
+        >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
+        ...                    'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
+        >>> g = df.groupby('A')
+        >>> g.nth(0)
+             B
+        A
+        1  NaN
+        2  3.0
+        >>> g.nth(1)
+             B
+        A
+        1  2.0
+        2  5.0
+        >>> g.nth(-1)
+             B
+        A
+        1  4.0
+        2  5.0
+
+        See Also
+        --------
+        pyspark.pandas.Series.groupby
+        pyspark.pandas.DataFrame.groupby
+        """
+        groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
+        internal, agg_columns, sdf = self._prepare_reduce(
+            groupkey_names=groupkey_names,
+            accepted_spark_types=None,
+            bool_to_numeric=False,
+        )
+        psdf: DataFrame = DataFrame(internal)
+
+        if len(psdf._internal.column_labels) > 0:
+            window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
+            tmp_row_number_col = "__tmp_row_number_col__"
+            if n >= 0:

Review Comment:
   validate n with a friendly exception?
   
   ```python
   >>> g.nth('C')
   Traceback (most recent call last):
     File "<stdin>", line 1, in <module>
     File "/Users/yikun/venv/lib/python3.9/site-packages/pandas/core/groupby/groupby.py", line 2304, in nth
       raise TypeError("n needs to be an int or a list/set/tuple of ints")
   TypeError: n needs to be an int or a list/set/tuple of ints
   ```



##########
python/pyspark/pandas/groupby.py:
##########
@@ -895,6 +895,89 @@ def sem(col: Column) -> Column:
             bool_to_numeric=True,
         )
 
+    # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
+    def nth(self, n: int) -> FrameLike:

Review Comment:
   https://github.com/apache/spark/blob/5a03f70aa1c6f709bee6c7d91e4e74bf38498b5a/python/pyspark/sql/functions.py#L3622
   
   Since 3.1, there are a `def nth_value` in spark, but considering negetive index and we are going to support list and slice in the future, I think use `row_number` is right in here, but just FYI if you have other idea.



##########
python/pyspark/pandas/groupby.py:
##########
@@ -895,6 +895,89 @@ def sem(col: Column) -> Column:
             bool_to_numeric=True,
         )
 
+    # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
+    def nth(self, n: int) -> FrameLike:
+        """
+        Take the nth row from each group.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        n : int
+            A single nth value for the row
+

Review Comment:
   ```suggestion
           Returns
           -------
   ```



##########
python/pyspark/pandas/groupby.py:
##########
@@ -895,6 +895,89 @@ def sem(col: Column) -> Column:
             bool_to_numeric=True,
         )
 
+    # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter
+    def nth(self, n: int) -> FrameLike:
+        """
+        Take the nth row from each group.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        n : int
+            A single nth value for the row
+
+        Examples
+        --------
+
+        >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2],
+        ...                    'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B'])
+        >>> g = df.groupby('A')
+        >>> g.nth(0)
+             B
+        A
+        1  NaN
+        2  3.0
+        >>> g.nth(1)
+             B
+        A
+        1  2.0
+        2  5.0
+        >>> g.nth(-1)
+             B
+        A
+        1  4.0
+        2  5.0
+
+        See Also
+        --------
+        pyspark.pandas.Series.groupby
+        pyspark.pandas.DataFrame.groupby
+        """
+        groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
+        internal, agg_columns, sdf = self._prepare_reduce(
+            groupkey_names=groupkey_names,
+            accepted_spark_types=None,
+            bool_to_numeric=False,
+        )
+        psdf: DataFrame = DataFrame(internal)
+
+        if len(psdf._internal.column_labels) > 0:
+            window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME)
+            tmp_row_number_col = "__tmp_row_number_col__"
+            if n >= 0:
+                sdf = (
+                    psdf._internal.spark_frame.withColumn(
+                        tmp_row_number_col, F.row_number().over(window1)
+                    )
+                    .where(F.col(tmp_row_number_col) == n + 1)
+                    .drop(tmp_row_number_col)
+                )
+            else:
+                window2 = Window.partitionBy(*groupkey_names).rowsBetween(
+                    Window.unboundedPreceding, Window.unboundedFollowing
+                )
+                tmp_group_size_col = "__tmp_group_size_col__"
+                sdf = (
+                    psdf._internal.spark_frame.withColumn(
+                        tmp_group_size_col, F.count(F.lit(0)).over(window2)
+                    )
+                    .withColumn(tmp_row_number_col, F.row_number().over(window1))
+                    .where(F.col(tmp_row_number_col) == F.col(tmp_group_size_col) + 1 + n)
+                    .drop(tmp_group_size_col, tmp_row_number_col)
+                )
+        else:
+            sdf = sdf.select(*groupkey_names).distinct()

Review Comment:
   Add a test to cover this? I'm a little fuzzy about this



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org