You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/08/01 01:47:58 UTC
[spark] branch master updated: [SPARK-39907][PS] Implement axis and skipna of Series.argmin
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3021eb197f4 [SPARK-39907][PS] Implement axis and skipna of Series.argmin
3021eb197f4 is described below
commit 3021eb197f488bc5342549a720c6348e65d160ee
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Mon Aug 1 10:47:34 2022 +0900
[SPARK-39907][PS] Implement axis and skipna of Series.argmin
### What changes were proposed in this pull request?
1, Implement axis and skipna of Series.argmin
2, compute the argmin on single pass, like `argmax`
### Why are the changes needed?
to add missing parameter
after this change, the underlying implements of argmax and argmin are almost the same
### Does this PR introduce _any_ user-facing change?
yes, new parameter
### How was this patch tested?
added tests
Closes #37328 from zhengruifeng/ps_update_argmin.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/pandas/series.py | 48 +++++++++++++++++++-----------
python/pyspark/pandas/tests/test_series.py | 13 ++++++++
2 files changed, 43 insertions(+), 18 deletions(-)
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 9bc1d7675e1..405dbbf23bf 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -6266,11 +6266,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
Parameters
----------
- axis : {{None}}
+ axis : None
Dummy argument for consistency with Series.
skipna : bool, default True
- Exclude NA/null values. If the entire Series is NA, the result
- will be NA.
+ Exclude NA/null values.
Returns
-------
@@ -6322,13 +6321,20 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
# If the maximum is achieved in multiple locations, the first row position is returned.
return -1 if max_value[0] is None else max_value[1]
- def argmin(self) -> int:
+ def argmin(self, axis: Axis = None, skipna: bool = True) -> int:
"""
Return int position of the smallest value in the Series.
If the minimum is achieved in multiple locations,
the first row position is returned.
+ Parameters
+ ----------
+ axis : None
+ Dummy argument for consistency with Series.
+ skipna : bool, default True
+ Exclude NA/null values.
+
Returns
-------
int
@@ -6350,24 +6356,30 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
>>> s.argmin() # doctest: +SKIP
0
"""
+ axis = validate_axis(axis, none_axis=0)
+ if axis == 1:
+ raise ValueError("axis can only be 0 or 'index'")
sdf = self._internal.spark_frame.select(self.spark.column, NATURAL_ORDER_COLUMN_NAME)
- min_value = sdf.select(
- F.min(scol_for(sdf, self._internal.data_spark_column_names[0])),
- F.first(NATURAL_ORDER_COLUMN_NAME),
- ).head()
- if min_value[1] is None:
- raise ValueError("attempt to get argmin of an empty sequence")
- elif min_value[0] is None:
- return -1
- # We should remember the natural sequence started from 0
seq_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__")
sdf = InternalFrame.attach_distributed_sequence_column(
- sdf.drop(NATURAL_ORDER_COLUMN_NAME), seq_col_name
+ sdf,
+ seq_col_name,
)
- # If the minimum is achieved in multiple locations, the first row position is returned.
- return sdf.filter(
- scol_for(sdf, self._internal.data_spark_column_names[0]) == min_value[0]
- ).head()[0]
+ scol = scol_for(sdf, self._internal.data_spark_column_names[0])
+
+ if skipna:
+ sdf = sdf.orderBy(scol.asc_nulls_last(), NATURAL_ORDER_COLUMN_NAME, seq_col_name)
+ else:
+ sdf = sdf.orderBy(scol.asc_nulls_first(), NATURAL_ORDER_COLUMN_NAME, seq_col_name)
+
+ results = sdf.select(scol, seq_col_name).take(1)
+
+ if len(results) == 0:
+ raise ValueError("attempt to get argmin of an empty sequence")
+ else:
+ min_value = results[0]
+ # If the maximum is achieved in multiple locations, the first row position is returned.
+ return -1 if min_value[0] is None else min_value[1]
def compare(
self, other: "Series", keep_shape: bool = False, keep_equal: bool = False
diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py
index 5012dee785e..fe088bf54eb 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -3012,8 +3012,12 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
psser = ps.from_pandas(pser)
self.assert_eq(pser.argmin(), psser.argmin())
self.assert_eq(pser.argmax(), psser.argmax())
+ self.assert_eq(pser.argmin(skipna=False), psser.argmin(skipna=False))
+ self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))
self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))
self.assert_eq((pser + 1).argmax(skipna=False), (psser + 1).argmax(skipna=False))
+ self.assert_eq(pser.argmin(skipna=False), psser.argmin(skipna=False))
+ self.assert_eq((pser + 1).argmin(skipna=False), (psser + 1).argmin(skipna=False))
# MultiIndex
pser.index = pd.MultiIndex.from_tuples(
@@ -3024,6 +3028,13 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(pser.argmax(), psser.argmax())
self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))
+ pser2 = pd.Series([np.NaN, 1.0, 2.0, np.NaN])
+ psser2 = ps.from_pandas(pser2)
+ self.assert_eq(pser2.argmin(), psser2.argmin())
+ self.assert_eq(pser2.argmax(), psser2.argmax())
+ self.assert_eq(pser2.argmin(skipna=False), psser2.argmin(skipna=False))
+ self.assert_eq(pser2.argmax(skipna=False), psser2.argmax(skipna=False))
+
# Null Series
self.assert_eq(pd.Series([np.nan]).argmin(), ps.Series([np.nan]).argmin())
self.assert_eq(pd.Series([np.nan]).argmax(), ps.Series([np.nan]).argmax())
@@ -3037,6 +3048,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
ps.Series([]).argmax()
with self.assertRaisesRegex(ValueError, "axis can only be 0 or 'index'"):
psser.argmax(axis=1)
+ with self.assertRaisesRegex(ValueError, "axis can only be 0 or 'index'"):
+ psser.argmin(axis=1)
def test_backfill(self):
pdf = pd.DataFrame({"x": [np.nan, 2, 3, 4, np.nan, 6]})
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org