You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/06/02 03:36:55 UTC
[spark] branch master updated: [SPARK-39314][PS] Respect ps.concat sort parameter to follow pandas behavior
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c61da89eddc [SPARK-39314][PS] Respect ps.concat sort parameter to follow pandas behavior
c61da89eddc is described below
commit c61da89eddcd62d00b27531a1e7ea03548b73fc8
Author: Yikun Jiang <yi...@gmail.com>
AuthorDate: Thu Jun 2 12:36:43 2022 +0900
[SPARK-39314][PS] Respect ps.concat sort parameter to follow pandas behavior
### What changes were proposed in this pull request?
Respect ps.concat sort parameter to follow pandas behavior:
- Remove the multi-index special sort process case and add ut.
- Still keep `num_series != 1` for now to follow pandas behavior
### Why are the changes needed?
Since pandas 1.4+ (https://github.com/pandas-dev/pandas/commit/01b8d2a77e5109adda2504b1cb4b1daeab3c74df), ps.concat method the sort parameter. We need to follow pandas behavior.
### Does this PR introduce _any_ user-facing change?
Yes, but follow pandas 1.4 behavior
### How was this patch tested?
test_concat_index_axis, test_concat_multiindex_sort, concat doctest
passed with 1.3/1.4
Closes #36711 from Yikun/SPARK-39314.
Authored-by: Yikun Jiang <yi...@gmail.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/pandas/namespace.py | 13 +++-----
python/pyspark/pandas/tests/test_namespace.py | 46 +++++++++++++++++++++++++--
2 files changed, 48 insertions(+), 11 deletions(-)
diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py
index 340e270ace5..0f5a979df79 100644
--- a/python/pyspark/pandas/namespace.py
+++ b/python/pyspark/pandas/namespace.py
@@ -2608,9 +2608,8 @@ def concat(
label for label in column_labels_of_psdfs[0] if label in interested_columns
]
- # When multi-index column, although pandas is flaky if `join="inner" and sort=False`,
- # always sort to follow the `join="outer"` case behavior.
- if (len(merged_columns) > 0 and len(merged_columns[0]) > 1) or sort:
+ # If sort is True, sort to follow pandas 1.4+ behavior.
+ if sort:
# FIXME: better ordering
merged_columns = sorted(merged_columns, key=name_like_string)
@@ -2622,11 +2621,9 @@ def concat(
assert len(merged_columns) > 0
- # Always sort when multi-index columns or there are more than two Series,
- # and if there is only one Series, never sort.
- sort = len(merged_columns[0]) > 1 or num_series > 1 or (num_series != 1 and sort)
-
- if sort:
+ # If sort is True, always sort when there are more than two Series,
+ # and if there is only one Series, never sort to follow pandas 1.4+ behavior.
+ if sort and num_series != 1:
# FIXME: better ordering
merged_columns = sorted(merged_columns, key=name_like_string)
diff --git a/python/pyspark/pandas/tests/test_namespace.py b/python/pyspark/pandas/tests/test_namespace.py
index 8c5adb9bae5..4db756c6e66 100644
--- a/python/pyspark/pandas/tests/test_namespace.py
+++ b/python/pyspark/pandas/tests/test_namespace.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+from distutils.version import LooseVersion
import itertools
import inspect
@@ -295,6 +296,28 @@ class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
AssertionError, lambda: ps.timedelta_range(start="1 day", periods=3, freq="ns")
)
+ def test_concat_multiindex_sort(self):
+ # SPARK-39314: Respect ps.concat sort parameter to follow pandas behavior
+ idx = pd.MultiIndex.from_tuples([("Y", "A"), ("Y", "B"), ("X", "C"), ("X", "D")])
+ pdf = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=idx)
+ psdf = ps.from_pandas(pdf)
+
+ ignore_indexes = [True, False]
+ joins = ["inner", "outer"]
+ sorts = [True]
+ if LooseVersion(pd.__version__) >= LooseVersion("1.4"):
+ sorts += [False]
+ objs = [
+ ([psdf, psdf.reset_index()], [pdf, pdf.reset_index()]),
+ ([psdf.reset_index(), psdf], [pdf.reset_index(), pdf]),
+ ]
+ for ignore_index, join, sort in itertools.product(ignore_indexes, joins, sorts):
+ for i, (psdfs, pdfs) in enumerate(objs):
+ self.assert_eq(
+ ps.concat(psdfs, ignore_index=ignore_index, join=join, sort=sort),
+ pd.concat(pdfs, ignore_index=ignore_index, join=join, sort=sort),
+ )
+
def test_concat_index_axis(self):
pdf = pd.DataFrame({"A": [0, 2, 4], "B": [1, 3, 5], "C": [6, 7, 8]})
# TODO: pdf.columns.names = ["ABC"]
@@ -306,16 +329,29 @@ class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
objs = [
([psdf, psdf], [pdf, pdf]),
+ # no Series
([psdf, psdf.reset_index()], [pdf, pdf.reset_index()]),
([psdf.reset_index(), psdf], [pdf.reset_index(), pdf]),
([psdf, psdf[["C", "A"]]], [pdf, pdf[["C", "A"]]]),
([psdf[["C", "A"]], psdf], [pdf[["C", "A"]], pdf]),
+ # only one Series
([psdf, psdf["C"]], [pdf, pdf["C"]]),
([psdf["C"], psdf], [pdf["C"], pdf]),
+ # more than two Series
([psdf["C"], psdf, psdf["A"]], [pdf["C"], pdf, pdf["A"]]),
- ([psdf, psdf["C"], psdf["A"]], [pdf, pdf["C"], pdf["A"]]),
]
+ if LooseVersion(pd.__version__) >= LooseVersion("1.4"):
+ # more than two Series
+ psdfs, pdfs = ([psdf, psdf["C"], psdf["A"]], [pdf, pdf["C"], pdf["A"]])
+ for ignore_index, join, sort in itertools.product(ignore_indexes, joins, sorts):
+ # See also https://github.com/pandas-dev/pandas/issues/47127
+ if (join, sort) != ("outer", True):
+ self.assert_eq(
+ ps.concat(psdfs, ignore_index=ignore_index, join=join, sort=sort),
+ pd.concat(pdfs, ignore_index=ignore_index, join=join, sort=sort),
+ )
+
for ignore_index, join, sort in itertools.product(ignore_indexes, joins, sorts):
for i, (psdfs, pdfs) in enumerate(objs):
with self.subTest(
@@ -350,11 +386,15 @@ class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
objs = [
([psdf3, psdf3], [pdf3, pdf3]),
([psdf3, psdf3.reset_index()], [pdf3, pdf3.reset_index()]),
- ([psdf3.reset_index(), psdf3], [pdf3.reset_index(), pdf3]),
([psdf3, psdf3[[("Y", "C"), ("X", "A")]]], [pdf3, pdf3[[("Y", "C"), ("X", "A")]]]),
- ([psdf3[[("Y", "C"), ("X", "A")]], psdf3], [pdf3[[("Y", "C"), ("X", "A")]], pdf3]),
]
+ if LooseVersion(pd.__version__) >= LooseVersion("1.4"):
+ objs += [
+ ([psdf3.reset_index(), psdf3], [pdf3.reset_index(), pdf3]),
+ ([psdf3[[("Y", "C"), ("X", "A")]], psdf3], [pdf3[[("Y", "C"), ("X", "A")]], pdf3]),
+ ]
+
for ignore_index, sort in itertools.product(ignore_indexes, sorts):
for i, (psdfs, pdfs) in enumerate(objs):
with self.subTest(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org