You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/10/13 09:32:14 UTC

[spark] branch master updated: [SPARK-36438][PYTHON] Support list-like Python objects for Series comparison

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 46bcef7  [SPARK-36438][PYTHON] Support list-like Python objects for Series comparison
46bcef7 is described below

commit 46bcef7472edd40c23afd9ac74cffe13c6a608ad
Author: itholic <ha...@databricks.com>
AuthorDate: Wed Oct 13 18:31:26 2021 +0900

    [SPARK-36438][PYTHON] Support list-like Python objects for Series comparison
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to implement `Series` comparison with list-like Python objects.
    
    Currently `Series` doesn't support the comparison to list-like Python objects such as `list`, `tuple`, `dict`, `set`.
    
    **Before**
    
    ```python
    >>> psser
    0    1
    1    2
    2    3
    dtype: int64
    
    >>> psser == [3, 2, 1]
    Traceback (most recent call last):
    ...
    TypeError: The operation can not be applied to list.
    ...
    ```
    
    **After**
    
    ```python
    >>> psser
    0    1
    1    2
    2    3
    dtype: int64
    
    >>> psser == [3, 2, 1]
    0    False
    1     True
    2    False
    dtype: bool
    ```
    
    This was originally proposed in https://github.com/databricks/koalas/pull/2022, and all reviews in origin PR has been resolved.
    
    ### Why are the changes needed?
    
    To follow pandas' behavior.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the `Series` comparison with list-like Python objects now possible.
    
    ### How was this patch tested?
    
    Unittests
    
    Closes #34114 from itholic/SPARK-36438.
    
    Authored-by: itholic <ha...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/pandas/base.py                      |  6 +-
 python/pyspark/pandas/data_type_ops/base.py        | 95 +++++++++++++++++++++-
 python/pyspark/pandas/series.py                    |  2 +-
 .../pandas/tests/test_ops_on_diff_frames.py        | 37 +++++++++
 python/pyspark/pandas/tests/test_series.py         | 50 ++++++++++++
 5 files changed, 184 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py
index 86f97ef..69e751e 100644
--- a/python/pyspark/pandas/base.py
+++ b/python/pyspark/pandas/base.py
@@ -394,7 +394,11 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
 
     # comparison operators
     def __eq__(self, other: Any) -> SeriesOrIndex:  # type: ignore[override]
-        return self._dtype_op.eq(self, other)
+        # pandas always returns False for all items with dict and set.
+        if isinstance(other, (dict, set)):
+            return self != self
+        else:
+            return self._dtype_op.eq(self, other)
 
     def __ne__(self, other: Any) -> SeriesOrIndex:  # type: ignore[override]
         return self._dtype_op.ne(self, other)
diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py
index 7900432..47a6671 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -376,11 +376,98 @@ class DataTypeOps(object, metaclass=ABCMeta):
         raise TypeError(">= can not be applied to %s." % self.pretty_name)
 
     def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
-        from pyspark.pandas.base import column_op
-
-        _sanitize_list_like(right)
+        if isinstance(right, (list, tuple)):
+            from pyspark.pandas.series import first_series, scol_for
+            from pyspark.pandas.frame import DataFrame
+            from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, InternalField
+
+            len_right = len(right)
+            if len(left) != len(right):
+                raise ValueError("Lengths must be equal")
+
+            sdf = left._internal.spark_frame
+            structed_scol = F.struct(
+                sdf[NATURAL_ORDER_COLUMN_NAME],
+                *left._internal.index_spark_columns,
+                left.spark.column
+            )
+            # The size of the list is expected to be small.
+            collected_structed_scol = F.collect_list(structed_scol)
+            # Sort the array by NATURAL_ORDER_COLUMN so that we can guarantee the order.
+            collected_structed_scol = F.array_sort(collected_structed_scol)
+            right_values_scol = F.array([F.lit(x) for x in right])  # type: ignore
+            index_scol_names = left._internal.index_spark_column_names
+            scol_name = left._internal.spark_column_name_for(left._internal.column_labels[0])
+            # Compare the values of left and right by using zip_with function.
+            cond = F.zip_with(
+                collected_structed_scol,
+                right_values_scol,
+                lambda x, y: F.struct(
+                    *[
+                        x[index_scol_name].alias(index_scol_name)
+                        for index_scol_name in index_scol_names
+                    ],
+                    F.when(x[scol_name].isNull() | y.isNull(), False)
+                    .otherwise(
+                        x[scol_name] == y,
+                    )
+                    .alias(scol_name)
+                ),
+            ).alias(scol_name)
+            # 1. `sdf_new` here looks like the below (the first field of each set is Index):
+            # +----------------------------------------------------------+
+            # |0                                                         |
+            # +----------------------------------------------------------+
+            # |[{0, false}, {1, true}, {2, false}, {3, true}, {4, false}]|
+            # +----------------------------------------------------------+
+            sdf_new = sdf.select(cond)
+            # 2. `sdf_new` after the explode looks like the below:
+            # +----------+
+            # |       col|
+            # +----------+
+            # |{0, false}|
+            # | {1, true}|
+            # |{2, false}|
+            # | {3, true}|
+            # |{4, false}|
+            # +----------+
+            sdf_new = sdf_new.select(F.explode(scol_name))
+            # 3. Here, the final `sdf_new` looks like the below:
+            # +-----------------+-----+
+            # |__index_level_0__|    0|
+            # +-----------------+-----+
+            # |                0|false|
+            # |                1| true|
+            # |                2|false|
+            # |                3| true|
+            # |                4|false|
+            # +-----------------+-----+
+            sdf_new = sdf_new.select("col.*")
+
+            index_spark_columns = [
+                scol_for(sdf_new, index_scol_name) for index_scol_name in index_scol_names
+            ]
+            data_spark_columns = [scol_for(sdf_new, scol_name)]
+
+            internal = left._internal.copy(
+                spark_frame=sdf_new,
+                index_spark_columns=index_spark_columns,
+                data_spark_columns=data_spark_columns,
+                index_fields=[
+                    InternalField.from_struct_field(index_field)
+                    for index_field in sdf_new.select(index_spark_columns).schema.fields
+                ],
+                data_fields=[
+                    InternalField.from_struct_field(
+                        sdf_new.select(data_spark_columns).schema.fields[0]
+                    )
+                ],
+            )
+            return first_series(DataFrame(internal))
+        else:
+            from pyspark.pandas.base import column_op
 
