You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/04/09 22:50:44 UTC

[spark] branch master updated: [SPARK-27387][PYTHON][TESTS] Replace sqlutils.assertPandasEqual with Pandas assert_frame_equals

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f62f44f  [SPARK-27387][PYTHON][TESTS] Replace sqlutils.assertPandasEqual with Pandas assert_frame_equals
f62f44f is described below

commit f62f44f2a277c38d3d5b5524b287340991523236
Author: Bryan Cutler <cu...@gmail.com>
AuthorDate: Wed Apr 10 07:50:25 2019 +0900

    [SPARK-27387][PYTHON][TESTS] Replace sqlutils.assertPandasEqual with Pandas assert_frame_equals
    
    ## What changes were proposed in this pull request?
    
    Running PySpark tests with Pandas 0.24.x causes a failure in `test_pandas_udf_grouped_map` test_supported_types:
    `ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()`
    
    This is because a column is an ArrayType and the method `sqlutils ReusedSQLTestCase.assertPandasEqual ` does not properly check this.
    
    This PR removes `assertPandasEqual` and replaces it with the built-in `pandas.util.testing.assert_frame_equal` which can properly handle columns of ArrayType and also prints out better diff between the DataFrames when an error occurs.
    
    Additionally, imports of pandas and pyarrow were moved to the top of related test files to avoid duplicating the same import many times.
    
    ## How was this patch tested?
    
    Existing tests
    
    Closes #24306 from BryanCutler/python-pandas-assert_frame_equal-SPARK-27387.
    
    Authored-by: Bryan Cutler <cu...@gmail.com>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 python/pyspark/sql/tests/test_arrow.py             | 44 +++++++---------
 python/pyspark/sql/tests/test_dataframe.py         |  7 +--
 .../sql/tests/test_pandas_udf_grouped_agg.py       | 60 +++++++++++----------
 .../sql/tests/test_pandas_udf_grouped_map.py       | 61 +++++++++++++---------
 python/pyspark/sql/tests/test_pandas_udf_scalar.py | 33 +++---------
 python/pyspark/sql/tests/test_pandas_udf_window.py | 41 ++++++++-------
 python/pyspark/testing/sqlutils.py                 |  6 ---
 7 files changed, 115 insertions(+), 137 deletions(-)

diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index 38a6402..a45c3fb 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -29,6 +29,13 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarro
 from pyspark.testing.utils import QuietTest
 from pyspark.util import _exception_message
 
+if have_pandas:
+    import pandas as pd
+    from pandas.util.testing import assert_frame_equal
+
+if have_pyarrow:
+    import pyarrow as pa
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,
@@ -40,7 +47,6 @@ class ArrowTests(ReusedSQLTestCase):
         from datetime import date, datetime
         from decimal import Decimal
         from distutils.version import LooseVersion
-        import pyarrow as pa
         super(ArrowTests, cls).setUpClass()
         cls.warnings_lock = threading.Lock()
 
@@ -89,7 +95,6 @@ class ArrowTests(ReusedSQLTestCase):
         super(ArrowTests, cls).tearDownClass()
 
     def create_pandas_data_frame(self):
-        import pandas as pd
         import numpy as np
         data_dict = {}
         for j, name in enumerate(self.schema.names):
@@ -100,8 +105,6 @@ class ArrowTests(ReusedSQLTestCase):
         return pd.DataFrame(data=data_dict)
 
     def test_toPandas_fallback_enabled(self):
-        import pandas as pd
-
         with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
             schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
             df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
@@ -117,11 +120,10 @@ class ArrowTests(ReusedSQLTestCase):
                         self.assertTrue(len(user_warns) > 0)
                         self.assertTrue(
                             "Attempting non-optimization" in _exception_message(user_warns[-1]))
-                        self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
+                        assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
 
     def test_toPandas_fallback_disabled(self):
         from distutils.version import LooseVersion
-        import pyarrow as pa
 
         schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
         df = self.spark.createDataFrame([(None,)], schema=schema)
@@ -157,8 +159,8 @@ class ArrowTests(ReusedSQLTestCase):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
         pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
         expected = self.create_pandas_data_frame()
