You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/06/02 03:36:55 UTC

[spark] branch master updated: [SPARK-39314][PS] Respect ps.concat sort parameter to follow pandas behavior

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c61da89eddc [SPARK-39314][PS] Respect ps.concat sort parameter to follow pandas behavior
c61da89eddc is described below

commit c61da89eddcd62d00b27531a1e7ea03548b73fc8
Author: Yikun Jiang <yi...@gmail.com>
AuthorDate: Thu Jun 2 12:36:43 2022 +0900

    [SPARK-39314][PS] Respect ps.concat sort parameter to follow pandas behavior
    
    ### What changes were proposed in this pull request?
    Respect ps.concat sort parameter to follow pandas behavior:
    - Remove the multi-index special sort process case and add ut.
    - Still keep `num_series != 1` for now to follow pandas behavior
    
    ### Why are the changes needed?
    
    Since pandas 1.4+ (https://github.com/pandas-dev/pandas/commit/01b8d2a77e5109adda2504b1cb4b1daeab3c74df),  ps.concat method the sort parameter. We need to follow pandas behavior.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, but follow pandas 1.4 behavior
    
    ### How was this patch tested?
    test_concat_index_axis, test_concat_multiindex_sort, concat doctest
    passed with 1.3/1.4
    
    Closes #36711 from Yikun/SPARK-39314.
    
    Authored-by: Yikun Jiang <yi...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/pandas/namespace.py            | 13 +++-----
 python/pyspark/pandas/tests/test_namespace.py | 46 +++++++++++++++++++++++++--
 2 files changed, 48 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py
index 340e270ace5..0f5a979df79 100644
--- a/python/pyspark/pandas/namespace.py
+++ b/python/pyspark/pandas/namespace.py
@@ -2608,9 +2608,8 @@ def concat(
                 label for label in column_labels_of_psdfs[0] if label in interested_columns
             ]
 
-            # When multi-index column, although pandas is flaky if `join="inner" and sort=False`,
-            # always sort to follow the `join="outer"` case behavior.
-            if (len(merged_columns) > 0 and len(merged_columns[0]) > 1) or sort:
+            # If sort is True, sort to follow pandas 1.4+ behavior.
+            if sort:
                 # FIXME: better ordering
                 merged_columns = sorted(merged_columns, key=name_like_string)
 
@@ -2622,11 +2621,9 @@ def concat(
 
             assert len(merged_columns) > 0
 
-            # Always sort when multi-index columns or there are more than two Series,
-            # and if there is only one Series, never sort.
-            sort = len(merged_columns[0]) > 1 or num_series > 1 or (num_series != 1 and sort)
-
-            if sort:
+            # If sort is True, always sort when there are more than two Series,
+            # and if there is only one Series, never sort to follow pandas 1.4+ behavior.
+            if sort and num_series != 1:
                 # FIXME: better ordering
                 merged_columns = sorted(merged_columns, key=name_like_string)
 
diff --git a/python/pyspark/pandas/tests/test_namespace.py b/python/pyspark/pandas/tests/test_namespace.py
index 8c5adb9bae5..4db756c6e66 100644
--- a/python/pyspark/pandas/tests/test_namespace.py
+++ b/python/pyspark/pandas/tests/test_namespace.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+from distutils.version import LooseVersion
 import itertools
 import inspect
 
@@ -295,6 +296,28 @@ class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
             AssertionError, lambda: ps.timedelta_range(start="1 day", periods=3, freq="ns")
         )
 
+    def test_concat_multiindex_sort(self):
+        # SPARK-39314: Respect ps.concat sort parameter to follow pandas behavior
+        idx = pd.MultiIndex.from_tuples([("Y", "A"), ("Y", "B"), ("X", "C"), ("X", "D")])
+        pdf = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], columns=idx)
+        psdf = ps.from_pandas(pdf)
+
+        ignore_indexes = [True, False]
+        joins = ["inner", "outer"]
+        sorts = [True]
+        if LooseVersion(pd.__version__) >= LooseVersion("1.4"):
+            sorts += [False]
+        objs = [
+            ([psdf, psdf.reset_index()], [pdf, pdf.reset_index()]),
+            ([psdf.reset_index(), psdf], [pdf.reset_index(), pdf]),
+        ]
+        for ignore_index, join, sort in itertools.product(ignore_indexes, joins, sorts):
+            for i, (psdfs, pdfs) in enumerate(objs):
+                self.assert_eq(
+                    ps.concat(psdfs, ignore_index=ignore_index, join=join, sort=sort),
+                    pd.concat(pdfs, ignore_index=ignore_index, join=join, sort=sort),
+                )
+
     def test_concat_index_axis(self):
         pdf = pd.DataFrame({"A": [0, 2, 4], "B": [1, 3, 5], "C": [6, 7, 8]})
         # TODO: pdf.columns.names = ["ABC"]
