You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/08/09 02:10:48 UTC
[spark] branch master updated: [SPARK-36369][PYTHON] Fix
Index.union to follow pandas 1.3
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a9f371c [SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3
a9f371c is described below
commit a9f371c2470ce28251012dea7428ff9be80bf3e5
Author: itholic <ha...@databricks.com>
AuthorDate: Mon Aug 9 11:10:01 2021 +0900
[SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3
### What changes were proposed in this pull request?
This PR proposes fixing the `Index.union` to follow the behavior of pandas 1.3.
Before:
```python
>>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2])
>>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2])
>>> ps_idx1.union(ps_idx2)
Int64Index([1, 1, 1, 1, 1, 2, 2], dtype='int64')
```
After:
```python
>>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2])
>>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2])
>>> ps_idx1.union(ps_idx2)
Int64Index([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype='int64')
```
This bug is fixed in https://github.com/pandas-dev/pandas/issues/36289.
### Why are the changes needed?
We should follow the behavior of pandas as much as possible.
### Does this PR introduce _any_ user-facing change?
Yes, the result for some cases have duplicates values will change.
### How was this patch tested?
Unit test.
Closes #33634 from itholic/SPARK-36369.
Authored-by: itholic <ha...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/pandas/indexes/base.py | 4 +-
python/pyspark/pandas/tests/indexes/test_base.py | 130 ++++++++++++-----------
2 files changed, 68 insertions(+), 66 deletions(-)
diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py
index 6c842bc..ccdb54c 100644
--- a/python/pyspark/pandas/indexes/base.py
+++ b/python/pyspark/pandas/indexes/base.py
@@ -2292,9 +2292,7 @@ class Index(IndexOpsMixin):
sdf_self = self._internal.spark_frame.select(self._internal.index_spark_columns)
sdf_other = other_idx._internal.spark_frame.select(other_idx._internal.index_spark_columns)
- sdf = sdf_self.union(sdf_other.subtract(sdf_self))
- if isinstance(self, MultiIndex):
- sdf = sdf.drop_duplicates()
+ sdf = sdf_self.unionAll(sdf_other).exceptAll(sdf_self.intersectAll(sdf_other))
if sort:
sdf = sdf.sort(*self._internal.index_spark_column_names)
diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py
index 93aae86..99bfbaa 100644
--- a/python/pyspark/pandas/tests/indexes/test_base.py
+++ b/python/pyspark/pandas/tests/indexes/test_base.py
@@ -1527,21 +1527,20 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
almost=True,
)
- if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
- # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
- pass
- else:
- self.assert_eq(psidx2.union(psidx1), pidx2.union(pidx1))
- self.assert_eq(
- psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
- pidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
- almost=True,
- )
- self.assert_eq(
- psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
- pidx2.union(pd.Series([1, 2, 3, 4, 3, 4, 3, 4])),
- almost=True,
- )
+ # Manually create the expected result here since there is a bug in Index.union
+ # dropping duplicated values in pandas < 1.3.
+ expected = pd.Index([1, 2, 3, 3, 3, 4, 4, 4, 5, 6])
+ self.assert_eq(psidx2.union(psidx1), expected)
+ self.assert_eq(
+ psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
+ expected,
+ almost=True,
+ )
+ self.assert_eq(
+ psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
+ expected,
+ almost=True,
+ )
# MultiIndex
pmidx1 = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")])
@@ -1553,80 +1552,85 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
psmidx3 = ps.from_pandas(pmidx3)
psmidx4 = ps.from_pandas(pmidx4)
- if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
- # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
- pass
- else:
- self.assert_eq(psmidx1.union(psmidx2), pmidx1.union(pmidx2))
- self.assert_eq(psmidx2.union(psmidx1), pmidx2.union(pmidx1))
- self.assert_eq(psmidx3.union(psmidx4), pmidx3.union(pmidx4))
- self.assert_eq(psmidx4.union(psmidx3), pmidx4.union(pmidx3))
- self.assert_eq(
- psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
- pmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
- )
- self.assert_eq(
- psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
- pmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
- )
- self.assert_eq(
- psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
- pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
- )
- self.assert_eq(
- psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
- pmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
- )
+ # Manually create the expected result here since there is a bug in MultiIndex.union
+ # dropping duplicated values in pandas < 1.3.
+ expected = pd.MultiIndex.from_tuples(
+ [("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
+ )
+ self.assert_eq(psmidx1.union(psmidx2), expected)
+ self.assert_eq(psmidx2.union(psmidx1), expected)
+ self.assert_eq(
+ psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
+ expected,
+ )
+ self.assert_eq(
+ psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
+ expected,
+ )
+
+ expected = pd.MultiIndex.from_tuples(
+ [(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
+ )
+ self.assert_eq(psmidx3.union(psmidx4), expected)
+ self.assert_eq(psmidx4.union(psmidx3), expected)
+ self.assert_eq(
+ psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
+ expected,
+ )
+ self.assert_eq(
+ psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
+ expected,
+ )
- if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
- # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
- pass
# Testing if the result is correct after sort=False.
# The `sort` argument is added in pandas 0.24.
- elif LooseVersion(pd.__version__) >= LooseVersion("0.24"):
+ if LooseVersion(pd.__version__) >= LooseVersion("0.24"):
+ # Manually create the expected result here since there is a bug in MultiIndex.union
+ # dropping duplicated values in pandas < 1.3.
+ expected = pd.MultiIndex.from_tuples(
+ [("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
+ )
self.assert_eq(
psmidx1.union(psmidx2, sort=False).sort_values(),
- pmidx1.union(pmidx2, sort=False).sort_values(),
+ expected,
)
self.assert_eq(
psmidx2.union(psmidx1, sort=False).sort_values(),
- pmidx2.union(pmidx1, sort=False).sort_values(),
- )
- self.assert_eq(
- psmidx3.union(psmidx4, sort=False).sort_values(),
- pmidx3.union(pmidx4, sort=False).sort_values(),
- )
- self.assert_eq(
- psmidx4.union(psmidx3, sort=False).sort_values(),
- pmidx4.union(pmidx3, sort=False).sort_values(),
+ expected,
)
self.assert_eq(
psmidx1.union(
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
).sort_values(),
- pmidx1.union(
- [("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
- ).sort_values(),
+ expected,
)
self.assert_eq(
psmidx2.union(
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
).sort_values(),
- pmidx2.union(
- [("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
- ).sort_values(),
+ expected,
+ )
+
+ expected = pd.MultiIndex.from_tuples(
+ [(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
+ )
+ self.assert_eq(
+ psmidx3.union(psmidx4, sort=False).sort_values(),
+ expected,
+ )
+ self.assert_eq(
+ psmidx4.union(psmidx3, sort=False).sort_values(),
+ expected,
)
self.assert_eq(
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
- pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
+ expected,
)
self.assert_eq(
psmidx4.union(
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
).sort_values(),
- pmidx4.union(
- [(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
- ).sort_values(),
+ expected,
)
self.assertRaises(NotImplementedError, lambda: psidx1.union(psmidx1))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org