-        self.assertPandasEqual(expected, pdf)
-        self.assertPandasEqual(expected, pdf_arrow)
+        assert_frame_equal(expected, pdf)
+        assert_frame_equal(expected, pdf_arrow)
 
     def test_toPandas_respect_session_timezone(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
@@ -168,13 +170,13 @@ class ArrowTests(ReusedSQLTestCase):
                 "spark.sql.execution.pandas.respectSessionTimeZone": False,
                 "spark.sql.session.timeZone": timezone}):
             pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
-            self.assertPandasEqual(pdf_arrow_la, pdf_la)
+            assert_frame_equal(pdf_arrow_la, pdf_la)
 
         with self.sql_conf({
                 "spark.sql.execution.pandas.respectSessionTimeZone": True,
                 "spark.sql.session.timeZone": timezone}):
             pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
-            self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
+            assert_frame_equal(pdf_arrow_ny, pdf_ny)
 
             self.assertFalse(pdf_ny.equals(pdf_la))
 
@@ -184,13 +186,13 @@ class ArrowTests(ReusedSQLTestCase):
                 if isinstance(field.dataType, TimestampType):
                     pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
                         pdf_la_corrected[field.name], timezone)
-            self.assertPandasEqual(pdf_ny, pdf_la_corrected)
+            assert_frame_equal(pdf_ny, pdf_la_corrected)
 
     def test_pandas_round_trip(self):
         pdf = self.create_pandas_data_frame()
         df = self.spark.createDataFrame(self.data, schema=self.schema)
         pdf_arrow = df.toPandas()
-        self.assertPandasEqual(pdf_arrow, pdf)
+        assert_frame_equal(pdf_arrow, pdf)
 
     def test_filtered_frame(self):
         df = self.spark.range(3).toDF("i")
@@ -245,7 +247,7 @@ class ArrowTests(ReusedSQLTestCase):
         df = self.spark.createDataFrame(pdf, schema=self.schema)
         self.assertEquals(self.schema, df.schema)
         pdf_arrow = df.toPandas()
-        self.assertPandasEqual(pdf_arrow, pdf)
+        assert_frame_equal(pdf_arrow, pdf)
 
     def test_createDataFrame_with_incorrect_schema(self):
         pdf = self.create_pandas_data_frame()
@@ -267,7 +269,6 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertEquals(df.schema.fieldNames(), new_names)
 
     def test_createDataFrame_column_name_encoding(self):
-        import pandas as pd
         pdf = pd.DataFrame({u'a': [1]})
         columns = self.spark.createDataFrame(pdf).columns
         self.assertTrue(isinstance(columns[0], str))
@@ -277,13 +278,11 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertEquals(columns[0], 'b')
 
     def test_createDataFrame_with_single_data_type(self):
-        import pandas as pd
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
                 self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
 
     def test_createDataFrame_does_not_modify_input(self):
-        import pandas as pd
         # Some series get converted for Spark to consume, this makes sure input is unchanged
         pdf = self.create_pandas_data_frame()
         # Use a nanosecond value to make sure it is not truncated
@@ -301,7 +300,6 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertEquals(self.schema, schema_rt)
 
     def test_createDataFrame_with_array_type(self):
-        import pandas as pd
         pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
         df, df_arrow = self._createDataFrame_toggle(pdf)
         result = df.collect()
@@ -327,7 +325,6 @@ class ArrowTests(ReusedSQLTestCase):
 
     def test_createDataFrame_with_int_col_names(self):
         import numpy as np
-        import pandas as pd
         pdf = pd.DataFrame(np.random.rand(4, 2))
         df, df_arrow = self._createDataFrame_toggle(pdf)
         pdf_col_names = [str(c) for c in pdf.columns]
@@ -335,8 +332,6 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertEqual(pdf_col_names, df_arrow.columns)
 
     def test_createDataFrame_fallback_enabled(self):
-        import pandas as pd
-
         with QuietTest(self.sc):
             with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
                 with warnings.catch_warnings(record=True) as warns:
@@ -354,8 +349,6 @@ class ArrowTests(ReusedSQLTestCase):
 
     def test_createDataFrame_fallback_disabled(self):
         from distutils.version import LooseVersion
