You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/12/11 18:30:02 UTC
(spark) branch master updated: [SPARK-46347][PS][TESTS] Reorganize `RollingTests `
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d5241ff2689 [SPARK-46347][PS][TESTS] Reorganize `RollingTests `
d5241ff2689 is described below
commit d5241ff26892fa615b27ae39b0be1b8907f59f29
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Mon Dec 11 10:29:49 2023 -0800
[SPARK-46347][PS][TESTS] Reorganize `RollingTests `
### What changes were proposed in this pull request?
Reorganize `RollingTests`, break it into multiple small files
### Why are the changes needed?
to be consistent with Pandas's tests
### Does this PR introduce _any_ user-facing change?
no, test only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44281 from zhengruifeng/ps_test_rolling.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
dev/sparktestsupport/modules.py | 16 +-
.../test_parity_groupby_rolling.py} | 12 +-
.../test_parity_groupby_rolling_adv.py} | 12 +-
.../test_parity_groupby_rolling_count.py} | 12 +-
.../connect/{ => window}/test_parity_rolling.py | 10 +-
.../test_parity_rolling_adv.py} | 12 +-
.../test_parity_rolling_count.py} | 12 +-
.../test_parity_rolling_error.py} | 12 +-
python/pyspark/pandas/tests/test_rolling.py | 317 ---------------------
.../pandas/tests/window/test_groupby_rolling.py | 132 +++++++++
.../tests/window/test_groupby_rolling_adv.py | 60 ++++
.../tests/window/test_groupby_rolling_count.py | 113 ++++++++
python/pyspark/pandas/tests/window/test_rolling.py | 91 ++++++
.../test_rolling_adv.py} | 33 ++-
.../pandas/tests/window/test_rolling_count.py | 72 +++++
.../test_rolling_error.py} | 31 +-
16 files changed, 578 insertions(+), 369 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 68e9ed8101d..c77a34f1d22 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -750,7 +750,13 @@ pyspark_pandas = Module(
"pyspark.pandas.tests.resample.test_series",
"pyspark.pandas.tests.resample.test_timezone",
"pyspark.pandas.tests.test_reshape",
- "pyspark.pandas.tests.test_rolling",
+ "pyspark.pandas.tests.window.test_rolling",
+ "pyspark.pandas.tests.window.test_rolling_adv",
+ "pyspark.pandas.tests.window.test_rolling_count",
+ "pyspark.pandas.tests.window.test_rolling_error",
+ "pyspark.pandas.tests.window.test_groupby_rolling",
+ "pyspark.pandas.tests.window.test_groupby_rolling_adv",
+ "pyspark.pandas.tests.window.test_groupby_rolling_count",
"pyspark.pandas.tests.test_scalars",
"pyspark.pandas.tests.test_series_conversion",
"pyspark.pandas.tests.test_series_datetime",
@@ -1120,7 +1126,13 @@ pyspark_pandas_connect_part2 = Module(
"pyspark.pandas.tests.connect.window.test_parity_ewm_error",
"pyspark.pandas.tests.connect.window.test_parity_ewm_mean",
"pyspark.pandas.tests.connect.window.test_parity_groupby_ewm_mean",
- "pyspark.pandas.tests.connect.test_parity_rolling",
+ "pyspark.pandas.tests.connect.window.test_parity_rolling",
+ "pyspark.pandas.tests.connect.window.test_parity_rolling_adv",
+ "pyspark.pandas.tests.connect.window.test_parity_rolling_count",
+ "pyspark.pandas.tests.connect.window.test_parity_rolling_error",
+ "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling",
+ "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_adv",
+ "pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_count",
"pyspark.pandas.tests.connect.test_parity_expanding",
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling",
"pyspark.pandas.tests.connect.computation.test_parity_missing_data",
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling.py
similarity index 76%
copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling.py
index 8318bed24f0..0a3e0b1358f 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
+from pyspark.pandas.tests.window.test_groupby_rolling import GroupByRollingMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+class RollingParityGroupTests(
+ GroupByRollingMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_groupby_rolling import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_adv.py
similarity index 75%
copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_adv.py
index 8318bed24f0..774f8dd9e75 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_adv.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
+from pyspark.pandas.tests.window.test_groupby_rolling_adv import GroupByRollingAdvMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+class RollingParityGroupAdvTests(
+ GroupByRollingAdvMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_adv import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_count.py
similarity index 75%
copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_count.py
index 8318bed24f0..89dc851b32c 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_rolling_count.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
+from pyspark.pandas.tests.window.test_groupby_rolling_count import GroupByRollingCountMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+class RollingParityGroupCountTests(
+ GroupByRollingCountMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_groupby_rolling_count import * # noqa
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_rolling.py
similarity index 79%
copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_rolling.py
index 8318bed24f0..9dc3d9dcd4c 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_rolling.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
+from pyspark.pandas.tests.window.test_rolling import RollingMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+ RollingMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_rolling import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_adv.py
similarity index 77%
copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_rolling_adv.py
index 8318bed24f0..ae0d9e0ba11 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_adv.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
+from pyspark.pandas.tests.window.test_rolling_adv import RollingAdvMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+class RollingParityAdvTests(
+ RollingAdvMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_rolling_adv import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_count.py
similarity index 77%
copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_rolling_count.py
index 8318bed24f0..7bbe31bc303 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_count.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
+from pyspark.pandas.tests.window.test_rolling_count import RollingCountMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+class RollingParityCountTests(
+ RollingCountMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_rolling_count import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_error.py
similarity index 77%
copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_rolling_error.py
index 8318bed24f0..dc4ecb321d7 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_rolling_error.py
@@ -16,19 +16,21 @@
#
import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
+from pyspark.pandas.tests.window.test_rolling_error import RollingErrorMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
-class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+class RollingParityErrorTests(
+ RollingErrorMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ from pyspark.pandas.tests.connect.window.test_parity_rolling_error import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py
deleted file mode 100644
index c7e49eab5bb..00000000000
--- a/python/pyspark/pandas/tests/test_rolling.py
+++ /dev/null
@@ -1,317 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-import unittest
-
-import numpy as np
-import pandas as pd
-
-import pyspark.pandas as ps
-from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-from pyspark.pandas.window import Rolling
-
-
-class RollingTestsMixin:
- def test_rolling_error(self):
- with self.assertRaisesRegex(ValueError, "window must be >= 0"):
- ps.range(10).rolling(window=-1)
- with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
- ps.range(10).rolling(window=1, min_periods=-1)
-
- with self.assertRaisesRegex(
- TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int"
- ):
- Rolling(1, 2)
-
- def _test_rolling_func(self, ps_func, pd_func=None):
- if not pd_func:
- pd_func = ps_func
- if isinstance(pd_func, str):
- pd_func = self.convert_str_to_lambda(pd_func)
- if isinstance(ps_func, str):
- ps_func = self.convert_str_to_lambda(ps_func)
- pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
- psser = ps.from_pandas(pser)
- self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2)))
- self.assert_eq(ps_func(psser.rolling(2)).sum(), pd_func(pser.rolling(2)).sum())
-
- # Multiindex
- pser = pd.Series(
- [1, 2, 3],
- index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]),
- name="a",
- )
- psser = ps.from_pandas(pser)
- self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2)))
-
- pdf = pd.DataFrame(
- {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4)
- )
- psdf = ps.from_pandas(pdf)
- self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2)))
- self.assert_eq(ps_func(psdf.rolling(2)).sum(), pd_func(pdf.rolling(2)).sum())
-
- # Multiindex column
- columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
- pdf.columns = columns
- psdf.columns = columns
- self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2)))
-
- def test_rolling_min(self):
- self._test_rolling_func("min")
-
- def test_rolling_max(self):
- self._test_rolling_func("max")
-
- def test_rolling_mean(self):
- self._test_rolling_func("mean")
-
- def test_rolling_quantile(self):
- self._test_rolling_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower"))
-
- def test_rolling_sum(self):
- self._test_rolling_func("sum")
-
- def test_rolling_count(self):
- pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
- psser = ps.from_pandas(pser)
- self.assert_eq(psser.rolling(2).count(), pser.rolling(2, min_periods=1).count())
- self.assert_eq(psser.rolling(2).count().sum(), pser.rolling(2, min_periods=1).count().sum())
-
- # TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas
- # Multiindex
- pser = pd.Series(
- [1, 2, 3],
- index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]),
- name="a",
- )
- psser = ps.from_pandas(pser)
- self.assert_eq(psser.rolling(2).count(), pser.rolling(2, min_periods=1).count())
-
- pdf = pd.DataFrame(
- {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4)
- )
- psdf = ps.from_pandas(pdf)
- self.assert_eq(psdf.rolling(2).count(), pdf.rolling(2, min_periods=1).count())
- self.assert_eq(psdf.rolling(2).count().sum(), pdf.rolling(2, min_periods=1).count().sum())
-
- # Multiindex column
- columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
- pdf.columns = columns
- psdf.columns = columns
- self.assert_eq(psdf.rolling(2).count(), pdf.rolling(2, min_periods=1).count())
-
- def test_rolling_std(self):
- self._test_rolling_func("std")
-
- def test_rolling_var(self):
- self._test_rolling_func("var")
-
- def test_rolling_skew(self):
- self._test_rolling_func("skew")
-
- def test_rolling_kurt(self):
- self._test_rolling_func("kurt")
-
- def _test_groupby_rolling_func(self, ps_func, pd_func=None):
- if not pd_func:
- pd_func = ps_func
- if isinstance(pd_func, str):
- pd_func = self.convert_str_to_lambda(pd_func)
- if isinstance(ps_func, str):
- ps_func = self.convert_str_to_lambda(ps_func)
- pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
- psser = ps.from_pandas(pser)
- self.assert_eq(
- ps_func(psser.groupby(psser).rolling(2)).sort_index(),
- pd_func(pser.groupby(pser).rolling(2)).sort_index(),
- )
- self.assert_eq(
- ps_func(psser.groupby(psser).rolling(2)).sum(),
- pd_func(pser.groupby(pser).rolling(2)).sum(),
- )
-
- # Multiindex
- pser = pd.Series(
- [1, 2, 3, 2],
- index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z"), ("c", "z")]),
- name="a",
- )
- psser = ps.from_pandas(pser)
- self.assert_eq(
- ps_func(psser.groupby(psser).rolling(2)).sort_index(),
- pd_func(pser.groupby(pser).rolling(2)).sort_index(),
- )
-
- pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]})
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(
- ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(),
- pd_func(pdf.groupby(pdf.a).rolling(2)).sort_index(),
- )
- self.assert_eq(
- ps_func(psdf.groupby(psdf.a).rolling(2)).sum(),
- pd_func(pdf.groupby(pdf.a).rolling(2)).sum(),
- )
- self.assert_eq(
- ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(),
- pd_func(pdf.groupby(pdf.a + 1).rolling(2)).sort_index(),
- )
-
- self.assert_eq(
- ps_func(psdf.b.groupby(psdf.a).rolling(2)).sort_index(),
- pd_func(pdf.b.groupby(pdf.a).rolling(2)).sort_index(),
- )
- self.assert_eq(
- ps_func(psdf.groupby(psdf.a)["b"].rolling(2)).sort_index(),
- pd_func(pdf.groupby(pdf.a)["b"].rolling(2)).sort_index(),
- )
- self.assert_eq(
- ps_func(psdf.groupby(psdf.a)[["b"]].rolling(2)).sort_index(),
- pd_func(pdf.groupby(pdf.a)[["b"]].rolling(2)).sort_index(),
- )
-
- # Multiindex column
- columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
- pdf.columns = columns
- psdf.columns = columns
-
- self.assert_eq(
- ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(),
- pd_func(pdf.groupby(("a", "x")).rolling(2)).sort_index(),
- )
-
- self.assert_eq(
- ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(),
- pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(),
- )
-
- def test_groupby_rolling_count(self):
- pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
- psser = ps.from_pandas(pser)
- # TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas
- self.assert_eq(
- psser.groupby(psser).rolling(2).count().sort_index(),
- pser.groupby(pser).rolling(2, min_periods=1).count().sort_index(),
- )
- self.assert_eq(
- psser.groupby(psser).rolling(2).count().sum(),
- pser.groupby(pser).rolling(2, min_periods=1).count().sum(),
- )
-
- # Multiindex
- pser = pd.Series(
- [1, 2, 3, 2],
- index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z"), ("c", "z")]),
- name="a",
- )
- psser = ps.from_pandas(pser)
- self.assert_eq(
- psser.groupby(psser).rolling(2).count().sort_index(),
- pser.groupby(pser).rolling(2, min_periods=1).count().sort_index(),
- )
-
- pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]})
- psdf = ps.from_pandas(pdf)
-
- self.assert_eq(
- psdf.groupby(psdf.a).rolling(2).count().sort_index(),
- pdf.groupby(pdf.a).rolling(2, min_periods=1).count().sort_index(),
- )
- self.assert_eq(
- psdf.groupby(psdf.a).rolling(2).count().sum(),
- pdf.groupby(pdf.a).rolling(2, min_periods=1).count().sum(),
- )
- self.assert_eq(
- psdf.groupby(psdf.a + 1).rolling(2).count().sort_index(),
- pdf.groupby(pdf.a + 1).rolling(2, min_periods=1).count().sort_index(),
- )
-
- self.assert_eq(
- psdf.b.groupby(psdf.a).rolling(2).count().sort_index(),
- pdf.b.groupby(pdf.a).rolling(2, min_periods=1).count().sort_index(),
- )
- self.assert_eq(
- psdf.groupby(psdf.a)["b"].rolling(2).count().sort_index(),
- pdf.groupby(pdf.a)["b"].rolling(2, min_periods=1).count().sort_index(),
- )
- self.assert_eq(
- psdf.groupby(psdf.a)[["b"]].rolling(2).count().sort_index(),
- pdf.groupby(pdf.a)[["b"]].rolling(2, min_periods=1).count().sort_index(),
- )
-
- # Multiindex column
- columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
- pdf.columns = columns
- psdf.columns = columns
-
- self.assert_eq(
- psdf.groupby(("a", "x")).rolling(2).count().sort_index(),
- pdf.groupby(("a", "x")).rolling(2, min_periods=1).count().sort_index(),
- )
-
- self.assert_eq(
- psdf.groupby([("a", "x"), ("a", "y")]).rolling(2).count().sort_index(),
- pdf.groupby([("a", "x"), ("a", "y")]).rolling(2, min_periods=1).count().sort_index(),
- )
-
- def test_groupby_rolling_min(self):
- self._test_groupby_rolling_func("min")
-
- def test_groupby_rolling_max(self):
- self._test_groupby_rolling_func("max")
-
- def test_groupby_rolling_mean(self):
- self._test_groupby_rolling_func("mean")
-
- def test_groupby_rolling_quantile(self):
- self._test_groupby_rolling_func(
- lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")
- )
-
- def test_groupby_rolling_sum(self):
- self._test_groupby_rolling_func("sum")
-
- def test_groupby_rolling_std(self):
- # TODO: `std` now raise error in pandas 1.0.0
- self._test_groupby_rolling_func("std")
-
- def test_groupby_rolling_var(self):
- self._test_groupby_rolling_func("var")
-
- def test_groupby_rolling_skew(self):
- self._test_groupby_rolling_func("skew")
-
- def test_groupby_rolling_kurt(self):
- self._test_groupby_rolling_func("kurt")
-
-
-class RollingTests(RollingTestsMixin, PandasOnSparkTestCase, TestUtils):
- pass
-
-
-if __name__ == "__main__":
- import unittest
- from pyspark.pandas.tests.test_rolling import * # noqa: F401
-
- try:
- import xmlrunner
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/window/test_groupby_rolling.py b/python/pyspark/pandas/tests/window/test_groupby_rolling.py
new file mode 100644
index 00000000000..a5bced6a8bf
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_groupby_rolling.py
@@ -0,0 +1,132 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import pandas as pd
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+
+
+class GroupByRollingTestingFuncMixin:
+ def _test_groupby_rolling_func(self, ps_func, pd_func=None):
+ if not pd_func:
+ pd_func = ps_func
+ if isinstance(pd_func, str):
+ pd_func = self.convert_str_to_lambda(pd_func)
+ if isinstance(ps_func, str):
+ ps_func = self.convert_str_to_lambda(ps_func)
+ pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
+ psser = ps.from_pandas(pser)
+ self.assert_eq(
+ ps_func(psser.groupby(psser).rolling(2)).sort_index(),
+ pd_func(pser.groupby(pser).rolling(2)).sort_index(),
+ )
+ self.assert_eq(
+ ps_func(psser.groupby(psser).rolling(2)).sum(),
+ pd_func(pser.groupby(pser).rolling(2)).sum(),
+ )
+
+ # Multiindex
+ pser = pd.Series(
+ [1, 2, 3, 2],
+ index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z"), ("c", "z")]),
+ name="a",
+ )
+ psser = ps.from_pandas(pser)
+ self.assert_eq(
+ ps_func(psser.groupby(psser).rolling(2)).sort_index(),
+ pd_func(pser.groupby(pser).rolling(2)).sort_index(),
+ )
+
+ pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(
+ ps_func(psdf.groupby(psdf.a).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a).rolling(2)).sort_index(),
+ )
+ self.assert_eq(
+ ps_func(psdf.groupby(psdf.a).rolling(2)).sum(),
+ pd_func(pdf.groupby(pdf.a).rolling(2)).sum(),
+ )
+ self.assert_eq(
+ ps_func(psdf.groupby(psdf.a + 1).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a + 1).rolling(2)).sort_index(),
+ )
+
+ self.assert_eq(
+ ps_func(psdf.b.groupby(psdf.a).rolling(2)).sort_index(),
+ pd_func(pdf.b.groupby(pdf.a).rolling(2)).sort_index(),
+ )
+ self.assert_eq(
+ ps_func(psdf.groupby(psdf.a)["b"].rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a)["b"].rolling(2)).sort_index(),
+ )
+ self.assert_eq(
+ ps_func(psdf.groupby(psdf.a)[["b"]].rolling(2)).sort_index(),
+ pd_func(pdf.groupby(pdf.a)[["b"]].rolling(2)).sort_index(),
+ )
+
+ # Multiindex column
+ columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
+ pdf.columns = columns
+ psdf.columns = columns
+
+ self.assert_eq(
+ ps_func(psdf.groupby(("a", "x")).rolling(2)).sort_index(),
+ pd_func(pdf.groupby(("a", "x")).rolling(2)).sort_index(),
+ )
+
+ self.assert_eq(
+ ps_func(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(),
+ pd_func(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2)).sort_index(),
+ )
+
+
+class GroupByRollingMixin(GroupByRollingTestingFuncMixin):
+ def test_groupby_rolling_min(self):
+ self._test_groupby_rolling_func("min")
+
+ def test_groupby_rolling_max(self):
+ self._test_groupby_rolling_func("max")
+
+ def test_groupby_rolling_mean(self):
+ self._test_groupby_rolling_func("mean")
+
+ def test_groupby_rolling_sum(self):
+ self._test_groupby_rolling_func("sum")
+
+
+class GroupByRollingTests(
+ GroupByRollingMixin,
+ PandasOnSparkTestCase,
+ TestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.window.test_groupby_rolling import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/window/test_groupby_rolling_adv.py b/python/pyspark/pandas/tests/window/test_groupby_rolling_adv.py
new file mode 100644
index 00000000000..13fa5902d2a
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_groupby_rolling_adv.py
@@ -0,0 +1,60 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.pandas.tests.window.test_groupby_rolling import GroupByRollingTestingFuncMixin
+
+
+class GroupByRollingAdvMixin(GroupByRollingTestingFuncMixin):
+ def test_groupby_rolling_quantile(self):
+ self._test_groupby_rolling_func(
+ lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")
+ )
+
+ def test_groupby_rolling_std(self):
+ # TODO: `std` now raise error in pandas 1.0.0
+ self._test_groupby_rolling_func("std")
+
+ def test_groupby_rolling_var(self):
+ self._test_groupby_rolling_func("var")
+
+ def test_groupby_rolling_skew(self):
+ self._test_groupby_rolling_func("skew")
+
+ def test_groupby_rolling_kurt(self):
+ self._test_groupby_rolling_func("kurt")
+
+
+class GroupByRollingAdvTests(
+ GroupByRollingAdvMixin,
+ PandasOnSparkTestCase,
+ TestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.window.test_groupby_rolling_adv import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/window/test_groupby_rolling_count.py b/python/pyspark/pandas/tests/window/test_groupby_rolling_count.py
new file mode 100644
index 00000000000..7499e2f821a
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_groupby_rolling_count.py
@@ -0,0 +1,113 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import pandas as pd
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+
+
+class GroupByRollingCountMixin:
+ def test_groupby_rolling_count(self):
+ pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
+ psser = ps.from_pandas(pser)
+ # TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas
+ self.assert_eq(
+ psser.groupby(psser).rolling(2).count().sort_index(),
+ pser.groupby(pser).rolling(2, min_periods=1).count().sort_index(),
+ )
+ self.assert_eq(
+ psser.groupby(psser).rolling(2).count().sum(),
+ pser.groupby(pser).rolling(2, min_periods=1).count().sum(),
+ )
+
+ # Multiindex
+ pser = pd.Series(
+ [1, 2, 3, 2],
+ index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z"), ("c", "z")]),
+ name="a",
+ )
+ psser = ps.from_pandas(pser)
+ self.assert_eq(
+ psser.groupby(psser).rolling(2).count().sort_index(),
+ pser.groupby(pser).rolling(2, min_periods=1).count().sort_index(),
+ )
+
+ pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]})
+ psdf = ps.from_pandas(pdf)
+
+ self.assert_eq(
+ psdf.groupby(psdf.a).rolling(2).count().sort_index(),
+ pdf.groupby(pdf.a).rolling(2, min_periods=1).count().sort_index(),
+ )
+ self.assert_eq(
+ psdf.groupby(psdf.a).rolling(2).count().sum(),
+ pdf.groupby(pdf.a).rolling(2, min_periods=1).count().sum(),
+ )
+ self.assert_eq(
+ psdf.groupby(psdf.a + 1).rolling(2).count().sort_index(),
+ pdf.groupby(pdf.a + 1).rolling(2, min_periods=1).count().sort_index(),
+ )
+
+ self.assert_eq(
+ psdf.b.groupby(psdf.a).rolling(2).count().sort_index(),
+ pdf.b.groupby(pdf.a).rolling(2, min_periods=1).count().sort_index(),
+ )
+ self.assert_eq(
+ psdf.groupby(psdf.a)["b"].rolling(2).count().sort_index(),
+ pdf.groupby(pdf.a)["b"].rolling(2, min_periods=1).count().sort_index(),
+ )
+ self.assert_eq(
+ psdf.groupby(psdf.a)[["b"]].rolling(2).count().sort_index(),
+ pdf.groupby(pdf.a)[["b"]].rolling(2, min_periods=1).count().sort_index(),
+ )
+
+ # Multiindex column
+ columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
+ pdf.columns = columns
+ psdf.columns = columns
+
+ self.assert_eq(
+ psdf.groupby(("a", "x")).rolling(2).count().sort_index(),
+ pdf.groupby(("a", "x")).rolling(2, min_periods=1).count().sort_index(),
+ )
+
+ self.assert_eq(
+ psdf.groupby([("a", "x"), ("a", "y")]).rolling(2).count().sort_index(),
+ pdf.groupby([("a", "x"), ("a", "y")]).rolling(2, min_periods=1).count().sort_index(),
+ )
+
+
+class GroupByRollingCountTests(
+ GroupByRollingCountMixin,
+ PandasOnSparkTestCase,
+ TestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.window.test_groupby_rolling_count import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/window/test_rolling.py b/python/pyspark/pandas/tests/window/test_rolling.py
new file mode 100644
index 00000000000..cf6903afe7c
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_rolling.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import pandas as pd
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class RollingTestingFuncMixin:
+ def _test_rolling_func(self, ps_func, pd_func=None):
+ if not pd_func:
+ pd_func = ps_func
+ if isinstance(pd_func, str):
+ pd_func = self.convert_str_to_lambda(pd_func)
+ if isinstance(ps_func, str):
+ ps_func = self.convert_str_to_lambda(ps_func)
+ pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
+ psser = ps.from_pandas(pser)
+ self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2)))
+ self.assert_eq(ps_func(psser.rolling(2)).sum(), pd_func(pser.rolling(2)).sum())
+
+ # Multiindex
+ pser = pd.Series(
+ [1, 2, 3],
+ index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]),
+ name="a",
+ )
+ psser = ps.from_pandas(pser)
+ self.assert_eq(ps_func(psser.rolling(2)), pd_func(pser.rolling(2)))
+
+ pdf = pd.DataFrame(
+ {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4)
+ )
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2)))
+ self.assert_eq(ps_func(psdf.rolling(2)).sum(), pd_func(pdf.rolling(2)).sum())
+
+ # Multiindex column
+ columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
+ pdf.columns = columns
+ psdf.columns = columns
+ self.assert_eq(ps_func(psdf.rolling(2)), pd_func(pdf.rolling(2)))
+
+
+class RollingMixin(RollingTestingFuncMixin):
+ def test_rolling_min(self):
+ self._test_rolling_func("min")
+
+ def test_rolling_max(self):
+ self._test_rolling_func("max")
+
+ def test_rolling_mean(self):
+ self._test_rolling_func("mean")
+
+ def test_rolling_sum(self):
+ self._test_rolling_func("sum")
+
+
+class RollingTests(
+ RollingMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.window.test_rolling import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/window/test_rolling_adv.py
similarity index 55%
copy from python/pyspark/pandas/tests/connect/test_parity_rolling.py
copy to python/pyspark/pandas/tests/window/test_rolling_adv.py
index 8318bed24f0..6ae48dfa76d 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/window/test_rolling_adv.py
@@ -14,24 +14,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.pandas.tests.window.test_rolling import RollingTestingFuncMixin
-class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+class RollingAdvMixin(RollingTestingFuncMixin):
+ def test_rolling_quantile(self):
+ self._test_rolling_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower"))
+
+ def test_rolling_std(self):
+ self._test_rolling_func("std")
+
+ def test_rolling_var(self):
+ self._test_rolling_func("var")
+
+ def test_rolling_skew(self):
+ self._test_rolling_func("skew")
+
+ def test_rolling_kurt(self):
+ self._test_rolling_func("kurt")
+
+
+class RollingAdvTests(
+ RollingAdvMixin,
+ PandasOnSparkTestCase,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ import unittest
+ from pyspark.pandas.tests.window.test_rolling_adv import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/window/test_rolling_count.py b/python/pyspark/pandas/tests/window/test_rolling_count.py
new file mode 100644
index 00000000000..36ec8cb056a
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_rolling_count.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import pandas as pd
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+
+
+class RollingCountMixin:
+ def test_rolling_count(self):
+ pser = pd.Series([1, 2, 3, 7, 9, 8], index=np.random.rand(6), name="a")
+ psser = ps.from_pandas(pser)
+ self.assert_eq(psser.rolling(2).count(), pser.rolling(2, min_periods=1).count())
+ self.assert_eq(psser.rolling(2).count().sum(), pser.rolling(2, min_periods=1).count().sum())
+
+ # TODO(SPARK-43432): Fix `min_periods` for Rolling.count() to work same as pandas
+ # Multiindex
+ pser = pd.Series(
+ [1, 2, 3],
+ index=pd.MultiIndex.from_tuples([("a", "x"), ("a", "y"), ("b", "z")]),
+ name="a",
+ )
+ psser = ps.from_pandas(pser)
+ self.assert_eq(psser.rolling(2).count(), pser.rolling(2, min_periods=1).count())
+
+ pdf = pd.DataFrame(
+ {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4)
+ )
+ psdf = ps.from_pandas(pdf)
+ self.assert_eq(psdf.rolling(2).count(), pdf.rolling(2, min_periods=1).count())
+ self.assert_eq(psdf.rolling(2).count().sum(), pdf.rolling(2, min_periods=1).count().sum())
+
+ # Multiindex column
+ columns = pd.MultiIndex.from_tuples([("a", "x"), ("a", "y")])
+ pdf.columns = columns
+ psdf.columns = columns
+ self.assert_eq(psdf.rolling(2).count(), pdf.rolling(2, min_periods=1).count())
+
+
+class RollingCountTests(
+ RollingCountMixin,
+ PandasOnSparkTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.pandas.tests.window.test_rolling_count import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/connect/test_parity_rolling.py b/python/pyspark/pandas/tests/window/test_rolling_error.py
similarity index 55%
rename from python/pyspark/pandas/tests/connect/test_parity_rolling.py
rename to python/pyspark/pandas/tests/window/test_rolling_error.py
index 8318bed24f0..485eeb78c13 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_rolling.py
+++ b/python/pyspark/pandas/tests/window/test_rolling_error.py
@@ -14,24 +14,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import unittest
-from pyspark.pandas.tests.test_rolling import RollingTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.pandas.window import Rolling
-class RollingParityTests(
- RollingTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
+class RollingErrorMixin:
+ def test_rolling_error(self):
+ with self.assertRaisesRegex(ValueError, "window must be >= 0"):
+ ps.range(10).rolling(window=-1)
+ with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
+ ps.range(10).rolling(window=1, min_periods=-1)
+
+ with self.assertRaisesRegex(
+ TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int"
+ ):
+ Rolling(1, 2)
+
+
+class RollingErrorTests(
+ RollingErrorMixin,
+ PandasOnSparkTestCase,
+ TestUtils,
):
pass
if __name__ == "__main__":
- from pyspark.pandas.tests.connect.test_parity_rolling import * # noqa: F401
+ import unittest
+ from pyspark.pandas.tests.window.test_rolling_error import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org