You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "itholic (via GitHub)" <gi...@apache.org> on 2023/07/27 04:04:29 UTC

[GitHub] [spark] itholic commented on a diff in pull request #42158: [SPARK-44548][PYTHON] Add support for pandas-on-Spark DataFrame assertDataFrameEqual

itholic commented on code in PR #42158:
URL: https://github.com/apache/spark/pull/42158#discussion_r1275697965


##########
python/pyspark/errors/error_classes.py:
##########
@@ -233,6 +238,12 @@
       "NumPy array input should be of <dimensions> dimensions."
     ]
   },
+  "INVALID_PANDAS_ON_SPARK_COMPARISON" : {
+    "message" : [
+      "Expected two pandas-on-Spark DataFrames",
+      "but got actual: <actual_type> and expected: <expected_type>"

Review Comment:
   Maybe `expected_type` here is always pandas-on-Spark DataFrame, so we don't need to mention ??



##########
python/pyspark/pandas/tests/test_utils.py:
##########
@@ -105,6 +107,72 @@ def test_validate_index_loc(self):
         with self.assertRaisesRegex(IndexError, err_msg):
             validate_index_loc(psidx, -4)
 
+    def test_assert_df_assertPandasOnSparkEqual(self):
+        import pyspark.pandas as ps
+
+        psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
+        psdf2 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
+
+        assertPandasOnSparkEqual(psdf1, psdf2)
+        assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=True)

Review Comment:
   Can we add a negative case for `checkRowOrder` as well ??



##########
python/pyspark/pandas/tests/test_utils.py:
##########
@@ -105,6 +107,72 @@ def test_validate_index_loc(self):
         with self.assertRaisesRegex(IndexError, err_msg):
             validate_index_loc(psidx, -4)
 
+    def test_assert_df_assertPandasOnSparkEqual(self):
+        import pyspark.pandas as ps
+
+        psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
+        psdf2 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
+
+        assertPandasOnSparkEqual(psdf1, psdf2)
+        assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=True)

Review Comment:
   After roughly reviewing through whole code changes, seems like `assertDataFrameEqual` can cover `assertPandasOnSparkEqual`??
   
   Maybe can we just use `assertDataFrameEqual` through whole DataFrame comparison testing, and don't expose `assertPandasOnSparkEqual` as an user API ??



