You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/12/15 09:32:10 UTC
(spark) branch master updated: [SPARK-46418][PS][TESTS] Reorganize `ReshapeTests`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d5d2c580df4d [SPARK-46418][PS][TESTS] Reorganize `ReshapeTests`
d5d2c580df4d is described below
commit d5d2c580df4dc4d30c8719d8d94a34fcedf0f802
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Dec 15 17:31:46 2023 +0800
[SPARK-46418][PS][TESTS] Reorganize `ReshapeTests`
### What changes were proposed in this pull request?
break `ReshapeTests` into multiple small tests
### Why are the changes needed?
1, it parity test is slow, sometimes takes >5 mins;
2, to be consistent with pandas' test https://github.com/pandas-dev/pandas/tree/main/pandas/tests/reshape
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44365 from zhengruifeng/ps_test_reshape.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
dev/sparktestsupport/modules.py | 14 +-
.../__init__.py} | 21 -
.../test_parity_get_dummies.py} | 10 +-
.../test_parity_get_dummies_kwargs.py} | 10 +-
.../test_parity_get_dummies_multiindex.py} | 10 +-
.../test_parity_get_dummies_object.py} | 10 +-
.../test_parity_get_dummies_prefix.py} | 10 +-
.../test_parity_merge_asof.py} | 10 +-
.../test_parity_reshape.py => reshape/__init__.py} | 21 -
.../pandas/tests/reshape/test_get_dummies.py | 126 ++++++
.../tests/reshape/test_get_dummies_kwargs.py | 69 +++
.../tests/reshape/test_get_dummies_multiindex.py | 108 +++++
.../tests/reshape/test_get_dummies_object.py | 84 ++++
.../tests/reshape/test_get_dummies_prefix.py | 93 ++++
.../pandas/tests/reshape/test_merge_asof.py | 220 ++++++++++
python/pyspark/pandas/tests/test_reshape.py | 478 ---------------------
16 files changed, 754 insertions(+), 540 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 7f4e331228dc..8de69e65d741 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -754,7 +754,12 @@ pyspark_pandas = Module(
"pyspark.pandas.tests.resample.test_missing",
"pyspark.pandas.tests.resample.test_series",
"pyspark.pandas.tests.resample.test_timezone",
- "pyspark.pandas.tests.test_reshape",
+ "pyspark.pandas.tests.reshape.test_get_dummies",
+ "pyspark.pandas.tests.reshape.test_get_dummies_kwargs",
+ "pyspark.pandas.tests.reshape.test_get_dummies_multiindex",
+ "pyspark.pandas.tests.reshape.test_get_dummies_object",
+ "pyspark.pandas.tests.reshape.test_get_dummies_prefix",
+ "pyspark.pandas.tests.reshape.test_merge_asof",
"pyspark.pandas.tests.window.test_rolling",
"pyspark.pandas.tests.window.test_rolling_adv",
"pyspark.pandas.tests.window.test_rolling_count",
@@ -1104,7 +1109,12 @@ pyspark_pandas_connect_part1 = Module(
"pyspark.pandas.tests.connect.series.test_parity_sort",
"pyspark.pandas.tests.connect.series.test_parity_stat",
"pyspark.pandas.tests.connect.data_type_ops.test_parity_num_arithmetic",
- "pyspark.pandas.tests.connect.test_parity_reshape",
+ "pyspark.pandas.tests.connect.reshape.test_parity_get_dummies",
+ "pyspark.pandas.tests.connect.reshape.test_parity_get_dummies_kwargs",
+ "pyspark.pandas.tests.connect.reshape.test_parity_get_dummies_multiindex",
+ "pyspark.pandas.tests.connect.reshape.test_parity_get_dummies_object",
+ "pyspark.pandas.tests.connect.reshape.test_parity_get_dummies_prefix",
+ "pyspark.pandas.tests.connect.reshape.test_parity_merge_asof",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_expanding",
],
excluded_python_implementations=[
diff --git a/python/pyspark/pandas/tests/connect/test_parity_reshape.py b/python/pyspark/pandas/tests/connect/reshape/__init__.py
similarity index 53%
copy from python/pyspark/pandas/tests/connect/test_parity_reshape.py
copy to python/pyspark/pandas/tests/connect/reshape/__init__.py
index 356baaff5ba0..cce3acad34a4 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_reshape.py
+++ b/python/pyspark/pandas/tests/connect/reshape/__init__.py
@@ -14,24 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-
-from pyspark.pandas.tests.test_reshape import ReshapeTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-
-
-class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
- pass
-
-
-if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_reshape import * # noqa: F401
-
- try:
- import xmlrunner # type: ignore[import]
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/connect/test_parity_reshape.py b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/test_parity_reshape.py
copy to python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies.py
index 356baaff5ba0..732329ef40dd 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_reshape.py
+++ b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_reshape import ReshapeTestsMixin
+from pyspark.pandas.tests.reshape.test_get_dummies import GetDummiesMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
+class GetDummiesParityTests(
+ GetDummiesMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_reshape import * # noqa: F401
+ from pyspark.pandas.tests.connect.reshape.test_parity_get_dummies import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_reshape.py b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_kwargs.py
similarity index 80%
copy from python/pyspark/pandas/tests/connect/test_parity_reshape.py
copy to python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_kwargs.py
index 356baaff5ba0..3c0e6b5cb664 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_reshape.py
+++ b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_kwargs.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_reshape import ReshapeTestsMixin
+from pyspark.pandas.tests.reshape.test_get_dummies_kwargs import GetDummiesKWArgsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
+class GetDummiesKWArgsParityTests(
+ GetDummiesKWArgsMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_reshape import * # noqa: F401
+ from pyspark.pandas.tests.connect.reshape.test_parity_get_dummies_kwargs import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_reshape.py b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_multiindex.py
similarity index 79%
copy from python/pyspark/pandas/tests/connect/test_parity_reshape.py
copy to python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_multiindex.py
index 356baaff5ba0..77917814e712 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_reshape.py
+++ b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_multiindex.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_reshape import ReshapeTestsMixin
+from pyspark.pandas.tests.reshape.test_get_dummies_multiindex import GetDummiesMultiIndexMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
+class GetDummiesMultiIndexParityTests(
+ GetDummiesMultiIndexMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_reshape import * # noqa: F401
+ from pyspark.pandas.tests.connect.reshape.test_parity_get_dummies_multiindex import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_reshape.py b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_object.py
similarity index 80%
copy from python/pyspark/pandas/tests/connect/test_parity_reshape.py
copy to python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_object.py
index 356baaff5ba0..fc74c8f0a5bd 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_reshape.py
+++ b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_object.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_reshape import ReshapeTestsMixin
+from pyspark.pandas.tests.reshape.test_get_dummies_object import GetDummiesObjectMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
+class GetDummiesObjectParityTests(
+ GetDummiesObjectMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_reshape import * # noqa: F401
+ from pyspark.pandas.tests.connect.reshape.test_parity_get_dummies_object import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_reshape.py b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_prefix.py
similarity index 80%
copy from python/pyspark/pandas/tests/connect/test_parity_reshape.py
copy to python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_prefix.py
index 356baaff5ba0..85189f6cebd6 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_reshape.py
+++ b/python/pyspark/pandas/tests/connect/reshape/test_parity_get_dummies_prefix.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_reshape import ReshapeTestsMixin
+from pyspark.pandas.tests.reshape.test_get_dummies_prefix import GetDummiesPrefixMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
+class GetDummiesPrefixParityTests(
+ GetDummiesPrefixMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_reshape import * # noqa: F401
+ from pyspark.pandas.tests.connect.reshape.test_parity_get_dummies_prefix import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_reshape.py b/python/pyspark/pandas/tests/connect/reshape/test_parity_merge_asof.py
similarity index 82%
copy from python/pyspark/pandas/tests/connect/test_parity_reshape.py
copy to python/pyspark/pandas/tests/connect/reshape/test_parity_merge_asof.py
index 356baaff5ba0..651e73728b67 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_reshape.py
+++ b/python/pyspark/pandas/tests/connect/reshape/test_parity_merge_asof.py
@@ -16,17 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_reshape import ReshapeTestsMixin
+from pyspark.pandas.tests.reshape.test_merge_asof import MergeAsOfMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
+class MergeAsOfParityTests(
+ MergeAsOfMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_reshape import * # noqa: F401
+ from pyspark.pandas.tests.connect.reshape.test_parity_merge_asof import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_reshape.py b/python/pyspark/pandas/tests/reshape/__init__.py
similarity index 53%
rename from python/pyspark/pandas/tests/connect/test_parity_reshape.py
rename to python/pyspark/pandas/tests/reshape/__init__.py
index 356baaff5ba0..cce3acad34a4 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_reshape.py
+++ b/python/pyspark/pandas/tests/reshape/__init__.py
@@ -14,24 +14,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-
-from pyspark.pandas.tests.test_reshape import ReshapeTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-
-
-class ReshapeParityTests(ReshapeTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase):
- pass
-
-
-if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_reshape import * # noqa: F401
-
- try:
- import xmlrunner # type: ignore[import]
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/reshape/test_get_dummies.py b/python/pyspark/pandas/tests/reshape/test_get_dummies.py
new file mode 100644
index 000000000000..25e3c5b878bc
--- /dev/null
+++ b/python/pyspark/pandas/tests/reshape/test_get_dummies.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+from decimal import Decimal
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class GetDummiesMixin:
+ def test_get_dummies(self):
+ for pdf_or_ps in [
+ pd.Series([1, 1, 1, 2, 2, 1, 3, 4]),
+ # pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype='category'),
+ # pd.Series(pd.Categorical([1, 1, 1, 2, 2, 1, 3, 4],
+ # categories=[4, 3, 2, 1])),
+ pd.DataFrame(
+ {
+ "a": [1, 2, 3, 4, 4, 3, 2, 1],
+ # 'b': pd.Categorical(list('abcdabcd')),
+ "b": list("abcdabcd"),
+ }
+ ),
+ pd.DataFrame({10: [1, 2, 3, 4, 4, 3, 2, 1], 20: list("abcdabcd")}),
+ ]:
+ psdf_or_psser = ps.from_pandas(pdf_or_ps)
+
+ self.assert_eq(ps.get_dummies(psdf_or_psser), pd.get_dummies(pdf_or_ps, dtype=np.int8))
+
+ psser = ps.Series([1, 1, 1, 2, 2, 1, 3, 4])
+ with self.assertRaisesRegex(
+ NotImplementedError, "get_dummies currently does not support sparse"
+ ):
+ ps.get_dummies(psser, sparse=True)
+ with self.assertRaisesRegex(NotImplementedError, "get_dummies currently only accept"):
+ ps.get_dummies(ps.Series([b"1"]))
+ with self.assertRaisesRegex(NotImplementedError, "get_dummies currently only accept"):
+ ps.get_dummies(ps.Series([None]))
+
+ def test_get_dummies_date_datetime(self):
+ pdf = pd.DataFrame(
+ {
+ "d": [
+ datetime.date(2019, 1, 1),
+ datetime.date(2019, 1, 2),
+ datetime.date(2019, 1, 1),
+ ],
+ "dt": [
+ datetime.datetime(2019, 1, 1, 0, 0, 0),
+ datetime.datetime(2019, 1, 1, 0, 0, 1),
+ datetime.datetime(2019, 1, 1, 0, 0, 0),
+ ],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(ps.get_dummies(psdf), pd.get_dummies(pdf, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(psdf.d), pd.get_dummies(pdf.d, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(psdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8))
+
+ def test_get_dummies_boolean(self):
+ pdf = pd.DataFrame({"b": [True, False, True]})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(ps.get_dummies(psdf), pd.get_dummies(pdf, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(psdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
+
+ def test_get_dummies_decimal(self):
+ pdf = pd.DataFrame({"d": [Decimal(1.0), Decimal(2.0), Decimal(1)]})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(ps.get_dummies(psdf), pd.get_dummies(pdf, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(psdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True)
+
+ def test_get_dummies_dtype(self):
+ pdf = pd.DataFrame(
+ {
+ # "A": pd.Categorical(['a', 'b', 'a'], categories=['a', 'b', 'c']),
+ "A": ["a", "b", "a"],
+ "B": [0, 0, 1],
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ exp = pd.get_dummies(pdf)
+ exp = exp.astype({"A_a": "float64", "A_b": "float64"})
+ res = ps.get_dummies(psdf, dtype="float64")
+ self.assert_eq(res, exp)
+
+
+class GetDummiesTests(
+ GetDummiesMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.reshape.test_get_dummies import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/reshape/test_get_dummies_kwargs.py b/python/pyspark/pandas/tests/reshape/test_get_dummies_kwargs.py
new file mode 100644
index 000000000000..fcb4504ab9f4
--- /dev/null
+++ b/python/pyspark/pandas/tests/reshape/test_get_dummies_kwargs.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class GetDummiesKWArgsMixin:
+ def test_get_dummies_kwargs(self):
+ # pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype='category')
+ pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4])
+ psser = ps.from_pandas(pser)
+ self.assert_eq(
+ ps.get_dummies(psser, prefix="X", prefix_sep="-"),
+ pd.get_dummies(pser, prefix="X", prefix_sep="-", dtype=np.int8),
+ )
+
+ self.assert_eq(
+ ps.get_dummies(psser, drop_first=True),
+ pd.get_dummies(pser, drop_first=True, dtype=np.int8),
+ )
+
+ # nan
+ # pser = pd.Series([1, 1, 1, 2, np.nan, 3, np.nan, 5], dtype='category')
+ pser = pd.Series([1, 1, 1, 2, np.nan, 3, np.nan, 5])
+ psser = ps.from_pandas(pser)
+ self.assert_eq(ps.get_dummies(psser), pd.get_dummies(pser, dtype=np.int8), almost=True)
+
+ # dummy_na
+ self.assert_eq(
+ ps.get_dummies(psser, dummy_na=True), pd.get_dummies(pser, dummy_na=True, dtype=np.int8)
+ )
+
+
+class GetDummiesKWArgsTests(
+ GetDummiesKWArgsMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.reshape.test_get_dummies_kwargs import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/reshape/test_get_dummies_multiindex.py b/python/pyspark/pandas/tests/reshape/test_get_dummies_multiindex.py
new file mode 100644
index 000000000000..35baaffb93c6
--- /dev/null
+++ b/python/pyspark/pandas/tests/reshape/test_get_dummies_multiindex.py
@@ -0,0 +1,108 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.pandas.utils import name_like_string
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class GetDummiesMultiIndexMixin:
+ def test_get_dummies_multiindex_columns(self):
+ pdf = pd.DataFrame(
+ {
+ ("x", "a", "1"): [1, 2, 3, 4, 4, 3, 2, 1],
+ ("x", "b", "2"): list("abcdabcd"),
+ ("y", "c", "3"): list("abcdabcd"),
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(
+ ps.get_dummies(psdf),
+ pd.get_dummies(pdf, dtype=np.int8).rename(columns=name_like_string),
+ )
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=[("y", "c", "3"), ("x", "a", "1")]),
+ pd.get_dummies(pdf, columns=[("y", "c", "3"), ("x", "a", "1")], dtype=np.int8).rename(
+ columns=name_like_string
+ ),
+ )
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=["x"]),
+ pd.get_dummies(pdf, columns=["x"], dtype=np.int8).rename(columns=name_like_string),
+ )
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=("x", "a")),
+ pd.get_dummies(pdf, columns=("x", "a"), dtype=np.int8).rename(columns=name_like_string),
+ )
+
+ self.assertRaises(KeyError, lambda: ps.get_dummies(psdf, columns=["z"]))
+ self.assertRaises(KeyError, lambda: ps.get_dummies(psdf, columns=("x", "c")))
+ self.assertRaises(ValueError, lambda: ps.get_dummies(psdf, columns=[("x",), "c"]))
+ self.assertRaises(TypeError, lambda: ps.get_dummies(psdf, columns="x"))
+
+ # non-string names
+ pdf = pd.DataFrame(
+ {
+ ("x", 1, "a"): [1, 2, 3, 4, 4, 3, 2, 1],
+ ("x", 2, "b"): list("abcdabcd"),
+ ("y", 3, "c"): list("abcdabcd"),
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(
+ ps.get_dummies(psdf),
+ pd.get_dummies(pdf, dtype=np.int8).rename(columns=name_like_string),
+ )
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=[("y", 3, "c"), ("x", 1, "a")]),
+ pd.get_dummies(pdf, columns=[("y", 3, "c"), ("x", 1, "a")], dtype=np.int8).rename(
+ columns=name_like_string
+ ),
+ )
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=["x"]),
+ pd.get_dummies(pdf, columns=["x"], dtype=np.int8).rename(columns=name_like_string),
+ )
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=("x", 1)),
+ pd.get_dummies(pdf, columns=("x", 1), dtype=np.int8).rename(columns=name_like_string),
+ )
+
+
+class GetDummiesMultiIndexTests(
+ GetDummiesMultiIndexMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.reshape.test_get_dummies_multiindex import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/reshape/test_get_dummies_object.py b/python/pyspark/pandas/tests/reshape/test_get_dummies_object.py
new file mode 100644
index 000000000000..36dbaec2cea8
--- /dev/null
+++ b/python/pyspark/pandas/tests/reshape/test_get_dummies_object.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class GetDummiesObjectMixin:
+ def test_get_dummies_object(self):
+ pdf = pd.DataFrame(
+ {
+ "a": [1, 2, 3, 4, 4, 3, 2, 1],
+ # 'a': pd.Categorical([1, 2, 3, 4, 4, 3, 2, 1]),
+ "b": list("abcdabcd"),
+ # 'c': pd.Categorical(list('abcdabcd')),
+ "c": list("abcdabcd"),
+ }
+ )
+ psdf = ps.from_pandas(pdf)
+
+ # Explicitly exclude object columns
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=["a", "c"]),
+ pd.get_dummies(pdf, columns=["a", "c"], dtype=np.int8),
+ )
+
+ self.assert_eq(ps.get_dummies(psdf), pd.get_dummies(pdf, dtype=np.int8))
+ self.assert_eq(ps.get_dummies(psdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=["b"]), pd.get_dummies(pdf, columns=["b"], dtype=np.int8)
+ )
+
+ self.assertRaises(KeyError, lambda: ps.get_dummies(psdf, columns=("a", "c")))
+ self.assertRaises(TypeError, lambda: ps.get_dummies(psdf, columns="b"))
+
+ # non-string names
+ pdf = pd.DataFrame(
+ {10: [1, 2, 3, 4, 4, 3, 2, 1], 20: list("abcdabcd"), 30: list("abcdabcd")}
+ )
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(
+ ps.get_dummies(psdf, columns=[10, 30]),
+ pd.get_dummies(pdf, columns=[10, 30], dtype=np.int8),
+ )
+
+ self.assertRaises(TypeError, lambda: ps.get_dummies(psdf, columns=10))
+
+
+class GetDummiesObjectTests(
+ GetDummiesObjectMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.reshape.test_get_dummies_object import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/reshape/test_get_dummies_prefix.py b/python/pyspark/pandas/tests/reshape/test_get_dummies_prefix.py
new file mode 100644
index 000000000000..95be2f6a5941
--- /dev/null
+++ b/python/pyspark/pandas/tests/reshape/test_get_dummies_prefix.py
@@ -0,0 +1,93 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class GetDummiesPrefixMixin:
+ def test_get_dummies_prefix(self):
+ pdf = pd.DataFrame({"A": ["a", "b", "a"], "B": ["b", "a", "c"], "D": [0, 0, 1]})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(
+ ps.get_dummies(psdf, prefix=["foo", "bar"]),
+ pd.get_dummies(pdf, prefix=["foo", "bar"], dtype=np.int8),
+ )
+
+ self.assert_eq(
+ ps.get_dummies(psdf, prefix=["foo"], columns=["B"]),
+ pd.get_dummies(pdf, prefix=["foo"], columns=["B"], dtype=np.int8),
+ )
+
+ self.assert_eq(
+ ps.get_dummies(psdf, prefix={"A": "foo", "B": "bar"}),
+ pd.get_dummies(pdf, prefix={"A": "foo", "B": "bar"}, dtype=np.int8),
+ )
+
+ self.assert_eq(
+ ps.get_dummies(psdf, prefix={"B": "foo", "A": "bar"}),
+ pd.get_dummies(pdf, prefix={"B": "foo", "A": "bar"}, dtype=np.int8),
+ )
+
+ self.assert_eq(
+ ps.get_dummies(psdf, prefix={"A": "foo", "B": "bar"}, columns=["A", "B"]),
+ pd.get_dummies(pdf, prefix={"A": "foo", "B": "bar"}, columns=["A", "B"], dtype=np.int8),
+ )
+
+ with self.assertRaisesRegex(NotImplementedError, "string types"):
+ ps.get_dummies(psdf, prefix="foo")
+ with self.assertRaisesRegex(ValueError, "Length of 'prefix' \\(1\\) .* \\(2\\)"):
+ ps.get_dummies(psdf, prefix=["foo"])
+ with self.assertRaisesRegex(ValueError, "Length of 'prefix' \\(2\\) .* \\(1\\)"):
+ ps.get_dummies(psdf, prefix=["foo", "bar"], columns=["B"])
+
+ pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4], name="A")
+ psser = ps.from_pandas(pser)
+
+ self.assert_eq(
+ ps.get_dummies(psser, prefix="foo"), pd.get_dummies(pser, prefix="foo", dtype=np.int8)
+ )
+
+ # columns are ignored.
+ self.assert_eq(
+ ps.get_dummies(psser, prefix=["foo"], columns=["B"]),
+ pd.get_dummies(pser, prefix=["foo"], columns=["B"], dtype=np.int8),
+ )
+
+
+class GetDummiesPrefixTests(
+ GetDummiesPrefixMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.reshape.test_get_dummies_prefix import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/reshape/test_merge_asof.py b/python/pyspark/pandas/tests/reshape/test_merge_asof.py
new file mode 100644
index 000000000000..4d70f55b9ac0
--- /dev/null
+++ b/python/pyspark/pandas/tests/reshape/test_merge_asof.py
@@ -0,0 +1,220 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.errors import AnalysisException
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class MergeAsOfMixin:
+ def test_merge_asof(self):
+ pdf_left = pd.DataFrame(
+ {"a": [1, 5, 10], "b": ["x", "y", "z"], "left_val": ["a", "b", "c"]}, index=[10, 20, 30]
+ )
+ pdf_right = pd.DataFrame(
+ {"a": [1, 2, 3, 6, 7], "b": ["v", "w", "x", "y", "z"], "right_val": [1, 2, 3, 6, 7]},
+ index=[100, 101, 102, 103, 104],
+ )
+ psdf_left = ps.from_pandas(pdf_left)
+ psdf_right = ps.from_pandas(pdf_right)
+
+ self.assert_eq(
+ pd.merge_asof(pdf_left, pdf_right, on="a").sort_values("a").reset_index(drop=True),
+ ps.merge_asof(psdf_left, psdf_right, on="a").sort_values("a").reset_index(drop=True),
+ )
+ self.assert_eq(
+ (
+ pd.merge_asof(pdf_left, pdf_right, left_on="a", right_on="a")
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ (
+ ps.merge_asof(psdf_left, psdf_right, left_on="a", right_on="a")
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ )
+
+ self.assert_eq(
+ pd.merge_asof(
+ pdf_left.set_index("a"), pdf_right, left_index=True, right_on="a"
+ ).sort_index(),
+ ps.merge_asof(
+ psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a"
+ ).sort_index(),
+ )
+
+ self.assert_eq(
+ pd.merge_asof(
+ pdf_left, pdf_right.set_index("a"), left_on="a", right_index=True
+ ).sort_index(),
+ ps.merge_asof(
+ psdf_left, psdf_right.set_index("a"), left_on="a", right_index=True
+ ).sort_index(),
+ )
+ self.assert_eq(
+ pd.merge_asof(
+ pdf_left.set_index("a"), pdf_right.set_index("a"), left_index=True, right_index=True
+ ).sort_index(),
+ ps.merge_asof(
+ psdf_left.set_index("a"),
+ psdf_right.set_index("a"),
+ left_index=True,
+ right_index=True,
+ ).sort_index(),
+ )
+ self.assert_eq(
+ (
+ pd.merge_asof(pdf_left, pdf_right, on="a", by="b")
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ (
+ ps.merge_asof(psdf_left, psdf_right, on="a", by="b")
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ )
+ self.assert_eq(
+ (
+ pd.merge_asof(pdf_left, pdf_right, on="a", tolerance=1)
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ (
+ ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=1)
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ )
+ self.assert_eq(
+ (
+ pd.merge_asof(pdf_left, pdf_right, on="a", allow_exact_matches=False)
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ (
+ ps.merge_asof(psdf_left, psdf_right, on="a", allow_exact_matches=False)
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ )
+ self.assert_eq(
+ (
+ pd.merge_asof(pdf_left, pdf_right, on="a", direction="forward")
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ (
+ ps.merge_asof(psdf_left, psdf_right, on="a", direction="forward")
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ )
+ self.assert_eq(
+ (
+ pd.merge_asof(pdf_left, pdf_right, on="a", direction="nearest")
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ (
+ ps.merge_asof(psdf_left, psdf_right, on="a", direction="nearest")
+ .sort_values("a")
+ .reset_index(drop=True)
+ ),
+ )
+ # Including Series
+ self.assert_eq(
+ pd.merge_asof(pdf_left["a"], pdf_right, on="a").sort_values("a").reset_index(drop=True),
+ ps.merge_asof(psdf_left["a"], psdf_right, on="a")
+ .sort_values("a")
+ .reset_index(drop=True),
+ )
+ self.assert_eq(
+ pd.merge_asof(pdf_left, pdf_right["a"], on="a").sort_values("a").reset_index(drop=True),
+ ps.merge_asof(psdf_left, psdf_right["a"], on="a")
+ .sort_values("a")
+ .reset_index(drop=True),
+ )
+ self.assert_eq(
+ pd.merge_asof(pdf_left["a"], pdf_right["a"], on="a")
+ .sort_values("a")
+ .reset_index(drop=True),
+ ps.merge_asof(psdf_left["a"], psdf_right["a"], on="a")
+ .sort_values("a")
+ .reset_index(drop=True),
+ )
+
+ self.assertRaises(
+ AnalysisException, lambda: ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=-1)
+ )
+ with self.assertRaisesRegex(
+ ValueError,
+ 'Can only pass argument "on" OR "left_on" and "right_on", not a combination of both.',
+ ):
+ ps.merge_asof(psdf_left, psdf_right, on="a", left_on="a")
+ psdf_multi_index = ps.DataFrame(
+ {"a": [1, 2, 3, 6, 7], "b": ["v", "w", "x", "y", "z"], "right_val": [1, 2, 3, 6, 7]},
+ index=pd.MultiIndex.from_tuples([(1, 2), (3, 4), (5, 6), (7, 8), (9, 10)]),
+ )
+ with self.assertRaisesRegex(ValueError, "right can only have one index"):
+ ps.merge_asof(psdf_left, psdf_multi_index, right_index=True)
+ with self.assertRaisesRegex(ValueError, "left can only have one index"):
+ ps.merge_asof(psdf_multi_index, psdf_right, left_index=True)
+ with self.assertRaisesRegex(ValueError, "Must pass right_on or right_index=True"):
+ ps.merge_asof(psdf_left, psdf_right, left_index=True)
+ with self.assertRaisesRegex(ValueError, "Must pass left_on or left_index=True"):
+ ps.merge_asof(psdf_left, psdf_right, right_index=True)
+ with self.assertRaisesRegex(ValueError, "can only asof on a key for left"):
+ ps.merge_asof(psdf_left, psdf_right, right_on="a", left_on=["a", "b"])
+ with self.assertRaisesRegex(ValueError, "can only asof on a key for right"):
+ ps.merge_asof(psdf_left, psdf_right, right_on=["a", "b"], left_on="a")
+ with self.assertRaisesRegex(
+ ValueError, 'Can only pass argument "by" OR "left_by" and "right_by".'
+ ):
+ ps.merge_asof(psdf_left, psdf_right, on="a", by="b", left_by="a")
+ with self.assertRaisesRegex(ValueError, "missing right_by"):
+ ps.merge_asof(psdf_left, psdf_right, on="a", left_by="b")
+ with self.assertRaisesRegex(ValueError, "missing left_by"):
+ ps.merge_asof(psdf_left, psdf_right, on="a", right_by="b")
+ with self.assertRaisesRegex(ValueError, "left_by and right_by must be same length"):
+ ps.merge_asof(psdf_left, psdf_right, on="a", left_by="b", right_by=["a", "b"])
+ psdf_right.columns = ["A", "B", "C"]
+ with self.assertRaisesRegex(ValueError, "No common columns to perform merge on."):
+ ps.merge_asof(psdf_left, psdf_right)
+
+
+class MergeAsOfTests(
+ MergeAsOfMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.reshape.test_merge_asof import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/test_reshape.py b/python/pyspark/pandas/tests/test_reshape.py
deleted file mode 100644
index 2e2615a5661b..000000000000
--- a/python/pyspark/pandas/tests/test_reshape.py
+++ /dev/null
@@ -1,478 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import datetime
-from decimal import Decimal
-
-import numpy as np
-import pandas as pd
-
-from pyspark import pandas as ps
-from pyspark.pandas.utils import name_like_string
-from pyspark.errors import AnalysisException
-from pyspark.testing.pandasutils import PandasOnSparkTestCase
-
-
-class ReshapeTestsMixin:
- def test_get_dummies(self):
- for pdf_or_ps in [
- pd.Series([1, 1, 1, 2, 2, 1, 3, 4]),
- # pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype='category'),
- # pd.Series(pd.Categorical([1, 1, 1, 2, 2, 1, 3, 4],
- # categories=[4, 3, 2, 1])),
- pd.DataFrame(
- {
- "a": [1, 2, 3, 4, 4, 3, 2, 1],
- # 'b': pd.Categorical(list('abcdabcd')),
- "b": list("abcdabcd"),
- }
- ),
- pd.DataFrame({10: [1, 2, 3, 4, 4, 3, 2, 1], 20: list("abcdabcd")}),
- ]:
- psdf_or_psser = ps.from_pandas(pdf_or_ps)
-
- self.assert_eq(ps.get_dummies(psdf_or_psser), pd.get_dummies(pdf_or_ps, dtype=np.int8))
-
- psser = ps.Series([1, 1, 1, 2, 2, 1, 3, 4])
- with self.assertRaisesRegex(
- NotImplementedError, "get_dummies currently does not support sparse"
- ):
- ps.get_dummies(psser, sparse=True)
- with self.assertRaisesRegex(NotImplementedError, "get_dummies currently only accept"):
- ps.get_dummies(ps.Series([b"1"]))
- with self.assertRaisesRegex(NotImplementedError, "get_dummies currently only accept"):
- ps.get_dummies(ps.Series([None]))
-
- def test_get_dummies_object(self):
- pdf = pd.DataFrame(
- {
- "a": [1, 2, 3, 4, 4, 3, 2, 1],
- # 'a': pd.Categorical([1, 2, 3, 4, 4, 3, 2, 1]),
- "b": list("abcdabcd"),
- # 'c': pd.Categorical(list('abcdabcd')),
- "c": list("abcdabcd"),
- }
- )
- psdf = ps.from_pandas(pdf)
-
- # Explicitly exclude object columns
- self.assert_eq(
- ps.get_dummies(psdf, columns=["a", "c"]),
- pd.get_dummies(pdf, columns=["a", "c"], dtype=np.int8),
- )
-
- self.assert_eq(ps.get_dummies(psdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(psdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
- self.assert_eq(
- ps.get_dummies(psdf, columns=["b"]), pd.get_dummies(pdf, columns=["b"], dtype=np.int8)
- )
-
- self.assertRaises(KeyError, lambda: ps.get_dummies(psdf, columns=("a", "c")))
- self.assertRaises(TypeError, lambda: ps.get_dummies(psdf, columns="b"))
-
- # non-string names
- pdf = pd.DataFrame(
- {10: [1, 2, 3, 4, 4, 3, 2, 1], 20: list("abcdabcd"), 30: list("abcdabcd")}
- )
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(
- ps.get_dummies(psdf, columns=[10, 30]),
- pd.get_dummies(pdf, columns=[10, 30], dtype=np.int8),
- )
-
- self.assertRaises(TypeError, lambda: ps.get_dummies(psdf, columns=10))
-
- def test_get_dummies_date_datetime(self):
- pdf = pd.DataFrame(
- {
- "d": [
- datetime.date(2019, 1, 1),
- datetime.date(2019, 1, 2),
- datetime.date(2019, 1, 1),
- ],
- "dt": [
- datetime.datetime(2019, 1, 1, 0, 0, 0),
- datetime.datetime(2019, 1, 1, 0, 0, 1),
- datetime.datetime(2019, 1, 1, 0, 0, 0),
- ],
- }
- )
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(ps.get_dummies(psdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(psdf.d), pd.get_dummies(pdf.d, dtype=np.int8))
- self.assert_eq(ps.get_dummies(psdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8))
-
- def test_get_dummies_boolean(self):
- pdf = pd.DataFrame({"b": [True, False, True]})
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(ps.get_dummies(psdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(psdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
-
- def test_get_dummies_decimal(self):
- pdf = pd.DataFrame({"d": [Decimal(1.0), Decimal(2.0), Decimal(1)]})
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(ps.get_dummies(psdf), pd.get_dummies(pdf, dtype=np.int8))
- self.assert_eq(ps.get_dummies(psdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True)
-
- def test_get_dummies_kwargs(self):
- # pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype='category')
- pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4])
- psser = ps.from_pandas(pser)
- self.assert_eq(
- ps.get_dummies(psser, prefix="X", prefix_sep="-"),
- pd.get_dummies(pser, prefix="X", prefix_sep="-", dtype=np.int8),
- )
-
- self.assert_eq(
- ps.get_dummies(psser, drop_first=True),
- pd.get_dummies(pser, drop_first=True, dtype=np.int8),
- )
-
- # nan
- # pser = pd.Series([1, 1, 1, 2, np.nan, 3, np.nan, 5], dtype='category')
- pser = pd.Series([1, 1, 1, 2, np.nan, 3, np.nan, 5])
- psser = ps.from_pandas(pser)
- self.assert_eq(ps.get_dummies(psser), pd.get_dummies(pser, dtype=np.int8), almost=True)
-
- # dummy_na
- self.assert_eq(
- ps.get_dummies(psser, dummy_na=True), pd.get_dummies(pser, dummy_na=True, dtype=np.int8)
- )
-
- def test_get_dummies_prefix(self):
- pdf = pd.DataFrame({"A": ["a", "b", "a"], "B": ["b", "a", "c"], "D": [0, 0, 1]})
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(
- ps.get_dummies(psdf, prefix=["foo", "bar"]),
- pd.get_dummies(pdf, prefix=["foo", "bar"], dtype=np.int8),
- )
-
- self.assert_eq(
- ps.get_dummies(psdf, prefix=["foo"], columns=["B"]),
- pd.get_dummies(pdf, prefix=["foo"], columns=["B"], dtype=np.int8),
- )
-
- self.assert_eq(
- ps.get_dummies(psdf, prefix={"A": "foo", "B": "bar"}),
- pd.get_dummies(pdf, prefix={"A": "foo", "B": "bar"}, dtype=np.int8),
- )
-
- self.assert_eq(
- ps.get_dummies(psdf, prefix={"B": "foo", "A": "bar"}),
- pd.get_dummies(pdf, prefix={"B": "foo", "A": "bar"}, dtype=np.int8),
- )
-
- self.assert_eq(
- ps.get_dummies(psdf, prefix={"A": "foo", "B": "bar"}, columns=["A", "B"]),
- pd.get_dummies(pdf, prefix={"A": "foo", "B": "bar"}, columns=["A", "B"], dtype=np.int8),
- )
-
- with self.assertRaisesRegex(NotImplementedError, "string types"):
- ps.get_dummies(psdf, prefix="foo")
- with self.assertRaisesRegex(ValueError, "Length of 'prefix' \\(1\\) .* \\(2\\)"):
- ps.get_dummies(psdf, prefix=["foo"])
- with self.assertRaisesRegex(ValueError, "Length of 'prefix' \\(2\\) .* \\(1\\)"):
- ps.get_dummies(psdf, prefix=["foo", "bar"], columns=["B"])
-
- pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4], name="A")
- psser = ps.from_pandas(pser)
-
- self.assert_eq(
- ps.get_dummies(psser, prefix="foo"), pd.get_dummies(pser, prefix="foo", dtype=np.int8)
- )
-
- # columns are ignored.
- self.assert_eq(
- ps.get_dummies(psser, prefix=["foo"], columns=["B"]),
- pd.get_dummies(pser, prefix=["foo"], columns=["B"], dtype=np.int8),
- )
-
- def test_get_dummies_dtype(self):
- pdf = pd.DataFrame(
- {
- # "A": pd.Categorical(['a', 'b', 'a'], categories=['a', 'b', 'c']),
- "A": ["a", "b", "a"],
- "B": [0, 0, 1],
- }
- )
- psdf = ps.from_pandas(pdf)
-
- exp = pd.get_dummies(pdf)
- exp = exp.astype({"A_a": "float64", "A_b": "float64"})
- res = ps.get_dummies(psdf, dtype="float64")
- self.assert_eq(res, exp)
-
- def test_get_dummies_multiindex_columns(self):
- pdf = pd.DataFrame(
- {
- ("x", "a", "1"): [1, 2, 3, 4, 4, 3, 2, 1],
- ("x", "b", "2"): list("abcdabcd"),
- ("y", "c", "3"): list("abcdabcd"),
- }
- )
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(
- ps.get_dummies(psdf),
- pd.get_dummies(pdf, dtype=np.int8).rename(columns=name_like_string),
- )
- self.assert_eq(
- ps.get_dummies(psdf, columns=[("y", "c", "3"), ("x", "a", "1")]),
- pd.get_dummies(pdf, columns=[("y", "c", "3"), ("x", "a", "1")], dtype=np.int8).rename(
- columns=name_like_string
- ),
- )
- self.assert_eq(
- ps.get_dummies(psdf, columns=["x"]),
- pd.get_dummies(pdf, columns=["x"], dtype=np.int8).rename(columns=name_like_string),
- )
- self.assert_eq(
- ps.get_dummies(psdf, columns=("x", "a")),
- pd.get_dummies(pdf, columns=("x", "a"), dtype=np.int8).rename(columns=name_like_string),
- )
-
- self.assertRaises(KeyError, lambda: ps.get_dummies(psdf, columns=["z"]))
- self.assertRaises(KeyError, lambda: ps.get_dummies(psdf, columns=("x", "c")))
- self.assertRaises(ValueError, lambda: ps.get_dummies(psdf, columns=[("x",), "c"]))
- self.assertRaises(TypeError, lambda: ps.get_dummies(psdf, columns="x"))
-
- # non-string names
- pdf = pd.DataFrame(
- {
- ("x", 1, "a"): [1, 2, 3, 4, 4, 3, 2, 1],
- ("x", 2, "b"): list("abcdabcd"),
- ("y", 3, "c"): list("abcdabcd"),
- }
- )
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(
- ps.get_dummies(psdf),
- pd.get_dummies(pdf, dtype=np.int8).rename(columns=name_like_string),
- )
- self.assert_eq(
- ps.get_dummies(psdf, columns=[("y", 3, "c"), ("x", 1, "a")]),
- pd.get_dummies(pdf, columns=[("y", 3, "c"), ("x", 1, "a")], dtype=np.int8).rename(
- columns=name_like_string
- ),
- )
- self.assert_eq(
- ps.get_dummies(psdf, columns=["x"]),
- pd.get_dummies(pdf, columns=["x"], dtype=np.int8).rename(columns=name_like_string),
- )
- self.assert_eq(
- ps.get_dummies(psdf, columns=("x", 1)),
- pd.get_dummies(pdf, columns=("x", 1), dtype=np.int8).rename(columns=name_like_string),
- )
-
- def test_merge_asof(self):
- pdf_left = pd.DataFrame(
- {"a": [1, 5, 10], "b": ["x", "y", "z"], "left_val": ["a", "b", "c"]}, index=[10, 20, 30]
- )
- pdf_right = pd.DataFrame(
- {"a": [1, 2, 3, 6, 7], "b": ["v", "w", "x", "y", "z"], "right_val": [1, 2, 3, 6, 7]},
- index=[100, 101, 102, 103, 104],
- )
- psdf_left = ps.from_pandas(pdf_left)
- psdf_right = ps.from_pandas(pdf_right)
-
- self.assert_eq(
- pd.merge_asof(pdf_left, pdf_right, on="a").sort_values("a").reset_index(drop=True),
- ps.merge_asof(psdf_left, psdf_right, on="a").sort_values("a").reset_index(drop=True),
- )
- self.assert_eq(
- (
- pd.merge_asof(pdf_left, pdf_right, left_on="a", right_on="a")
- .sort_values("a")
- .reset_index(drop=True)
- ),
- (
- ps.merge_asof(psdf_left, psdf_right, left_on="a", right_on="a")
- .sort_values("a")
- .reset_index(drop=True)
- ),
- )
-
- self.assert_eq(
- pd.merge_asof(
- pdf_left.set_index("a"), pdf_right, left_index=True, right_on="a"
- ).sort_index(),
- ps.merge_asof(
- psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a"
- ).sort_index(),
- )
-
- self.assert_eq(
- pd.merge_asof(
- pdf_left, pdf_right.set_index("a"), left_on="a", right_index=True
- ).sort_index(),
- ps.merge_asof(
- psdf_left, psdf_right.set_index("a"), left_on="a", right_index=True
- ).sort_index(),
- )
- self.assert_eq(
- pd.merge_asof(
- pdf_left.set_index("a"), pdf_right.set_index("a"), left_index=True, right_index=True
- ).sort_index(),
- ps.merge_asof(
- psdf_left.set_index("a"),
- psdf_right.set_index("a"),
- left_index=True,
- right_index=True,
- ).sort_index(),
- )
- self.assert_eq(
- (
- pd.merge_asof(pdf_left, pdf_right, on="a", by="b")
- .sort_values("a")
- .reset_index(drop=True)
- ),
- (
- ps.merge_asof(psdf_left, psdf_right, on="a", by="b")
- .sort_values("a")
- .reset_index(drop=True)
- ),
- )
- self.assert_eq(
- (
- pd.merge_asof(pdf_left, pdf_right, on="a", tolerance=1)
- .sort_values("a")
- .reset_index(drop=True)
- ),
- (
- ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=1)
- .sort_values("a")
- .reset_index(drop=True)
- ),
- )
- self.assert_eq(
- (
- pd.merge_asof(pdf_left, pdf_right, on="a", allow_exact_matches=False)
- .sort_values("a")
- .reset_index(drop=True)
- ),
- (
- ps.merge_asof(psdf_left, psdf_right, on="a", allow_exact_matches=False)
- .sort_values("a")
- .reset_index(drop=True)
- ),
- )
- self.assert_eq(
- (
- pd.merge_asof(pdf_left, pdf_right, on="a", direction="forward")
- .sort_values("a")
- .reset_index(drop=True)
- ),
- (
- ps.merge_asof(psdf_left, psdf_right, on="a", direction="forward")
- .sort_values("a")
- .reset_index(drop=True)
- ),
- )
- self.assert_eq(
- (
- pd.merge_asof(pdf_left, pdf_right, on="a", direction="nearest")
- .sort_values("a")
- .reset_index(drop=True)
- ),
- (
- ps.merge_asof(psdf_left, psdf_right, on="a", direction="nearest")
- .sort_values("a")
- .reset_index(drop=True)
- ),
- )
- # Including Series
- self.assert_eq(
- pd.merge_asof(pdf_left["a"], pdf_right, on="a").sort_values("a").reset_index(drop=True),
- ps.merge_asof(psdf_left["a"], psdf_right, on="a")
- .sort_values("a")
- .reset_index(drop=True),
- )
- self.assert_eq(
- pd.merge_asof(pdf_left, pdf_right["a"], on="a").sort_values("a").reset_index(drop=True),
- ps.merge_asof(psdf_left, psdf_right["a"], on="a")
- .sort_values("a")
- .reset_index(drop=True),
- )
- self.assert_eq(
- pd.merge_asof(pdf_left["a"], pdf_right["a"], on="a")
- .sort_values("a")
- .reset_index(drop=True),
- ps.merge_asof(psdf_left["a"], psdf_right["a"], on="a")
- .sort_values("a")
- .reset_index(drop=True),
- )
-
- self.assertRaises(
- AnalysisException, lambda: ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=-1)
- )
- with self.assertRaisesRegex(
- ValueError,
- 'Can only pass argument "on" OR "left_on" and "right_on", not a combination of both.',
- ):
- ps.merge_asof(psdf_left, psdf_right, on="a", left_on="a")
- psdf_multi_index = ps.DataFrame(
- {"a": [1, 2, 3, 6, 7], "b": ["v", "w", "x", "y", "z"], "right_val": [1, 2, 3, 6, 7]},
- index=pd.MultiIndex.from_tuples([(1, 2), (3, 4), (5, 6), (7, 8), (9, 10)]),
- )
- with self.assertRaisesRegex(ValueError, "right can only have one index"):
- ps.merge_asof(psdf_left, psdf_multi_index, right_index=True)
- with self.assertRaisesRegex(ValueError, "left can only have one index"):
- ps.merge_asof(psdf_multi_index, psdf_right, left_index=True)
- with self.assertRaisesRegex(ValueError, "Must pass right_on or right_index=True"):
- ps.merge_asof(psdf_left, psdf_right, left_index=True)
- with self.assertRaisesRegex(ValueError, "Must pass left_on or left_index=True"):
- ps.merge_asof(psdf_left, psdf_right, right_index=True)
- with self.assertRaisesRegex(ValueError, "can only asof on a key for left"):
- ps.merge_asof(psdf_left, psdf_right, right_on="a", left_on=["a", "b"])
- with self.assertRaisesRegex(ValueError, "can only asof on a key for right"):
- ps.merge_asof(psdf_left, psdf_right, right_on=["a", "b"], left_on="a")
- with self.assertRaisesRegex(
- ValueError, 'Can only pass argument "by" OR "left_by" and "right_by".'
- ):
- ps.merge_asof(psdf_left, psdf_right, on="a", by="b", left_by="a")
- with self.assertRaisesRegex(ValueError, "missing right_by"):
- ps.merge_asof(psdf_left, psdf_right, on="a", left_by="b")
- with self.assertRaisesRegex(ValueError, "missing left_by"):
- ps.merge_asof(psdf_left, psdf_right, on="a", right_by="b")
- with self.assertRaisesRegex(ValueError, "left_by and right_by must be same length"):
- ps.merge_asof(psdf_left, psdf_right, on="a", left_by="b", right_by=["a", "b"])
- psdf_right.columns = ["A", "B", "C"]
- with self.assertRaisesRegex(ValueError, "No common columns to perform merge on."):
- ps.merge_asof(psdf_left, psdf_right)
-
-
-class ReshapeTests(ReshapeTestsMixin, PandasOnSparkTestCase):
- pass
-
-
-if __name__ == "__main__":
- import unittest
- from pyspark.pandas.tests.test_reshape import * # noqa: F401
-
- try:
- import xmlrunner
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org