You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/08/01 01:47:58 UTC

[spark] branch master updated: [SPARK-39907][PS] Implement axis and skipna of Series.argmin

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3021eb197f4 [SPARK-39907][PS] Implement axis and skipna of Series.argmin
3021eb197f4 is described below

commit 3021eb197f488bc5342549a720c6348e65d160ee
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Mon Aug 1 10:47:34 2022 +0900

    [SPARK-39907][PS] Implement axis and skipna of Series.argmin
    
    ### What changes were proposed in this pull request?
    
    1, Implement axis and skipna of Series.argmin
    2, compute the argmin on single pass, like `argmax`
    
    ### Why are the changes needed?
    
    to add missing parameter
    after this change, the underlying implements of argmax and argmin are almost the same
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new parameter
    
    ### How was this patch tested?
    added tests
    
    Closes #37328 from zhengruifeng/ps_update_argmin.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/pandas/series.py            | 48 +++++++++++++++++++-----------
 python/pyspark/pandas/tests/test_series.py | 13 ++++++++
 2 files changed, 43 insertions(+), 18 deletions(-)

diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 9bc1d7675e1..405dbbf23bf 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -6266,11 +6266,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
 
         Parameters
         ----------
-        axis : {{None}}
+        axis : None
             Dummy argument for consistency with Series.
         skipna : bool, default True
-            Exclude NA/null values. If the entire Series is NA, the result
-            will be NA.
+            Exclude NA/null values.
 
         Returns
         -------
@@ -6322,13 +6321,20 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
             # If the maximum is achieved in multiple locations, the first row position is returned.
             return -1 if max_value[0] is None else max_value[1]
 
-    def argmin(self) -> int:
+    def argmin(self, axis: Axis = None, skipna: bool = True) -> int:
         """
         Return int position of the smallest value in the Series.
 
         If the minimum is achieved in multiple locations,
         the first row position is returned.
 
+        Parameters
+        ----------
+        axis : None
+            Dummy argument for consistency with Series.
+        skipna : bool, default True
+            Exclude NA/null values.
+
         Returns
         -------
         int
@@ -6350,24 +6356,30 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         >>> s.argmin()  # doctest: +SKIP
         0
         """
+        axis = validate_axis(axis, none_axis=0)
+        if axis == 1:
+            raise ValueError("axis can only be 0 or 'index'")
         sdf = self._internal.spark_frame.select(self.spark.column, NATURAL_ORDER_COLUMN_NAME)
-        min_value = sdf.select(
-            F.min(scol_for(sdf, self._internal.data_spark_column_names[0])),
-            F.first(NATURAL_ORDER_COLUMN_NAME),
-        ).head()
-        if min_value[1] is None:
-            raise ValueError("attempt to get argmin of an empty sequence")
-        elif min_value[0] is None:
-            return -1
-        # We should remember the natural sequence started from 0
         seq_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__")
         sdf = InternalFrame.attach_distributed_sequence_column(
-            sdf.drop(NATURAL_ORDER_COLUMN_NAME), seq_col_name
+            sdf,
+            seq_col_name,
         )
-        # If the minimum is achieved in multiple locations, the first row position is returned.
-        return sdf.filter(
-            scol_for(sdf, self._internal.data_spark_column_names[0]) == min_value[0]
-        ).head()[0]
+        scol = scol_for(sdf, self._internal.data_spark_column_names[0])
+
+        if skipna:
+            sdf = sdf.orderBy(scol.asc_nulls_last(), NATURAL_ORDER_COLUMN_NAME, seq_col_name)
+        else:
+            sdf = sdf.orderBy(scol.asc_nulls_first(), NATURAL_ORDER_COLUMN_NAME, seq_col_name)
+
+        results = sdf.select(scol, seq_col_name).take(1)
+
+        if len(results) == 0:
+            raise ValueError("attempt to get argmin of an empty sequence")
+        else:
+            min_value = results[0]
+            # If the maximum is achieved in multiple locations, the first row position is returned.
+            return -1 if min_value[0] is None else min_value[1]
 
     def compare(
         self, other: "Series", keep_shape: bool = False, keep_equal: bool = False
diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py
index 5012dee785e..fe088bf54eb 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -3012,8 +3012,12 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
         psser = ps.from_pandas(pser)
         self.assert_eq(pser.argmin(), psser.argmin())
         self.assert_eq(pser.argmax(), psser.argmax())
+        self.assert_eq(pser.argmin(skipna=False), psser.argmin(skipna=False))
+        self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))
         self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))
         self.assert_eq((pser + 1).argmax(skipna=False), (psser + 1).argmax(skipna=False))
+        self.assert_eq(pser.argmin(skipna=False), psser.argmin(skipna=False))
+        self.assert_eq((pser + 1).argmin(skipna=False), (psser + 1).argmin(skipna=False))
 
         # MultiIndex
         pser.index = pd.MultiIndex.from_tuples(
@@ -3024,6 +3028,13 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
         self.assert_eq(pser.argmax(), psser.argmax())
         self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))
 
+        pser2 = pd.Series([np.NaN, 1.0, 2.0, np.NaN])
+        psser2 = ps.from_pandas(pser2)
+        self.assert_eq(pser2.argmin(), psser2.argmin())
+        self.assert_eq(pser2.argmax(), psser2.argmax())
+        self.assert_eq(pser2.argmin(skipna=False), psser2.argmin(skipna=False))
+        self.assert_eq(pser2.argmax(skipna=False), psser2.argmax(skipna=False))
+
         # Null Series
         self.assert_eq(pd.Series([np.nan]).argmin(), ps.Series([np.nan]).argmin())
         self.assert_eq(pd.Series([np.nan]).argmax(), ps.Series([np.nan]).argmax())
@@ -3037,6 +3048,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
             ps.Series([]).argmax()
         with self.assertRaisesRegex(ValueError, "axis can only be 0 or 'index'"):
             psser.argmax(axis=1)
+        with self.assertRaisesRegex(ValueError, "axis can only be 0 or 'index'"):
+            psser.argmin(axis=1)
 
     def test_backfill(self):
         pdf = pd.DataFrame({"x": [np.nan, 2, 3, 4, np.nan, 6]})


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org