-        return column_op(Column.__eq__)(left, right)
+            return column_op(Column.__eq__)(left, right)
 
     def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 9e20525..f6defe4 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -675,7 +675,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
     koalas = CachedAccessor("koalas", PandasOnSparkSeriesMethods)
 
     # Comparison Operators
-    def eq(self, other: Any) -> bool:
+    def eq(self, other: Any) -> "Series":
         """
         Compare if the current value is equal to the other.
 
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
index cd5d834..ba7f88e 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
@@ -1845,6 +1845,29 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
         pscov = psser1.cov(psser2, min_periods=3)
         self.assert_eq(pcov, pscov, almost=True)
 
+    def test_series_eq(self):
+        pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
+        psser = ps.from_pandas(pser)
+
+        # other = Series
+        pandas_other = pd.Series([np.nan, 1, 3, 4, np.nan, 6], name="x")
+        pandas_on_spark_other = ps.from_pandas(pandas_other)
+        self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
+        self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index())
+
+        # other = Series with different Index
+        pandas_other = pd.Series(
+            [np.nan, 1, 3, 4, np.nan, 6], index=[10, 20, 30, 40, 50, 60], name="x"
+        )
+        pandas_on_spark_other = ps.from_pandas(pandas_other)
+        self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
+
+        # other = Index
+        pandas_other = pd.Index([np.nan, 1, 3, 4, np.nan, 6], name="x")
+        pandas_on_spark_other = ps.from_pandas(pandas_other)
+        self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
+        self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index())
+
 
 class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils):
     @classmethod
@@ -2039,6 +2062,20 @@ class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils):
         with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
             psdf1.combine_first(psdf2)
 
+    def test_series_eq(self):
+        pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
+        psser = ps.from_pandas(pser)
+
+        others = (
+            ps.Series([np.nan, 1, 3, 4, np.nan, 6], name="x"),
+            ps.Index([np.nan, 1, 3, 4, np.nan, 6], name="x"),
+        )
+        for other in others:
+            with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
+                psser.eq(other)
+            with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
+                psser == other
+
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.test_ops_on_diff_frames import *  # noqa: F401
diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py
index aba27fa..0ec8d71 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -3071,6 +3071,56 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
         pscov = psdf["s1"].cov(psdf["s2"], min_periods=4)
         self.assert_eq(pcov, pscov, almost=True)
 
+    def test_eq(self):
+        pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
+        psser = ps.from_pandas(pser)
+
+        # other = Series
+        self.assert_eq(pser.eq(pser), psser.eq(psser))
+        self.assert_eq(pser == pser, psser == psser)
+
+        # other = dict
+        other = {1: None, 2: None, 3: None, 4: None, np.nan: None, 6: None}
+        self.assert_eq(pser.eq(other), psser.eq(other))
+        self.assert_eq(pser == other, psser == other)
+
+        # other = set
+        other = {1, 2, 3, 4, np.nan, 6}
+        self.assert_eq(pser.eq(other), psser.eq(other))
+        self.assert_eq(pser == other, psser == other)
+
+        # other = list
+        other = [np.nan, 1, 3, 4, np.nan, 6]
+        if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
+            self.assert_eq(pser.eq(other), psser.eq(other).sort_index())
+            self.assert_eq(pser == other, (psser == other).sort_index())
+        else:
+            self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index())
+            self.assert_eq((pser == other).rename("x"), (psser == other).sort_index())
+
+        # other = tuple
+        other = (np.nan, 1, 3, 4, np.nan, 6)
+        if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
+            self.assert_eq(pser.eq(other), psser.eq(other).sort_index())
+            self.assert_eq(pser == other, (psser == other).sort_index())
+        else:
+            self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index())
+            self.assert_eq((pser == other).rename("x"), (psser == other).sort_index())
+
+        # other = list with the different length
+        other = [np.nan, 1, 3, 4, np.nan]
+        with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
+            psser.eq(other)
+        with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
+            psser == other
+
+        # other = tuple with the different length
+        other = (np.nan, 1, 3, 4, np.nan)
+        with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
+            psser.eq(other)
+        with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
+            psser == other
+
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.test_series import *  # noqa: F401

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org