You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by cu...@apache.org on 2020/05/28 00:28:26 UTC

[spark] branch master updated: [SPARK-25351][SQL][PYTHON] Handle Pandas category type when converting from Python with Arrow

This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 339b0eca [SPARK-25351][SQL][PYTHON] Handle Pandas category type when converting from Python with Arrow
339b0eca is described below

commit 339b0ecadb9c66ec8a62fd1f8e5a7a266b465aef
Author: Jalpan Randeri <ra...@amazon.com>
AuthorDate: Wed May 27 17:27:29 2020 -0700

    [SPARK-25351][SQL][PYTHON] Handle Pandas category type when converting from Python with Arrow
    
    Handle Pandas category type while converting from python with Arrow enabled. The category column will be converted to whatever type the category elements are as is the case with Arrow disabled.
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    New unit tests were added for `createDataFrame` and scalar `pandas_udf`
    
    Closes #26585 from jalpan-randeri/feature-pyarrow-dictionary-type.
    
    Authored-by: Jalpan Randeri <ra...@amazon.com>
    Signed-off-by: Bryan Cutler <cu...@gmail.com>
---
 python/pyspark/sql/pandas/serializers.py           |  3 +++
 python/pyspark/sql/pandas/types.py                 |  2 ++
 python/pyspark/sql/tests/test_arrow.py             | 26 ++++++++++++++++++++++
 python/pyspark/sql/tests/test_pandas_udf_scalar.py | 21 +++++++++++++++++
 4 files changed, 52 insertions(+)

diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 4dd15d1..ff0b10a 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -154,6 +154,9 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
             # Ensure timestamp series are in expected form for Spark internal representation
             if t is not None and pa.types.is_timestamp(t):
                 s = _check_series_convert_timestamps_internal(s, self._timezone)
+            elif type(s.dtype) == pd.CategoricalDtype:
+                # Note: This can be removed once minimum pyarrow version is >= 0.16.1
+                s = s.astype(s.dtypes.categories.dtype)
             try:
                 array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
             except pa.ArrowException as e:
diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py
index d1edf3f..4b70c8a 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -114,6 +114,8 @@ def from_arrow_type(at):
         return StructType(
             [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
              for field in at])
+    elif types.is_dictionary(at):
+        spark_type = from_arrow_type(at.value_type)
     else:
         raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
     return spark_type
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index 004c79f..c3c9fb0 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -415,6 +415,32 @@ class ArrowTests(ReusedSQLTestCase):
         for case in cases:
             run_test(*case)
 
+    def test_createDateFrame_with_category_type(self):
+        pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
+        pdf["B"] = pdf["A"].astype('category')
+        category_first_element = dict(enumerate(pdf['B'].cat.categories))[0]
+
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
+            arrow_df = self.spark.createDataFrame(pdf)
+            arrow_type = arrow_df.dtypes[1][1]
+            result_arrow = arrow_df.toPandas()
+            arrow_first_category_element = result_arrow["B"][0]
+
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
+            df = self.spark.createDataFrame(pdf)
+            spark_type = df.dtypes[1][1]
+            result_spark = df.toPandas()
+            spark_first_category_element = result_spark["B"][0]
+
+        assert_frame_equal(result_spark, result_arrow)
+
+        # ensure original category elements are string
+        assert isinstance(category_first_element, str)
+        # spark data frame and arrow execution mode enabled data frame type must match pandas
+        assert spark_type == arrow_type == 'string'
+        assert isinstance(arrow_first_category_element, str)
+        assert isinstance(spark_first_category_element, str)
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index 7260e80..ae6b8d5 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -897,6 +897,27 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             result = df.withColumn('time', foo_udf(df.time))
             self.assertEquals(df.collect(), result.collect())
 
+    def test_udf_category_type(self):
+
+        @pandas_udf('string')
+        def to_category_func(x):
+            return x.astype('category')
+
+        pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
+        df = self.spark.createDataFrame(pdf)
+        df = df.withColumn("B", to_category_func(df['A']))
+        result_spark = df.toPandas()
+
+        spark_type = df.dtypes[1][1]
+        # spark data frame and arrow execution mode enabled data frame type must match pandas
+        assert spark_type == 'string'
+
+        # Check result value of column 'B' must be equal to column 'A'
+        for i in range(0, len(result_spark["A"])):
+            assert result_spark["A"][i] == result_spark["B"][i]
+            assert isinstance(result_spark["A"][i], str)
+            assert isinstance(result_spark["B"][i], str)
+
     @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
     def test_type_annotation(self):
         # Regression test to check if type hints can be used. See SPARK-23569.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org