-        import pandas as pd
-        import pyarrow as pa
 
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
@@ -371,7 +364,6 @@ class ArrowTests(ReusedSQLTestCase):
 
     # Regression test for SPARK-23314
     def test_timestamp_dst(self):
-        import pandas as pd
         # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
         dt = [datetime.datetime(2015, 11, 1, 0, 30),
               datetime.datetime(2015, 11, 1, 1, 30),
@@ -381,8 +373,8 @@ class ArrowTests(ReusedSQLTestCase):
         df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
         df_from_pandas = self.spark.createDataFrame(pdf)
 
-        self.assertPandasEqual(pdf, df_from_python.toPandas())
-        self.assertPandasEqual(pdf, df_from_pandas.toPandas())
+        assert_frame_equal(pdf, df_from_python.toPandas())
+        assert_frame_equal(pdf, df_from_pandas.toPandas())
 
     def test_toPandas_batch_order(self):
 
@@ -398,7 +390,7 @@ class ArrowTests(ReusedSQLTestCase):
                 df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
             with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
                 pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
-                self.assertPandasEqual(pdf, pdf_arrow)
+                assert_frame_equal(pdf, pdf_arrow)
 
         cases = [
             (1024, 512, 2),    # Use large num partitions for more likely collecting out of order
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 65edf59..eb34bbb 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -581,14 +581,15 @@ class DataFrameTests(ReusedSQLTestCase):
 
     # Regression test for SPARK-23360
     @unittest.skipIf(not have_pandas, pandas_requirement_message)
-    def test_create_dateframe_from_pandas_with_dst(self):
+    def test_create_dataframe_from_pandas_with_dst(self):
         import pandas as pd
+        from pandas.util.testing import assert_frame_equal
         from datetime import datetime
 
         pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
 
         df = self.spark.createDataFrame(pdf)
-        self.assertPandasEqual(pdf, df.toPandas())
+        assert_frame_equal(pdf, df.toPandas())
 
         orig_env_tz = os.environ.get('TZ', None)
         try:
@@ -597,7 +598,7 @@ class DataFrameTests(ReusedSQLTestCase):
             time.tzset()
             with self.sql_conf({'spark.sql.session.timeZone': tz}):
                 df = self.spark.createDataFrame(pdf)
-                self.assertPandasEqual(pdf, df.toPandas())
+                assert_frame_equal(pdf, df.toPandas())
         finally:
             del os.environ['TZ']
             if orig_env_tz is not None:
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
index 18264ea..9eda1aa 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
@@ -26,6 +26,10 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarro
     pandas_requirement_message, pyarrow_requirement_message
 from pyspark.testing.utils import QuietTest
 
+if have_pandas:
+    import pandas as pd
+    from pandas.util.testing import assert_frame_equal
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,
@@ -50,8 +54,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
 
     @property
     def pandas_scalar_plus_two(self):
-        import pandas as pd
-
         @pandas_udf('double', PandasUDFType.SCALAR)
         def plus_two(v):
             assert isinstance(v, pd.Series)
@@ -107,7 +109,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
              [9, 335.0, 33.5, [33.5]]],
             ['id', 'sum(v)', 'avg(v)', 'avg(array(v))'])
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_basic(self):
         df = self.data
@@ -116,19 +118,19 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
         # Groupby one column and aggregate one UDF with literal
         result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id')
         expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
         # Groupby one expression and aggregate one UDF with literal
         result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\
             .sort(df.id + 1)
         expected2 = df.groupby((col('id') + 1))\
             .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1)
-        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+        assert_frame_equal(expected2.toPandas(), result2.toPandas())
 
         # Groupby one column and aggregate one UDF without literal
         result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id')
         expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id')
-        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+        assert_frame_equal(expected3.toPandas(), result3.toPandas())
 
         # Groupby one expression and aggregate one UDF without literal
         result4 = df.groupby((col('id') + 1).alias('id'))\
@@ -137,7 +139,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
         expected4 = df.groupby((col('id') + 1).alias('id'))\
             .agg(mean(df.v).alias('weighted_mean(v, w)'))\
             .sort('id')
-        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+        assert_frame_equal(expected4.toPandas(), result4.toPandas())
 
     def test_unsupported_types(self):
         with QuietTest(self.sc):
