You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/04/15 14:26:51 UTC

[GitHub] [spark] Yikun commented on a change in pull request #32177: [WIP][SPARK-34999][PYTHON] Consolidate PySpark testing utils

Yikun commented on a change in pull request #32177:
URL: https://github.com/apache/spark/pull/32177#discussion_r614085656



##########
File path: python/pyspark/testing/utils.py
##########
@@ -39,6 +54,29 @@
     # No NumPy, but that's okay, we'll skip those tests
     pass
 
+tabulate_requirement_message = None
+try:
+    from tabulate import tabulate  # noqa: F401
+except ImportError as e:
+    # If tabulate requirement is not satisfied, skip related tests.
+    tabulate_requirement_message = str(e)
+have_tabulate = tabulate_requirement_message is None

Review comment:
       I'm not sure why we need introduce the an extra `tabulate_requirement_message`?
   
   I would like the below code:
   ```suggestion
   have_tabulate = True
   try:
       from tabulate import tabulate  # noqa: F401
   except ImportError:
       # If tabulate requirement is not satisfied, skip related tests.
       have_tabulate = False
   ```
   or just keep the above style:
   ```suggestion
   have_tabulate = True
   try:
       from tabulate import tabulate  # noqa: F401
   except ImportError:
       # If tabulate requirement is not satisfied, skip related tests.
       have_tabulate = False
   ```

##########
File path: python/pyspark/pandas/tests/test_dataframe_conversion.py
##########
@@ -26,7 +26,7 @@
 
 from pyspark import pandas as pp

Review comment:
       unreltaed: we do the rename on pp --> ps in  https://github.com/apache/spark/pull/32108 , 
   
   - `from pyspark import pandas as pd`  <-- this is the panda develper style to name pandas.
   - `from pyspark import pandas as ps`  <-- does it mean `panda on spark`
   
   I know it's unrelated but would mind also do it in your patch? Or I'm also okay to submited in a separated patch.

##########
File path: python/pyspark/testing/utils.py
##########
@@ -39,6 +54,29 @@
     # No NumPy, but that's okay, we'll skip those tests

Review comment:
       unrelated, but better to change `except` to `except ImportError`

##########
File path: python/pyspark/testing/utils.py
##########
@@ -171,3 +209,393 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
         raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars)))
     else:
         return jars[0]
