You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/08/09 02:10:48 UTC

[spark] branch master updated: [SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a9f371c  [SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3
a9f371c is described below

commit a9f371c2470ce28251012dea7428ff9be80bf3e5
Author: itholic <ha...@databricks.com>
AuthorDate: Mon Aug 9 11:10:01 2021 +0900

    [SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3
    
    ### What changes were proposed in this pull request?
    
    This PR proposes fixing the `Index.union` to follow the behavior of pandas 1.3.
    
    Before:
    ```python
    >>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2])
    >>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2])
    >>> ps_idx1.union(ps_idx2)
    Int64Index([1, 1, 1, 1, 1, 2, 2], dtype='int64')
    ```
    
    After:
    ```python
    >>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2])
    >>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2])
    >>> ps_idx1.union(ps_idx2)
    Int64Index([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype='int64')
    ```
    
    This bug is fixed in https://github.com/pandas-dev/pandas/issues/36289.
    
    ### Why are the changes needed?
    
    We should follow the behavior of pandas as much as possible.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the result for some cases have duplicates values will change.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #33634 from itholic/SPARK-36369.
    
    Authored-by: itholic <ha...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/pandas/indexes/base.py            |   4 +-
 python/pyspark/pandas/tests/indexes/test_base.py | 130 ++++++++++++-----------
 2 files changed, 68 insertions(+), 66 deletions(-)

diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py
index 6c842bc..ccdb54c 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -2292,9 +2292,7 @@ class Index(IndexOpsMixin):
 
         sdf_self = self._internal.spark_frame.select(self._internal.index_spark_columns)
         sdf_other = other_idx._internal.spark_frame.select(other_idx._internal.index_spark_columns)
-        sdf = sdf_self.union(sdf_other.subtract(sdf_self))
-        if isinstance(self, MultiIndex):
-            sdf = sdf.drop_duplicates()
+        sdf = sdf_self.unionAll(sdf_other).exceptAll(sdf_self.intersectAll(sdf_other))
         if sort:
             sdf = sdf.sort(*self._internal.index_spark_column_names)
 
diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py
index 93aae86..99bfbaa 100644
--- a/python/pyspark/pandas/tests/indexes/test_base.py
+++ b/python/pyspark/pandas/tests/indexes/test_base.py
@@ -1527,21 +1527,20 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
                 almost=True,
             )
 
-            if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
-                # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
-                pass
-            else:
-                self.assert_eq(psidx2.union(psidx1), pidx2.union(pidx1))
-                self.assert_eq(
-                    psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
-                    pidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
-                    almost=True,
-                )
-                self.assert_eq(
-                    psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
-                    pidx2.union(pd.Series([1, 2, 3, 4, 3, 4, 3, 4])),
-                    almost=True,
-                )
+            # Manually create the expected result here since there is a bug in Index.union
+            # dropping duplicated values in pandas < 1.3.
+            expected = pd.Index([1, 2, 3, 3, 3, 4, 4, 4, 5, 6])
+            self.assert_eq(psidx2.union(psidx1), expected)
+            self.assert_eq(
+                psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
+                expected,
+                almost=True,
+            )
+            self.assert_eq(
+                psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
+                expected,
+                almost=True,
+            )
 
         # MultiIndex
         pmidx1 = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")])
@@ -1553,80 +1552,85 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
         psmidx3 = ps.from_pandas(pmidx3)
         psmidx4 = ps.from_pandas(pmidx4)
 
-        if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
-            # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
-            pass
-        else:
-            self.assert_eq(psmidx1.union(psmidx2), pmidx1.union(pmidx2))
-            self.assert_eq(psmidx2.union(psmidx1), pmidx2.union(pmidx1))
-            self.assert_eq(psmidx3.union(psmidx4), pmidx3.union(pmidx4))
-            self.assert_eq(psmidx4.union(psmidx3), pmidx4.union(pmidx3))
-            self.assert_eq(
-                psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
-                pmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
-            )
-            self.assert_eq(
-                psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
-                pmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
-            )
-            self.assert_eq(
-                psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
-                pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
-            )
-            self.assert_eq(
-                psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
-                pmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
-            )
+        # Manually create the expected result here since there is a bug in MultiIndex.union
+        # dropping duplicated values in pandas < 1.3.
+        expected = pd.MultiIndex.from_tuples(
+            [("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
+        )
+        self.assert_eq(psmidx1.union(psmidx2), expected)
+        self.assert_eq(psmidx2.union(psmidx1), expected)
+        self.assert_eq(
+            psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
+            expected,
+        )
+        self.assert_eq(
+            psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
+            expected,
+        )
+
+        expected = pd.MultiIndex.from_tuples(
+            [(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
+        )
+        self.assert_eq(psmidx3.union(psmidx4), expected)
+        self.assert_eq(psmidx4.union(psmidx3), expected)
+        self.assert_eq(
+            psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
+            expected,
+        )
+        self.assert_eq(
+            psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
+            expected,
+        )
 
-        if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
-            # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
-            pass
         # Testing if the result is correct after sort=False.
         # The `sort` argument is added in pandas 0.24.
-        elif LooseVersion(pd.__version__) >= LooseVersion("0.24"):
+        if LooseVersion(pd.__version__) >= LooseVersion("0.24"):
+            # Manually create the expected result here since there is a bug in MultiIndex.union
+            # dropping duplicated values in pandas < 1.3.
+            expected = pd.MultiIndex.from_tuples(
+                [("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
+            )
             self.assert_eq(
                 psmidx1.union(psmidx2, sort=False).sort_values(),
-                pmidx1.union(pmidx2, sort=False).sort_values(),
+                expected,
             )
             self.assert_eq(
                 psmidx2.union(psmidx1, sort=False).sort_values(),
-                pmidx2.union(pmidx1, sort=False).sort_values(),
-            )
-            self.assert_eq(
-                psmidx3.union(psmidx4, sort=False).sort_values(),
-                pmidx3.union(pmidx4, sort=False).sort_values(),
-            )
-            self.assert_eq(
-                psmidx4.union(psmidx3, sort=False).sort_values(),
-                pmidx4.union(pmidx3, sort=False).sort_values(),
+                expected,
             )
             self.assert_eq(
                 psmidx1.union(
                     [("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
                 ).sort_values(),
-                pmidx1.union(
-                    [("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
-                ).sort_values(),
+                expected,
             )
             self.assert_eq(
                 psmidx2.union(
                     [("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
                 ).sort_values(),
-                pmidx2.union(
-                    [("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
-                ).sort_values(),
+                expected,
+            )
+
+            expected = pd.MultiIndex.from_tuples(
+                [(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
+            )
+            self.assert_eq(
+                psmidx3.union(psmidx4, sort=False).sort_values(),
+                expected,
+            )
+            self.assert_eq(
+                psmidx4.union(psmidx3, sort=False).sort_values(),
+                expected,
             )
             self.assert_eq(
                 psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
-                pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
+                expected,
             )
             self.assert_eq(
                 psmidx4.union(
                     [(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
                 ).sort_values(),
-                pmidx4.union(
-                    [(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
-                ).sort_values(),
+                expected,
             )
 
         self.assertRaises(NotImplementedError, lambda: psidx1.union(psmidx1))

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org