You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/10/13 09:32:14 UTC
[spark] branch master updated: [SPARK-36438][PYTHON] Support
list-like Python objects for Series comparison
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 46bcef7 [SPARK-36438][PYTHON] Support list-like Python objects for Series comparison
46bcef7 is described below
commit 46bcef7472edd40c23afd9ac74cffe13c6a608ad
Author: itholic <ha...@databricks.com>
AuthorDate: Wed Oct 13 18:31:26 2021 +0900
[SPARK-36438][PYTHON] Support list-like Python objects for Series comparison
### What changes were proposed in this pull request?
This PR proposes to implement `Series` comparison with list-like Python objects.
Currently `Series` doesn't support the comparison to list-like Python objects such as `list`, `tuple`, `dict`, `set`.
**Before**
```python
>>> psser
0 1
1 2
2 3
dtype: int64
>>> psser == [3, 2, 1]
Traceback (most recent call last):
...
TypeError: The operation can not be applied to list.
...
```
**After**
```python
>>> psser
0 1
1 2
2 3
dtype: int64
>>> psser == [3, 2, 1]
0 False
1 True
2 False
dtype: bool
```
This was originally proposed in https://github.com/databricks/koalas/pull/2022, and all reviews in origin PR has been resolved.
### Why are the changes needed?
To follow pandas' behavior.
### Does this PR introduce _any_ user-facing change?
Yes, the `Series` comparison with list-like Python objects now possible.
### How was this patch tested?
Unittests
Closes #34114 from itholic/SPARK-36438.
Authored-by: itholic <ha...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/pandas/base.py | 6 +-
python/pyspark/pandas/data_type_ops/base.py | 95 +++++++++++++++++++++-
python/pyspark/pandas/series.py | 2 +-
.../pandas/tests/test_ops_on_diff_frames.py | 37 +++++++++
python/pyspark/pandas/tests/test_series.py | 50 ++++++++++++
5 files changed, 184 insertions(+), 6 deletions(-)
diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py
index 86f97ef..69e751e 100644
--- a/python/pyspark/pandas/base.py
+++ b/python/pyspark/pandas/base.py
@@ -394,7 +394,11 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
# comparison operators
def __eq__(self, other: Any) -> SeriesOrIndex: # type: ignore[override]
- return self._dtype_op.eq(self, other)
+ # pandas always returns False for all items with dict and set.
+ if isinstance(other, (dict, set)):
+ return self != self
+ else:
+ return self._dtype_op.eq(self, other)
def __ne__(self, other: Any) -> SeriesOrIndex: # type: ignore[override]
return self._dtype_op.ne(self, other)
diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py
index 7900432..47a6671 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -376,11 +376,98 @@ class DataTypeOps(object, metaclass=ABCMeta):
raise TypeError(">= can not be applied to %s." % self.pretty_name)
def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
- from pyspark.pandas.base import column_op
-
- _sanitize_list_like(right)
+ if isinstance(right, (list, tuple)):
+ from pyspark.pandas.series import first_series, scol_for
+ from pyspark.pandas.frame import DataFrame
+ from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, InternalField
+
+ len_right = len(right)
+ if len(left) != len(right):
+ raise ValueError("Lengths must be equal")
+
+ sdf = left._internal.spark_frame
+ structed_scol = F.struct(
+ sdf[NATURAL_ORDER_COLUMN_NAME],
+ *left._internal.index_spark_columns,
+ left.spark.column
+ )
+ # The size of the list is expected to be small.
+ collected_structed_scol = F.collect_list(structed_scol)
+ # Sort the array by NATURAL_ORDER_COLUMN so that we can guarantee the order.
+ collected_structed_scol = F.array_sort(collected_structed_scol)
+ right_values_scol = F.array([F.lit(x) for x in right]) # type: ignore
+ index_scol_names = left._internal.index_spark_column_names
+ scol_name = left._internal.spark_column_name_for(left._internal.column_labels[0])
+ # Compare the values of left and right by using zip_with function.
+ cond = F.zip_with(
+ collected_structed_scol,
+ right_values_scol,
+ lambda x, y: F.struct(
+ *[
+ x[index_scol_name].alias(index_scol_name)
+ for index_scol_name in index_scol_names
+ ],
+ F.when(x[scol_name].isNull() | y.isNull(), False)
+ .otherwise(
+ x[scol_name] == y,
+ )
+ .alias(scol_name)
+ ),
+ ).alias(scol_name)
+ # 1. `sdf_new` here looks like the below (the first field of each set is Index):
+ # +----------------------------------------------------------+
+ # |0 |
+ # +----------------------------------------------------------+
+ # |[{0, false}, {1, true}, {2, false}, {3, true}, {4, false}]|
+ # +----------------------------------------------------------+
+ sdf_new = sdf.select(cond)
+ # 2. `sdf_new` after the explode looks like the below:
+ # +----------+
+ # | col|
+ # +----------+
+ # |{0, false}|
+ # | {1, true}|
+ # |{2, false}|
+ # | {3, true}|
+ # |{4, false}|
+ # +----------+
+ sdf_new = sdf_new.select(F.explode(scol_name))
+ # 3. Here, the final `sdf_new` looks like the below:
+ # +-----------------+-----+
+ # |__index_level_0__| 0|
+ # +-----------------+-----+
+ # | 0|false|
+ # | 1| true|
+ # | 2|false|
+ # | 3| true|
+ # | 4|false|
+ # +-----------------+-----+
+ sdf_new = sdf_new.select("col.*")
+
+ index_spark_columns = [
+ scol_for(sdf_new, index_scol_name) for index_scol_name in index_scol_names
+ ]
+ data_spark_columns = [scol_for(sdf_new, scol_name)]
+
+ internal = left._internal.copy(
+ spark_frame=sdf_new,
+ index_spark_columns=index_spark_columns,
+ data_spark_columns=data_spark_columns,
+ index_fields=[
+ InternalField.from_struct_field(index_field)
+ for index_field in sdf_new.select(index_spark_columns).schema.fields
+ ],
+ data_fields=[
+ InternalField.from_struct_field(
+ sdf_new.select(data_spark_columns).schema.fields[0]
+ )
+ ],
+ )
+ return first_series(DataFrame(internal))
+ else:
+ from pyspark.pandas.base import column_op
- return column_op(Column.__eq__)(left, right)
+ return column_op(Column.__eq__)(left, right)
def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 9e20525..f6defe4 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -675,7 +675,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
koalas = CachedAccessor("koalas", PandasOnSparkSeriesMethods)
# Comparison Operators
- def eq(self, other: Any) -> bool:
+ def eq(self, other: Any) -> "Series":
"""
Compare if the current value is equal to the other.
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
index cd5d834..ba7f88e 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
@@ -1845,6 +1845,29 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
pscov = psser1.cov(psser2, min_periods=3)
self.assert_eq(pcov, pscov, almost=True)
+ def test_series_eq(self):
+ pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
+ psser = ps.from_pandas(pser)
+
+ # other = Series
+ pandas_other = pd.Series([np.nan, 1, 3, 4, np.nan, 6], name="x")
+ pandas_on_spark_other = ps.from_pandas(pandas_other)
+ self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
+ self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index())
+
+ # other = Series with different Index
+ pandas_other = pd.Series(
+ [np.nan, 1, 3, 4, np.nan, 6], index=[10, 20, 30, 40, 50, 60], name="x"
+ )
+ pandas_on_spark_other = ps.from_pandas(pandas_other)
+ self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
+
+ # other = Index
+ pandas_other = pd.Index([np.nan, 1, 3, 4, np.nan, 6], name="x")
+ pandas_on_spark_other = ps.from_pandas(pandas_other)
+ self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
+ self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index())
+
class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils):
@classmethod
@@ -2039,6 +2062,20 @@ class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils):
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
psdf1.combine_first(psdf2)
+ def test_series_eq(self):
+ pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
+ psser = ps.from_pandas(pser)
+
+ others = (
+ ps.Series([np.nan, 1, 3, 4, np.nan, 6], name="x"),
+ ps.Index([np.nan, 1, 3, 4, np.nan, 6], name="x"),
+ )
+ for other in others:
+ with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
+ psser.eq(other)
+ with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
+ psser == other
+
if __name__ == "__main__":
from pyspark.pandas.tests.test_ops_on_diff_frames import * # noqa: F401
diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py
index aba27fa..0ec8d71 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -3071,6 +3071,56 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
pscov = psdf["s1"].cov(psdf["s2"], min_periods=4)
self.assert_eq(pcov, pscov, almost=True)
+ def test_eq(self):
+ pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
+ psser = ps.from_pandas(pser)
+
+ # other = Series
+ self.assert_eq(pser.eq(pser), psser.eq(psser))
+ self.assert_eq(pser == pser, psser == psser)
+
+ # other = dict
+ other = {1: None, 2: None, 3: None, 4: None, np.nan: None, 6: None}
+ self.assert_eq(pser.eq(other), psser.eq(other))
+ self.assert_eq(pser == other, psser == other)
+
+ # other = set
+ other = {1, 2, 3, 4, np.nan, 6}
+ self.assert_eq(pser.eq(other), psser.eq(other))
+ self.assert_eq(pser == other, psser == other)
+
+ # other = list
+ other = [np.nan, 1, 3, 4, np.nan, 6]
+ if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
+ self.assert_eq(pser.eq(other), psser.eq(other).sort_index())
+ self.assert_eq(pser == other, (psser == other).sort_index())
+ else:
+ self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index())
+ self.assert_eq((pser == other).rename("x"), (psser == other).sort_index())
+
+ # other = tuple
+ other = (np.nan, 1, 3, 4, np.nan, 6)
+ if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
+ self.assert_eq(pser.eq(other), psser.eq(other).sort_index())
+ self.assert_eq(pser == other, (psser == other).sort_index())
+ else:
+ self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index())
+ self.assert_eq((pser == other).rename("x"), (psser == other).sort_index())
+
+ # other = list with the different length
+ other = [np.nan, 1, 3, 4, np.nan]
+ with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
+ psser.eq(other)
+ with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
+ psser == other
+
+ # other = tuple with the different length
+ other = (np.nan, 1, 3, 4, np.nan)
+ with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
+ psser.eq(other)
+ with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
+ psser == other
+
if __name__ == "__main__":
from pyspark.pandas.tests.test_series import * # noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org