##########
python/pyspark/testing/utils.py:
##########
@@ -395,24 +404,37 @@ def assertDataFrameEqual(
     elif actual is None or expected is None:
         return False
 
+    import pyspark.pandas as ps
+    from pyspark.testing.pandasutils import assertPandasOnSparkEqual

Review Comment:
   IIUC, do we want to consolidate all DataFrame comparison testing utils into `assertDataFrameEqual`??
   
   If so, we'd better to also mark `assertPandasOnSparkEqual` as an internal function by adding `_` prefix?



##########
python/pyspark/testing/pandasutils.py:
##########
@@ -54,153 +58,350 @@
 have_plotly = plotly_requirement_message is None
 
 
-class PandasOnSparkTestUtils:
-    def convert_str_to_lambda(self, func):
-        """
-        This function coverts `func` str to lambda call
-        """
-        return lambda x: getattr(x, func)()
+__all__ = ["assertPandasOnSparkEqual"]
 
-    def assertPandasEqual(self, left, right, check_exact=True):
-        import pandas as pd
-        from pandas.core.dtypes.common import is_numeric_dtype
-        from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal
-
-        if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
-            try:
-                if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
-                    kwargs = dict(check_freq=False)
-                else:
-                    kwargs = dict()
-
-                if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
-                    # Due to https://github.com/pandas-dev/pandas/issues/35446
-                    check_exact = (
-                        check_exact
-                        and all([is_numeric_dtype(dtype) for dtype in left.dtypes])
-                        and all([is_numeric_dtype(dtype) for dtype in right.dtypes])
-                    )
 
-                assert_frame_equal(
-                    left,
-                    right,
-                    check_index_type=("equiv" if len(left.index) > 0 else False),
-                    check_column_type=("equiv" if len(left.columns) > 0 else False),
-                    check_exact=check_exact,
-                    **kwargs,
-                )
-            except AssertionError as e:
-                msg = (
-                    str(e)
-                    + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
-                    + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
-                )
-                raise AssertionError(msg) from e
-        elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
-            try:
-                if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
-                    kwargs = dict(check_freq=False)
-                else:
-                    kwargs = dict()
-                if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
-                    # Due to https://github.com/pandas-dev/pandas/issues/35446
-                    check_exact = (
-                        check_exact
-                        and is_numeric_dtype(left.dtype)
-                        and is_numeric_dtype(right.dtype)
-                    )
-                assert_series_equal(
-                    left,
-                    right,
-                    check_index_type=("equiv" if len(left.index) > 0 else False),
-                    check_exact=check_exact,
-                    **kwargs,
-                )
-            except AssertionError as e:
-                msg = (
-                    str(e)
-                    + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
-                    + "\n\nRight:\n%s\n%s" % (right, right.dtype)
-                )
-                raise AssertionError(msg) from e
-        elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
-            try:
-                if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
-                    # Due to https://github.com/pandas-dev/pandas/issues/35446
-                    check_exact = (
-                        check_exact
-                        and is_numeric_dtype(left.dtype)
-                        and is_numeric_dtype(right.dtype)
-                    )
-                assert_index_equal(left, right, check_exact=check_exact)
-            except AssertionError as e:
-                msg = (
-                    str(e)
-                    + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
-                    + "\n\nRight:\n%s\n%s" % (right, right.dtype)
-                )
-                raise AssertionError(msg) from e
-        else:
-            raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+def assertPandasDFEqual(
+    left: Union[pd.DataFrame, pd.Series, pd.Index],
+    right: Union[pd.DataFrame, pd.Series, pd.Index],
+    checkExact: bool,
+):
+    from pandas.core.dtypes.common import is_numeric_dtype
+    from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal
 
-    def assertPandasAlmostEqual(self, left, right):
-        """
-        This function checks if given pandas objects approximately same,
-        which means the conditions below:
-          - Both objects are nullable
-          - Compare floats rounding to the number of decimal places, 7 after
-            dropping missing values (NaN, NaT, None)
-        """
-        import pandas as pd
+    if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+        try:
+            if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+                kwargs = dict(check_freq=False)
+            else:
+                kwargs = dict()
+
+            if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
+                # Due to https://github.com/pandas-dev/pandas/issues/35446
+                checkExact = (
+                    checkExact
+                    and all([is_numeric_dtype(dtype) for dtype in left.dtypes])
+                    and all([is_numeric_dtype(dtype) for dtype in right.dtypes])
+                )
 
-        if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+            assert_frame_equal(
+                left,
+                right,
+                check_index_type=("equiv" if len(left.index) > 0 else False),
+                check_column_type=("equiv" if len(left.columns) > 0 else False),
+                check_exact=checkExact,
+                **kwargs,
+            )
+        except AssertionError as e:
             msg = (
-                "DataFrames are not almost equal: "
+                str(e)
                 + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
                 + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
             )
-            self.assertEqual(left.shape, right.shape, msg=msg)
-            for lcol, rcol in zip(left.columns, right.columns):
-                self.assertEqual(lcol, rcol, msg=msg)
-                for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()):
-                    self.assertEqual(lnull, rnull, msg=msg)
-                for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()):
-                    self.assertAlmostEqual(lval, rval, msg=msg)
-            self.assertEqual(left.columns.names, right.columns.names, msg=msg)
-        elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
-            msg = (
-                "Series are not almost equal: "
-                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
-                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            raise AssertionError(msg) from e
+    elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+        try:
+            if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+                kwargs = dict(check_freq=False)
+            else:
+                kwargs = dict()
+            if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
+                # Due to https://github.com/pandas-dev/pandas/issues/35446
+                checkExact = (
+                    checkExact and is_numeric_dtype(left.dtype) and is_numeric_dtype(right.dtype)
+                )
+            assert_series_equal(
+                left,
+                right,
+                check_index_type=("equiv" if len(left.index) > 0 else False),
+                check_exact=checkExact,
+                **kwargs,
             )
-            self.assertEqual(left.name, right.name, msg=msg)
-            self.assertEqual(len(left), len(right), msg=msg)
-            for lnull, rnull in zip(left.isnull(), right.isnull()):
-                self.assertEqual(lnull, rnull, msg=msg)
-            for lval, rval in zip(left.dropna(), right.dropna()):
-                self.assertAlmostEqual(lval, rval, msg=msg)
-        elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex):
+        except AssertionError as e:
             msg = (
-                "MultiIndices are not almost equal: "
+                str(e)
                 + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
                 + "\n\nRight:\n%s\n%s" % (right, right.dtype)
             )