@@ -306,16 +329,29 @@ class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
 
         objs = [
             ([psdf, psdf], [pdf, pdf]),
+            # no Series
             ([psdf, psdf.reset_index()], [pdf, pdf.reset_index()]),
             ([psdf.reset_index(), psdf], [pdf.reset_index(), pdf]),
             ([psdf, psdf[["C", "A"]]], [pdf, pdf[["C", "A"]]]),
             ([psdf[["C", "A"]], psdf], [pdf[["C", "A"]], pdf]),
+            # only one Series
             ([psdf, psdf["C"]], [pdf, pdf["C"]]),
             ([psdf["C"], psdf], [pdf["C"], pdf]),
+            # more than two Series
             ([psdf["C"], psdf, psdf["A"]], [pdf["C"], pdf, pdf["A"]]),
-            ([psdf, psdf["C"], psdf["A"]], [pdf, pdf["C"], pdf["A"]]),
         ]
 
+        if LooseVersion(pd.__version__) >= LooseVersion("1.4"):
+            # more than two Series
+            psdfs, pdfs = ([psdf, psdf["C"], psdf["A"]], [pdf, pdf["C"], pdf["A"]])
+            for ignore_index, join, sort in itertools.product(ignore_indexes, joins, sorts):
+                # See also https://github.com/pandas-dev/pandas/issues/47127
+                if (join, sort) != ("outer", True):
+                    self.assert_eq(
+                        ps.concat(psdfs, ignore_index=ignore_index, join=join, sort=sort),
+                        pd.concat(pdfs, ignore_index=ignore_index, join=join, sort=sort),
+                    )
+
         for ignore_index, join, sort in itertools.product(ignore_indexes, joins, sorts):
             for i, (psdfs, pdfs) in enumerate(objs):
                 with self.subTest(
@@ -350,11 +386,15 @@ class NamespaceTest(PandasOnSparkTestCase, SQLTestUtils):
         objs = [
             ([psdf3, psdf3], [pdf3, pdf3]),
             ([psdf3, psdf3.reset_index()], [pdf3, pdf3.reset_index()]),
-            ([psdf3.reset_index(), psdf3], [pdf3.reset_index(), pdf3]),
             ([psdf3, psdf3[[("Y", "C"), ("X", "A")]]], [pdf3, pdf3[[("Y", "C"), ("X", "A")]]]),
-            ([psdf3[[("Y", "C"), ("X", "A")]], psdf3], [pdf3[[("Y", "C"), ("X", "A")]], pdf3]),
         ]
 
+        if LooseVersion(pd.__version__) >= LooseVersion("1.4"):
+            objs += [
+                ([psdf3.reset_index(), psdf3], [pdf3.reset_index(), pdf3]),
+                ([psdf3[[("Y", "C"), ("X", "A")]], psdf3], [pdf3[[("Y", "C"), ("X", "A")]], pdf3]),
+            ]
+
         for ignore_index, sort in itertools.product(ignore_indexes, sorts):
             for i, (psdfs, pdfs) in enumerate(objs):
                 with self.subTest(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org