@@ -166,7 +168,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
         result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias'))
         expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias'))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_mixed_sql(self):
         """
@@ -200,9 +202,9 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
                      .agg(sum(df.v + 1) + 2)
                      .sort('id'))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
-        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
-        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected2.toPandas(), result2.toPandas())
+        assert_frame_equal(expected3.toPandas(), result3.toPandas())
 
     def test_mixed_udfs(self):
         """
@@ -262,12 +264,12 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
                      .agg(plus_two(sum(plus_two(df.v))))
                      .sort('plus_two(id)'))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
-        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
-        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
-        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
-        self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
-        self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected2.toPandas(), result2.toPandas())
+        assert_frame_equal(expected3.toPandas(), result3.toPandas())
+        assert_frame_equal(expected4.toPandas(), result4.toPandas())
+        assert_frame_equal(expected5.toPandas(), result5.toPandas())
+        assert_frame_equal(expected6.toPandas(), result6.toPandas())
 
     def test_multiple_udfs(self):
         """
@@ -291,7 +293,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
                      .sort('id')
                      .toPandas())
 
-        self.assertPandasEqual(expected1, result1)
+        assert_frame_equal(expected1, result1)
 
     def test_complex_groupby(self):
         df = self.data
@@ -327,13 +329,13 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
         result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)')
         expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)')
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
-        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
-        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
-        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
-        self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
-        self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
-        self.assertPandasEqual(expected7.toPandas(), result7.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected2.toPandas(), result2.toPandas())
+        assert_frame_equal(expected3.toPandas(), result3.toPandas())
+        assert_frame_equal(expected4.toPandas(), result4.toPandas())
+        assert_frame_equal(expected5.toPandas(), result5.toPandas())
+        assert_frame_equal(expected6.toPandas(), result6.toPandas())
+        assert_frame_equal(expected7.toPandas(), result7.toPandas())
 
     def test_complex_expressions(self):
         df = self.data
@@ -404,9 +406,9 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
                      .sort('id')
                      .toPandas())
 
-        self.assertPandasEqual(expected1, result1)
-        self.assertPandasEqual(expected2, result2)
-        self.assertPandasEqual(expected3, result3)
+        assert_frame_equal(expected1, result1)
+        assert_frame_equal(expected2, result2)
+        assert_frame_equal(expected3, result3)
 
     def test_retain_group_columns(self):
         with self.sql_conf({"spark.sql.retainGroupColumns": False}):
@@ -415,7 +417,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
 
             result1 = df.groupby(df.id).agg(sum_udf(df.v))
             expected1 = df.groupby(df.id).agg(sum(df.v))
-            self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+            assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_array_type(self):
         df = self.data
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
index f7684d3..c8bad99 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
@@ -17,6 +17,7 @@
 
 import datetime
 import unittest
+import sys
 
 from collections import OrderedDict
 from decimal import Decimal
@@ -29,6 +30,23 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarro
     pandas_requirement_message, pyarrow_requirement_message
 from pyspark.testing.utils import QuietTest
 
+if have_pandas:
+    import pandas as pd
+    from pandas.util.testing import assert_frame_equal
+
+if have_pyarrow:
+    import pyarrow as pa
+
+
+"""
+Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
+from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
+"""
+if sys.version < '3':
+    _check_column_type = False
+else:
+    _check_column_type = True
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,
@@ -42,7 +60,6 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
             .withColumn("v", explode(col('vs'))).drop('vs')
 
     def test_supported_types(self):
-        import pyarrow as pa
 
         values = [
             1, 2, 3,
@@ -127,9 +144,9 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
         expected3 = expected1
 
-        self.assertPandasEqual(expected1, result1)
-        self.assertPandasEqual(expected2, result2)
-        self.assertPandasEqual(expected3, result3)
+        assert_frame_equal(expected1, result1, check_column_type=_check_column_type)
+        assert_frame_equal(expected2, result2, check_column_type=_check_column_type)
+        assert_frame_equal(expected3, result3, check_column_type=_check_column_type)
 
     def test_array_type_correct(self):
         df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
@@ -147,7 +164,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
 
         result = df.groupby('id').apply(udf).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
     def test_register_grouped_map_udf(self):
         foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
@@ -169,7 +186,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
 
         result = df.groupby('id').apply(foo).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
     def test_coerce(self):
         df = self.data
@@ -183,7 +200,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         result = df.groupby('id').apply(foo).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
         expected = expected.assign(v=expected.v.astype('float64'))
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
     def test_complex_groupby(self):
         df = self.data
@@ -201,7 +218,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         expected = pdf.groupby(pdf['id'] % 2 == 0, as_index=False).apply(normalize.func)
         expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
         expected = expected.assign(norm=expected.norm.astype('float64'))
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
     def test_empty_groupby(self):
         df = self.data
@@ -219,7 +236,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         expected = normalize.func(pdf)
         expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
         expected = expected.assign(norm=expected.norm.astype('float64'))
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
     def test_datatype_string(self):
         df = self.data
@@ -232,7 +249,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
 
         result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
     def test_wrong_return_type(self):
         with QuietTest(self.sc):
@@ -266,8 +283,6 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
                     pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
 
     def test_unsupported_types(self):
-        import pyarrow as pa
-
         common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*'
         unsupported_types = [
             StructField('map', MapType(StringType(), IntegerType())),
@@ -295,7 +310,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
         foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP)
         result = df.groupby('time').apply(foo_udf).sort('time')
-        self.assertPandasEqual(df.toPandas(), result.toPandas())
+        assert_frame_equal(df.toPandas(), result.toPandas(), check_column_type=_check_column_type)
 
     def test_udf_with_key(self):
         import numpy as np
@@ -349,29 +364,28 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         expected1 = pdf.groupby('id', as_index=False)\
             .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
             .sort_values(['id', 'v']).reset_index(drop=True)
-        self.assertPandasEqual(expected1, result1)
+        assert_frame_equal(expected1, result1, check_column_type=_check_column_type)
 
         # Test groupby expression
         result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
         expected2 = pdf.groupby(pdf.id % 2, as_index=False)\
             .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
             .sort_values(['id', 'v']).reset_index(drop=True)
-        self.assertPandasEqual(expected2, result2)
+        assert_frame_equal(expected2, result2, check_column_type=_check_column_type)
 
         # Test complex groupby
         result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
         expected3 = pdf.groupby([pdf.id, pdf.v % 2], as_index=False)\
             .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
             .sort_values(['id', 'v']).reset_index(drop=True)
-        self.assertPandasEqual(expected3, result3)
+        assert_frame_equal(expected3, result3, check_column_type=_check_column_type)
 
         # Test empty groupby
         result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
         expected4 = udf3.func((), pdf)
-        self.assertPandasEqual(expected4, result4)
+        assert_frame_equal(expected4, result4, check_column_type=_check_column_type)
 
     def test_column_order(self):
-        import pandas as pd
 
         # Helper function to set column names from a list
         def rename_pdf(pdf, names):
@@ -402,7 +416,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
             .select('id', 'u', 'v').toPandas()
         pd_result = grouped_pdf.apply(change_col_order)
         expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
         # Function returns a pdf with positional columns, indexed by range
         def range_col_order(pdf):
@@ -421,7 +435,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         pd_result = grouped_pdf.apply(range_col_order)
         rename_pdf(pd_result, ['id', 'u', 'v'])
         expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
         # Function returns a pdf with columns indexed with integers
         def int_index(pdf):
@@ -439,7 +453,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         pd_result = grouped_pdf.apply(int_index)
         rename_pdf(pd_result, ['id', 'u', 'v'])
         expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
-        self.assertPandasEqual(expected, result)
+        assert_frame_equal(expected, result, check_column_type=_check_column_type)
 
         @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
         def column_name_typo(pdf):
@@ -452,7 +466,6 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
                 grouped_df.apply(column_name_typo).collect()
-            import pyarrow as pa
             if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
                 # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
                 with self.assertRaisesRegexp(Exception, "No cast implemented"):
@@ -462,8 +475,6 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
                     grouped_df.apply(invalid_positional_types).collect()
 
     def test_positional_assignment_conf(self):
-        import pandas as pd
-
         with self.sql_conf({
                 "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}):
 
@@ -492,8 +503,6 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         self.assertEquals(res.count(), 5)
 
     def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
-        import pandas as pd
-
         df = self.spark.range(0, 10).toDF('v1')
         df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
             .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index 7df918b..ebba074 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -41,6 +41,12 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled,\
     pyarrow_requirement_message
 from pyspark.testing.utils import QuietTest
 
+if have_pandas:
+    import pandas as pd
+
+if have_pyarrow:
+    import pyarrow as pa
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,
@@ -70,7 +76,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
 
     @property
     def nondeterministic_vectorized_udf(self):
-        import pandas as pd
         import numpy as np
 
         @pandas_udf('double')
@@ -205,7 +210,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         self.assertEquals(df.collect(), res.collect())
 
     def test_vectorized_udf_string_in_udf(self):
-        import pandas as pd
         df = self.spark.range(10)
         str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType())
         actual = df.select(str_f(col('id')))
@@ -236,8 +240,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         self.assertEquals(df.collect(), res.collect())
 
     def test_vectorized_udf_null_binary(self):
-        import pyarrow as pa
-
         if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
             with QuietTest(self.sc):
                 with self.assertRaisesRegexp(
@@ -269,9 +271,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         self.assertEquals(df.collect(), result.collect())
 
     def test_vectorized_udf_struct_type(self):
-        import pandas as pd
-        import pyarrow as pa
-
         df = self.spark.range(10)
         return_type = StructType([
             StructField('id', LongType()),
@@ -305,8 +304,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             self.assertEqual(expected, actual.collect())
 
     def test_vectorized_udf_struct_complex(self):
-        import pandas as pd
-
         df = self.spark.range(10)
         return_type = StructType([
             StructField('ts', TimestampType()),
@@ -359,8 +356,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                 df.select(raise_exception(col('id'))).collect()
 
     def test_vectorized_udf_invalid_length(self):
-        import pandas as pd
-
         df = self.spark.range(10)
         raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
         with QuietTest(self.sc):
@@ -377,8 +372,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         self.assertEquals(df.collect(), res.collect())
 
     def test_vectorized_udf_chained_struct_type(self):
-        import pandas as pd
-
         df = self.spark.range(10)
         return_type = StructType([
             StructField('id', LongType()),
@@ -470,7 +463,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
 
         @pandas_udf(returnType=StringType())
         def check_data(idx, date, date_copy):
-            import pandas as pd
             msgs = []
             is_equal = date.isnull()
             for i in range(len(idx)):
@@ -509,7 +501,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
 
         @pandas_udf(returnType=StringType())
         def check_data(idx, timestamp, timestamp_copy):
-            import pandas as pd
             msgs = []
             is_equal = timestamp.isnull()  # use this array to check values are equal
             for i in range(len(idx)):
@@ -533,8 +524,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             self.assertIsNone(result[i][3])  # "check_data" col
 
     def test_vectorized_udf_return_timestamp_tz(self):
-        import pandas as pd
-
         df = self.spark.range(10)
 
         @pandas_udf(returnType=TimestampType())
@@ -551,8 +540,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             self.assertEquals(expected, ts)
 
     def test_vectorized_udf_check_config(self):
-        import pandas as pd
-
         with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
             df = self.spark.range(10, numPartitions=1)
 
@@ -565,8 +552,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                 self.assertTrue(r <= 3)
 
     def test_vectorized_udf_timestamps_respect_session_timezone(self):
-        import pandas as pd
-
         schema = StructType([
             StructField("idx", LongType(), True),
             StructField("timestamp", TimestampType(), True)])
@@ -653,7 +638,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
 
     @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
     def test_type_annotation(self):
-        from pyspark.sql.functions import pandas_udf
         # Regression test to check if type hints can be used. See SPARK-23569.
         # Note that it throws an error during compilation in lower Python versions if 'exec'
         # is not used. Also, note that we explicitly use another dictionary to avoid modifications
@@ -670,8 +654,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         self.assertEqual(df.first()[0], 0)
 
     def test_mixed_udf(self):
-        import pandas as pd
-
         df = self.spark.range(0, 1).toDF('v')
 
         # Test mixture of multiple UDFs and Pandas UDFs.
@@ -772,8 +754,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         self.assertEquals(expected.collect(), df_multi_2.collect())
 
     def test_mixed_udf_and_sql(self):
-        import pandas as pd
-
         df = self.spark.range(0, 1).toDF('v')
 
         # Test mixture of UDFs, Pandas UDFs and SQL expression.
@@ -831,7 +811,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
     def test_datasource_with_udf(self):
         # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF
         # This needs to a separate test because Arrow dependency is optional
-        import pandas as pd
         import numpy as np
 
         path = tempfile.mkdtemp()
diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py
index 3ba98e7..7d6540d 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_window.py
@@ -25,6 +25,9 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarro
     pandas_requirement_message, pyarrow_requirement_message
 from pyspark.testing.utils import QuietTest
 
+if have_pandas:
+    from pandas.util.testing import assert_frame_equal
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,
@@ -48,8 +51,6 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
 
     @property
     def pandas_agg_count_udf(self):
-        from pyspark.sql.functions import pandas_udf, PandasUDFType
-
         @pandas_udf('long', PandasUDFType.GROUPED_AGG)
         def count(v):
             return len(v)
@@ -127,8 +128,8 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         result2 = df.select(mean_udf(df['v']).over(w))
         expected2 = df.select(mean(df['v']).over(w))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
-        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected2.toPandas(), result2.toPandas())
 
     def test_multiple_udfs(self):
         df = self.data
@@ -142,7 +143,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
             .withColumn('max_v', max(df['v']).over(w)) \
             .withColumn('min_w', min(df['w']).over(w))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_replace_existing(self):
         df = self.data
@@ -151,7 +152,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w))
         expected1 = df.withColumn('v', mean(df['v']).over(w))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_mixed_sql(self):
         df = self.data
@@ -161,7 +162,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1)
         expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1)
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_mixed_udf(self):
         df = self.data
@@ -185,8 +186,8 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
             'v2',
             time_two(mean(time_two(df['v'])).over(w)))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
-        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected2.toPandas(), result2.toPandas())
 
     def test_without_partitionBy(self):
         df = self.data
@@ -199,8 +200,8 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         result2 = df.select(mean_udf(df['v']).over(w))
         expected2 = df.select(mean(df['v']).over(w))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
-        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected2.toPandas(), result2.toPandas())
 
     def test_mixed_sql_and_udf(self):
         df = self.data
@@ -229,10 +230,10 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
             .withColumn('rank', rank().over(ow))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
-        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
-        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
-        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected2.toPandas(), result2.toPandas())
+        assert_frame_equal(expected3.toPandas(), result3.toPandas())
+        assert_frame_equal(expected4.toPandas(), result4.toPandas())
 
     def test_array_type(self):
         df = self.data
@@ -276,7 +277,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
             .withColumn('max_v', max(df['v']).over(w2)) \
             .withColumn('min_v', min(df['v']).over(w1))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_growing_window(self):
         from pyspark.sql.functions import mean
@@ -293,7 +294,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
             .withColumn('m2', mean(df['v']).over(w2))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_sliding_window(self):
         from pyspark.sql.functions import mean
@@ -310,7 +311,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
             .withColumn('m2', mean(df['v']).over(w2))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_shrinking_window(self):
         from pyspark.sql.functions import mean
@@ -327,7 +328,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
             .withColumn('m2', mean(df['v']).over(w2))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_bounded_mixed(self):
         from pyspark.sql.functions import mean, max
@@ -347,7 +348,7 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
             .withColumn('max_v', max(df['v']).over(w2)) \
             .withColumn('mean_unbounded_v', mean(df['v']).over(w1))
 
-        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index afc40cc..13800cf 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -260,9 +260,3 @@ class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
         super(ReusedSQLTestCase, cls).tearDownClass()
         cls.spark.stop()
         shutil.rmtree(cls.tempdir.name, ignore_errors=True)
-
-    def assertPandasEqual(self, expected, result):
-        msg = ("DataFrames are not equal: " +
-               "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
-               "\n\nResult:\n%s\n%s" % (result, result.dtypes))
-        self.assertTrue(expected.equals(result), msg=msg)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org