You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/05/23 03:41:29 UTC

[GitHub] [spark] itholic commented on a diff in pull request #36599: [SPARK-39228][PYTHON][PS] Implement `skipna` of `Series.argmax`

itholic commented on code in PR #36599:
URL: https://github.com/apache/spark/pull/36599#discussion_r878998074


##########
python/pyspark/pandas/tests/test_series.py:
##########
@@ -2987,9 +2987,9 @@ def test_argmin_argmax(self):
             name="Koalas",
         )
         psser = ps.from_pandas(pser)
-
         self.assert_eq(pser.argmin(), psser.argmin())
         self.assert_eq(pser.argmax(), psser.argmax())
+        self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))

Review Comment:
   Can we have one more test for chained operation while we're here ?
   
   e.g.
   ```python
   (pser + 1).argmax(skipna=False)
   ```



##########
python/pyspark/pandas/series.py:
##########
@@ -6255,36 +6261,47 @@ def argmax(self) -> int:
         --------
         Consider dataset containing cereal calories
 
-        >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0,
+        >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0, 'Unknown': np.nan,
         ...                'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 110.0})
-        >>> s  # doctest: +SKIP
+        >>> s
         Corn Flakes              100.0
         Almond Delight           110.0
+        Unknown                    NaN
         Cinnamon Toast Crunch    120.0
         Cocoa Puff               110.0
         dtype: float64
 
-        >>> s.argmax()  # doctest: +SKIP
-        2
+        >>> s.argmax()
+        3
+
+        >>> s.argmax(skipna=False)
+        -1
         """
         sdf = self._internal.spark_frame.select(self.spark.column, NATURAL_ORDER_COLUMN_NAME)
+        seq_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__")
+        sdf = InternalFrame.attach_distributed_sequence_column(
+            sdf,
+            seq_col_name,
+        )
+        scol = scol_for(sdf, self._internal.data_spark_column_names[0])
+
+        if skipna:
+            sdf = sdf.orderBy(scol.desc_nulls_last(), NATURAL_ORDER_COLUMN_NAME)
+        else:
+            sdf = sdf.orderBy(scol.desc_nulls_first(), NATURAL_ORDER_COLUMN_NAME)
+
         max_value = sdf.select(
-            F.max(scol_for(sdf, self._internal.data_spark_column_names[0])),
+            F.first(scol),
             F.first(NATURAL_ORDER_COLUMN_NAME),
         ).head()
+
         if max_value[1] is None:
             raise ValueError("attempt to get argmax of an empty sequence")
         elif max_value[0] is None:
             return -1
-        # We should remember the natural sequence started from 0
-        seq_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__")
-        sdf = InternalFrame.attach_distributed_sequence_column(
-            sdf.drop(NATURAL_ORDER_COLUMN_NAME), seq_col_name
-        )
+
         # If the maximum is achieved in multiple locations, the first row position is returned.
-        return sdf.filter(
-            scol_for(sdf, self._internal.data_spark_column_names[0]) == max_value[0]
-        ).head()[0]
+        return sdf.filter(scol == max_value[0]).head()[0]

Review Comment:
   Yeah, I think maybe we can have utils such as `max_by`, if we only need the first argument of `max_value`
   
   something like:
   
   ```python
   max_value = max_by(sdf, scol)
   ```
   
   But maybe in this scenario, we also need to check the second value of `max_value` to check the validation:
   ```python
           if max_value[1] is None:
               raise ValueError("attempt to get argmax of an empty sequence")
   ```
   
   Or we can use the other name explicitly `max_row` or something, instead of `max_value` for the first obtained `max_value` to avoid confusion.
   
   e.g.
   
   ```python
           max_row = sdf.select(
               F.first(scol),
               F.first(NATURAL_ORDER_COLUMN_NAME),
           ).head()
   
           max_value = max_row[0]
           if max_row[1] is None:
               raise ValueError("attempt to get argmax of an empty sequence")
           elif max_value is None:
               return -1
   
           # If the maximum is achieved in multiple locations, the first row position is returned.
           return sdf.filter(scol == max_value).head()[0]
   ```
   
   WDYT??



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org