-            self.assertEqual(len(left), len(right), msg=msg)
-            for lval, rval in zip(left, right):
-                self.assertAlmostEqual(lval, rval, msg=msg)
-        elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+            raise AssertionError(msg) from e
+    elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+        try:
+            if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
+                # Due to https://github.com/pandas-dev/pandas/issues/35446
+                checkExact = (
+                    checkExact and is_numeric_dtype(left.dtype) and is_numeric_dtype(right.dtype)
+                )
+            assert_index_equal(left, right, checkExact=checkExact)
+        except AssertionError as e:
             msg = (
-                "Indices are not almost equal: "
+                str(e)
                 + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
                 + "\n\nRight:\n%s\n%s" % (right, right.dtype)
             )
-            self.assertEqual(len(left), len(right), msg=msg)
-            for lnull, rnull in zip(left.isnull(), right.isnull()):
-                self.assertEqual(lnull, rnull, msg=msg)
-            for lval, rval in zip(left.dropna(), right.dropna()):
-                self.assertAlmostEqual(lval, rval, msg=msg)
+            raise AssertionError(msg) from e
+    else:
+        raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+
+def assertPandasDFAlmostEqual(
+    left: Union[pd.DataFrame, pd.Series, pd.Index], right: Union[pd.DataFrame, pd.Series, pd.Index]
+):
+    """
+    This function checks if given pandas objects approximately same,
+    which means the conditions below:
+      - Both objects are nullable
+      - Compare floats rounding to the number of decimal places, 7 after
+        dropping missing values (NaN, NaT, None)
+    """
+    # following pandas convention, rtol=1e-5 and atol=1e-8
+    rtol = 1e-5
+    atol = 1e-8
+
+    if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+        msg = (
+            "DataFrames are not almost equal: "
+            + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
+            + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
+        )
+        if left.shape != right.shape:
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_DATAFRAME",
+                message_parameters={
+                    "msg": msg,
+                },
+            )
+        for lcol, rcol in zip(left.columns, right.columns):
+            if lcol != rcol:
+                raise PySparkAssertionError(
+                    error_class="DIFFERENT_PANDAS_DATAFRAME",
+                    message_parameters={
+                        "msg": msg,
+                    },
+                )
+            for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()):
+                if lnull != rnull:
+                    raise PySparkAssertionError(
+                        error_class="DIFFERENT_PANDAS_DATAFRAME",
+                        message_parameters={
+                            "msg": msg,
+                        },
+                    )
+            for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()):
+                if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and (
+                    isinstance(rval, float) or isinstance(rval, decimal.Decimal)
+                ):
+                    if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))):
+                        raise PySparkAssertionError(
+                            error_class="DIFFERENT_PANDAS_DATAFRAME",
+                            message_parameters={
+                                "msg": msg,
+                            },
+                        )
+        if left.columns.names != right.columns.names:
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_DATAFRAME",
+                message_parameters={
+                    "msg": msg,
+                },
+            )
+    elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+        msg = (
+            "Series are not almost equal: "
+            + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+            + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+        )
+        if left.name != right.name or len(left) != len(right):
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_DATAFRAME",
+                message_parameters={
+                    "msg": msg,
+                },
+            )
+        for lnull, rnull in zip(left.isnull(), right.isnull()):
+            if lnull != rnull:
+                raise PySparkAssertionError(
+                    error_class="DIFFERENT_PANDAS_DATAFRAME",
+                    message_parameters={
+                        "msg": msg,
+                    },
+                )
+        for lval, rval in zip(left.dropna(), right.dropna()):
+            if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and (
+                isinstance(rval, float) or isinstance(rval, decimal.Decimal)
+            ):
+                if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))):
+                    raise PySparkAssertionError(
+                        error_class="DIFFERENT_PANDAS_DATAFRAME",
+                        message_parameters={
+                            "msg": msg,
+                        },
+                    )
+    elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex):
+        msg = (
+            "MultiIndices are not almost equal: "
+            + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+            + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+        )
+        if len(left) != len(right):
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_DATAFRAME",
+                message_parameters={
+                    "msg": msg,
+                },
+            )
+        for lval, rval in zip(left, right):
+            if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and (
+                isinstance(rval, float) or isinstance(rval, decimal.Decimal)
+            ):
+                if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))):
+                    raise PySparkAssertionError(
+                        error_class="DIFFERENT_PANDAS_DATAFRAME",
+                        message_parameters={
+                            "msg": msg,
+                        },
+                    )
+    elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+        msg = (
+            "Indices are not almost equal: "
+            + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+            + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+        )
+        if len(left) != len(right):
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_DATAFRAME",
+                message_parameters={
+                    "msg": msg,
+                },
+            )
+        for lnull, rnull in zip(left.isnull(), right.isnull()):
+            if lnull != rnull:
+                raise PySparkAssertionError(
+                    error_class="DIFFERENT_PANDAS_DATAFRAME",
+                    message_parameters={
+                        "msg": msg,
+                    },
+                )
+        for lval, rval in zip(left.dropna(), right.dropna()):
+            if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and (
+                isinstance(rval, float) or isinstance(rval, decimal.Decimal)
+            ):
+                if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))):
+                    raise PySparkAssertionError(
+                        error_class="DIFFERENT_PANDAS_DATAFRAME",
+                        message_parameters={
+                            "msg": msg,
+                        },
+                    )
+    else:
+        raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+
+def assertPandasOnSparkEqual(
+    actual: Union[DataFrame, Series, Index],
+    expected: Union[DataFrame, pd.DataFrame, Series, Index],
+    checkExact: bool = True,
+    almost: bool = False,
+    checkRowOrder: bool = False,
+):
+    r"""
+    A util function to assert equality between actual (pandas-on-Spark DataFrame) and expected
+    (pandas-on-Spark or pandas DataFrame).
+
+    .. versionadded:: 3.5.0
+
+    Parameters
+    ----------
+    actual: pandas-on-Spark DataFrame
+        The DataFrame that is being compared or tested.
+    expected: pandas-on-Spark or pandas DataFrame
+        The expected DataFrame, for comparison with the actual result.
+    checkExact: bool, optional
+        A flag indicating whether to compare exact equality.
+        If set to 'True' (default), the data is compared exactly.
+        If set to 'False', the data is compared less precisely, following pandas assert_frame_equal
+        approximate comparison (see documentation for more details).
+    almost: bool, optional
+        A flag indicating whether to use unittest `assertAlmostEqual` or `assertEqual`.
+        If set to 'True', the comparison is delegated to `unittest`'s `assertAlmostEqual`
+        (see documentation for more details).
+        If set to 'False' (default), the data is compared exactly with `unittest`'s
+        `assertEqual`.
+    checkRowOrder : bool, optional
+        A flag indicating whether the order of rows should be considered in the comparison.
+        If set to `False` (default), the row order is not taken into account.
+        If set to `True`, the order of rows is important and will be checked during comparison.
+        (See Notes)
+
+    Notes
+    -----
+    For `checkRowOrder`, note that pandas-on-Spark DataFrame ordering is non-deterministic, unless
+    explicitly sorted.
+
+    Examples
+    --------
+    >>> import pyspark.pandas as ps
+    >>> psdf1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
+    >>> psdf2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
+    >>> assertPandasOnSparkEqual(psdf1, psdf2)  # pass, ps.DataFrames are equal
+    >>> s1 = ps.Series([212.32, 100.0001])
+    >>> s2 = ps.Series([212.32, 100.0])
+    >>> assertPandasOnSparkEqual(s1, s2, checkExact=False)  # pass, ps.Series are approx equal
+    >>> s1 = ps.Index([212.300001, 100.000])
+    >>> s2 = ps.Index([212.3, 100.0001])
+    >>> assertPandasOnSparkEqual(s1, s2, almost=True)  # pass, ps.Index obj are almost equal
+    """
+    if actual is None and expected is None:
+        return True
+    elif actual is None or expected is None:
+        return False
+
+    if not isinstance(actual, (DataFrame, Series, Index)):
+        raise PySparkAssertionError(
+            error_class="INVALID_TYPE_DF_EQUALITY_ARG",
+            message_parameters={
+                "expected_type": Union[DataFrame, Series, Index],
+                "arg_name": "actual",
+                "actual_type": type(actual),
+            },
+        )
+    elif not isinstance(expected, (DataFrame, pd.DataFrame, Series, Index)):
+        raise PySparkAssertionError(
+            error_class="INVALID_TYPE_DF_EQUALITY_ARG",
+            message_parameters={
+                "expected_type": Union[DataFrame, pd.DataFrame, Series, Index],
+                "arg_name": "expected",
+                "actual_type": type(expected),
+            },
+        )
+    else:
+        actual = actual.to_pandas()
+        if not isinstance(expected, pd.DataFrame):
+            expected = expected.to_pandas()
+
+        if not checkRowOrder:
+            if isinstance(actual, pd.DataFrame) and len(actual.columns) > 0:
+                actual = actual.sort_values(by=actual.columns[0], ignore_index=True)
+            if isinstance(expected, pd.DataFrame) and len(expected.columns) > 0:
+                expected = expected.sort_values(by=expected.columns[0], ignore_index=True)
+
+        if almost:
+            assertPandasDFAlmostEqual(actual, expected)

Review Comment:
   If `assertPandasDFAlmostEqual` and `assertPandasDFEqual` are only used internally, let's explicitly mark them as an internal function by adding `_` prefix.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org