You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/08/03 10:40:58 UTC

[spark] branch master updated: [SPARK-39939][PYTHON][PS] return self.copy during calling shift with period == 0

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c9eba123990 [SPARK-39939][PYTHON][PS] return self.copy during calling shift with period == 0
c9eba123990 is described below

commit c9eba123990e2e9d7d3a7f2f5c67e55115aa1731
Author: bzhaoop <bz...@gmail.com>
AuthorDate: Wed Aug 3 18:40:26 2022 +0800

    [SPARK-39939][PYTHON][PS] return self.copy during calling shift with period == 0
    
    PySpark raises Error when we call shift func with periods=0.
    
    The behavior of Pandas will return a same copy for the said obj.
    
    ### What changes were proposed in this pull request?
    Will return self.copy when period == 0
    
    ### Why are the changes needed?
    Behaviors between PySpark and pandas are different
    
    PySpark:
    ```
    >>> df = ps.DataFrame({'Col1': [10, 20, 15, 30, 45], 'Col2': [13, 23, 18, 33, 48],'Col3': [17, 27, 22, 37, 52]},columns=['Col1', 'Col2', 'Col3'])
    >>> df.Col1.shift(periods=3)
    0     NaN
    1     NaN
    2     NaN
    3    10.0
    4    20.0
    Name: Col1, dtype: float64
    >>> df.Col1.shift(periods=0)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/home/spark/spark/python/pyspark/pandas/base.py", line 1170, in shift
        return self._shift(periods, fill_value).spark.analyzed
      File "/home/spark/spark/python/pyspark/pandas/spark/accessors.py", line 256, in analyzed
        return first_series(DataFrame(self._data._internal.resolved_copy))
      File "/home/spark/spark/python/pyspark/pandas/utils.py", line 589, in wrapped_lazy_property
        setattr(self, attr_name, fn(self))
      File "/home/spark/spark/python/pyspark/pandas/internal.py", line 1173, in resolved_copy
        sdf = self.spark_frame.select(self.spark_columns + list(HIDDEN_COLUMNS))
      File "/home/spark/spark/python/pyspark/sql/dataframe.py", line 2073, in select
        jdf = self._jdf.select(self._jcols(*cols))
      File "/home/spark/.pyenv/versions/3.8.13/lib/python3.8/site-packages/py4j/java_gateway.py", line 1321, in __call__
        return_value = get_return_value(
      File "/home/spark/spark/python/pyspark/sql/utils.py", line 196, in deco
        raise converted from None
    pyspark.sql.utils.AnalysisException: Cannot specify window frame for lag function
    ```
    
    pandas:
    ```
    >>> pdf = pd.DataFrame({'Col1': [10, 20, 15, 30, 45], 'Col2': [13, 23, 18, 33, 48],'Col3': [17, 27, 22, 37, 52]},columns=['Col1', 'Col2', 'Col3'])
    >>> pdf.Col1.shift(periods=3)
    0     NaN
    1     NaN
    2     NaN
    3    10.0
    4    20.0
    Name: Col1, dtype: float64
    >>> pdf.Col1.shift(periods=0)
    0    10
    1    20
    2    15
    3    30
    4    45
    Name: Col1, dtype: int64
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    call shift func with period == 0.
    
    Closes #37366 from bzhaoopenstack/period.
    
    Authored-by: bzhaoop <bz...@gmail.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/pandas/base.py                 | 3 +++
 python/pyspark/pandas/tests/test_dataframe.py | 1 +
 python/pyspark/pandas/tests/test_series.py    | 2 ++
 3 files changed, 6 insertions(+)

diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py
index 3430f5efa93..bf7149e6b23 100644
--- a/python/pyspark/pandas/base.py
+++ b/python/pyspark/pandas/base.py
@@ -1179,6 +1179,9 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
         if not isinstance(periods, int):
             raise TypeError("periods should be an int; however, got [%s]" % type(periods).__name__)
 
+        if periods == 0:
+            return self.copy()
+
         col = self.spark.column
         window = (
             Window.partitionBy(*part_cols)
diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py
index 1361c44404a..add93faba0c 100644
--- a/python/pyspark/pandas/tests/test_dataframe.py
+++ b/python/pyspark/pandas/tests/test_dataframe.py
@@ -4249,6 +4249,7 @@ class DataFrameTest(ComparisonTestBase, SQLTestUtils):
         psdf.columns = columns
         self.assert_eq(pdf.shift(3), psdf.shift(3))
         self.assert_eq(pdf.shift().shift(-1), psdf.shift().shift(-1))
+        self.assert_eq(pdf.shift(0), psdf.shift(0))
 
     def test_diff(self):
         pdf = pd.DataFrame(
diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py
index 144df0f986a..6bc07def712 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -1549,6 +1549,8 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
         with self.assertRaisesRegex(TypeError, "periods should be an int; however"):
             psser.shift(periods=1.5)
 
+        self.assert_eq(psser.shift(periods=0), pser.shift(periods=0))
+
     def test_diff(self):
         pser = pd.Series([10, 20, 15, 30, 45], name="x")
         psser = ps.Series(pser)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org