+
+
+# Utilities below are used mainly in pyspark/pandas
+class SQLTestUtils(object):
+    """
+    This util assumes the instance of this to have 'spark' attribute, having a spark session.
+    It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
+    the implementation of this class has 'spark' attribute.
+    """
+
+    @contextmanager
+    def sql_conf(self, pairs):
+        """
+        A convenient context manager to test some configuration specific logic. This sets
+        `value` to the configuration `key` and then restores it back when it exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+        with sqlc(pairs, spark=self.spark):
+            yield
+
+    @contextmanager
+    def database(self, *databases):
+        """
+        A convenient context manager to test with some specific databases. This drops the given
+        databases if it exists and sets current database to "default" when it exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+        try:
+            yield
+        finally:
+            for db in databases:
+                self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
+            self.spark.catalog.setCurrentDatabase("default")
+
+    @contextmanager
+    def table(self, *tables):
+        """
+        A convenient context manager to test with some specific tables. This drops the given tables
+        if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+        try:
+            yield
+        finally:
+            for t in tables:
+                self.spark.sql("DROP TABLE IF EXISTS %s" % t)
+
+    @contextmanager
+    def tempView(self, *views):
+        """
+        A convenient context manager to test with some specific views. This drops the given views
+        if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+        try:
+            yield
+        finally:
+            for v in views:
+                self.spark.catalog.dropTempView(v)
+
+    @contextmanager
+    def function(self, *functions):
+        """
+        A convenient context manager to test with some specific functions. This drops the given
+        functions if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+        try:
+            yield
+        finally:
+            for f in functions:
+                self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
+
+
+class ReusedSQLTestCase(unittest.TestCase, SQLTestUtils):
+    @classmethod
+    def setUpClass(cls):
+        cls.spark = default_session()
+        cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
+
+    @classmethod
+    def tearDownClass(cls):
+        # We don't stop Spark session to reuse across all tests.
+        # The Spark session will be started and stopped at PyTest session level.
+        # Please see databricks/koalas/conftest.py.
+        pass
+
+    def assertPandasEqual(self, left, right, check_exact=True):
+        if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+            try:
+                if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+                    kwargs = dict(check_freq=False)
+                else:
+                    kwargs = dict()
+
+                assert_frame_equal(
+                    left,
+                    right,
+                    check_index_type=("equiv" if len(left.index) > 0 else False),
+                    check_column_type=("equiv" if len(left.columns) > 0 else False),
+                    check_exact=check_exact,
+                    **kwargs
+                )
+            except AssertionError as e:
+                msg = (
+                    str(e)
+                    + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
+                    + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
+                )
+                raise AssertionError(msg) from e
+        elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+            try:
+                if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+                    kwargs = dict(check_freq=False)
+                else:
+                    kwargs = dict()
+
+                assert_series_equal(
+                    left,
+                    right,
+                    check_index_type=("equiv" if len(left.index) > 0 else False),
+                    check_exact=check_exact,
+                    **kwargs
+                )
+            except AssertionError as e:
+                msg = (
+                    str(e)
+                    + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                    + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+                )
+                raise AssertionError(msg) from e
+        elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+            try:
+                assert_index_equal(left, right, check_exact=check_exact)
+            except AssertionError as e:
+                msg = (
+                    str(e)
+                    + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                    + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+                )
+                raise AssertionError(msg) from e
+        else:
+            raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+    def assertPandasAlmostEqual(self, left, right):
+        """
+        This function checks if given pandas objects approximately same,
+        which means the conditions below:
+          - Both objects are nullable
+          - Compare floats rounding to the number of decimal places, 7 after
+            dropping missing values (NaN, NaT, None)
+        """
+        if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+            msg = (
+                "DataFrames are not almost equal: "
+                + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
+                + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
+            )
+            self.assertEqual(left.shape, right.shape, msg=msg)
+            for lcol, rcol in zip(left.columns, right.columns):
+                self.assertEqual(lcol, rcol, msg=msg)
+                for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()):
+                    self.assertEqual(lnull, rnull, msg=msg)
+                for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()):
+                    self.assertAlmostEqual(lval, rval, msg=msg)
+            self.assertEqual(left.columns.names, right.columns.names, msg=msg)
+        elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+            msg = (
+                "Series are not almost equal: "
+                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            )
+            self.assertEqual(left.name, right.name, msg=msg)
+            self.assertEqual(len(left), len(right), msg=msg)
+            for lnull, rnull in zip(left.isnull(), right.isnull()):
+                self.assertEqual(lnull, rnull, msg=msg)
+            for lval, rval in zip(left.dropna(), right.dropna()):
+                self.assertAlmostEqual(lval, rval, msg=msg)
+        elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex):
+            msg = (
+                "MultiIndices are not almost equal: "
+                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            )
+            self.assertEqual(len(left), len(right), msg=msg)
+            for lval, rval in zip(left, right):
+                self.assertAlmostEqual(lval, rval, msg=msg)
+        elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+            msg = (
+                "Indices are not almost equal: "
+                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            )
+            self.assertEqual(len(left), len(right), msg=msg)
+            for lnull, rnull in zip(left.isnull(), right.isnull()):
+                self.assertEqual(lnull, rnull, msg=msg)
+            for lval, rval in zip(left.dropna(), right.dropna()):
+                self.assertAlmostEqual(lval, rval, msg=msg)
+        else:
+            raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+    def assert_eq(self, left, right, check_exact=True, almost=False):
+        """
+        Asserts if two arbitrary objects are equal or not. If given objects are Koalas DataFrame
+        or Series, they are converted into pandas' and compared.
+
+        :param left: object to compare
+        :param right: object to compare
+        :param check_exact: if this is False, the comparison is done less precisely.
+        :param almost: if this is enabled, the comparison is delegated to `unittest`'s
+                       `assertAlmostEqual`. See its documentation for more details.
+        """
+        lobj = self._to_pandas(left)
+        robj = self._to_pandas(right)
+        if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)):
+            if almost:
+                self.assertPandasAlmostEqual(lobj, robj)
+            else:
+                self.assertPandasEqual(lobj, robj, check_exact=check_exact)
+        elif is_list_like(lobj) and is_list_like(robj):
+            self.assertTrue(len(left) == len(right))
+            for litem, ritem in zip(left, right):
+                self.assert_eq(litem, ritem, check_exact=check_exact, almost=almost)
+        elif (lobj is not None and pd.isna(lobj)) and (robj is not None and pd.isna(robj)):
+            pass
+        else:
+            if almost:
+                self.assertAlmostEqual(lobj, robj)
+            else:
+                self.assertEqual(lobj, robj)
+
+    @staticmethod
+    def _to_pandas(obj):
+        if isinstance(obj, (DataFrame, Series, Index)):
+            return obj.to_pandas()
+        else:
+            return obj
+
+
+class TestUtils(object):
+    @contextmanager
+    def temp_dir(self):
+        tmp = tempfile.mkdtemp()
+        try:
+            yield tmp
+        finally:
+            shutil.rmtree(tmp)
+
+    @contextmanager
+    def temp_file(self):
+        with self.temp_dir() as tmp:
+            yield tempfile.mktemp(dir=tmp)
+
+
+class ComparisonTestBase(ReusedSQLTestCase):
+    @property
+    def kdf(self):
+        return ps.from_pandas(self.pdf)
+
+    @property
+    def pdf(self):
+        return self.kdf.to_pandas()
+
+
+def compare_both(f=None, almost=True):
+
+    if f is None:
+        return functools.partial(compare_both, almost=almost)
+    elif isinstance(f, bool):
+        return functools.partial(compare_both, almost=f)
+
+    @functools.wraps(f)
+    def wrapped(self):
+        if almost:
+            compare = self.assertPandasAlmostEqual
+        else:
+            compare = self.assertPandasEqual
+
+        for result_pandas, result_spark in zip(f(self, self.pdf), f(self, self.kdf)):
+            compare(result_pandas, result_spark.to_pandas())
+
+    return wrapped
+
+
+@contextmanager
+def assert_produces_warning(
+    expected_warning=Warning,
+    filter_level="always",
+    check_stacklevel=True,
+    raise_on_extra_warnings=True,
+):
+    """
+    Context manager for running code expected to either raise a specific
+    warning, or not raise any warnings. Verifies that the code raises the
+    expected warning, and that it does not raise any other unexpected
+    warnings. It is basically a wrapper around ``warnings.catch_warnings``.
+
+    Notes
+    -----
+    Replicated from pandas._testing.

Review comment:
       This notes is outdate.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org