You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/06/20 23:57:38 UTC

[spark] branch master updated: [SPARK-39534][PS] Series.argmax only needs single pass

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 95fba869169 [SPARK-39534][PS] Series.argmax only needs single pass
95fba869169 is described below

commit 95fba8691696f6c4c00927cbcd8fde81765f0252
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Jun 21 08:57:27 2022 +0900

    [SPARK-39534][PS] Series.argmax only needs single pass
    
    ### What changes were proposed in this pull request?
    compute `Series.argmax ` with one pass
    
    ### Why are the changes needed?
    existing implemation of `Series.argmax` needs two pass on the dataset, the first one is to compute the maximum value, and the second one is to get the index.
    However, they can be computed on one pass.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing UT
    
    Closes #36927 from zhengruifeng/ps_series_argmax_opt.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/pandas/series.py | 20 ++++++++------------
 1 file changed, 8 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 813d27709e4..352e7dd750b 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -6301,22 +6301,18 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         scol = scol_for(sdf, self._internal.data_spark_column_names[0])
 
         if skipna:
-            sdf = sdf.orderBy(scol.desc_nulls_last(), NATURAL_ORDER_COLUMN_NAME)
+            sdf = sdf.orderBy(scol.desc_nulls_last(), NATURAL_ORDER_COLUMN_NAME, seq_col_name)
         else:
-            sdf = sdf.orderBy(scol.desc_nulls_first(), NATURAL_ORDER_COLUMN_NAME)
+            sdf = sdf.orderBy(scol.desc_nulls_first(), NATURAL_ORDER_COLUMN_NAME, seq_col_name)
 
-        max_value = sdf.select(
-            F.first(scol),
-            F.first(NATURAL_ORDER_COLUMN_NAME),
-        ).head()
+        results = sdf.select(scol, seq_col_name).take(1)
 
-        if max_value[1] is None:
+        if len(results) == 0:
             raise ValueError("attempt to get argmax of an empty sequence")
-        elif max_value[0] is None:
-            return -1
-
-        # If the maximum is achieved in multiple locations, the first row position is returned.
-        return sdf.filter(scol == max_value[0]).head()[0]
+        else:
+            max_value = results[0]
+            # If the maximum is achieved in multiple locations, the first row position is returned.
+            return -1 if max_value[0] is None else max_value[1]
 
     def argmin(self) -> int:
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org