You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/02/06 09:30:56 UTC

spark git commit: [SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType() to handle str type properly in Python 2.

Repository: spark
Updated Branches:
  refs/heads/master 8141c3e3d -> 63c5bf13c


[SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType() to handle str type properly in Python 2.

## What changes were proposed in this pull request?

In Python 2, when `pandas_udf` tries to return string type value created in the udf with `".."`, the execution fails. E.g.,

```python
from pyspark.sql.functions import pandas_udf, col
import pandas as pd

df = spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string")
df.select(str_f(col('id'))).show()
```

raises the following exception:

```
...

java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType
	at scala.Predef$.assert(Predef.scala:170)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:93)

...
```

Seems like pyarrow ignores `type` parameter for `pa.Array.from_pandas()` and consider it as binary type when the type is string type and the string values are `str` instead of `unicode` in Python 2.

This pr adds a workaround for the case.

## How was this patch tested?

Added a test and existing tests.

Author: Takuya UESHIN <ue...@databricks.com>

Closes #20507 from ueshin/issues/SPARK-23334.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/63c5bf13
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/63c5bf13
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/63c5bf13

Branch: refs/heads/master
Commit: 63c5bf13ce5cd3b8d7e7fb88de881ed207fde720
Parents: 8141c3e
Author: Takuya UESHIN <ue...@databricks.com>
Authored: Tue Feb 6 18:30:50 2018 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Tue Feb 6 18:30:50 2018 +0900

----------------------------------------------------------------------
 python/pyspark/serializers.py | 4 ++++
 python/pyspark/sql/tests.py   | 9 +++++++++
 2 files changed, 13 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/63c5bf13/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index e870325..91a7f09 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -230,6 +230,10 @@ def _create_batch(series, timezone):
             s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
             # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
             return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
+        elif t is not None and pa.types.is_string(t) and sys.version < '3':
+            # TODO: need decode before converting to Arrow in Python 2
+            return pa.Array.from_pandas(s.apply(
+                lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
         return pa.Array.from_pandas(s, mask=mask, type=t)
 
     arrs = [create_array(s, t) for s, t in series]

http://git-wip-us.apache.org/repos/asf/spark/blob/63c5bf13/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 545ec5a..89b7c21 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3922,6 +3922,15 @@ class ScalarPandasUDF(ReusedSQLTestCase):
         res = df.select(str_f(col('str')))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_vectorized_udf_string_in_udf(self):
+        from pyspark.sql.functions import pandas_udf, col
+        import pandas as pd
+        df = self.spark.range(10)
+        str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType())
+        actual = df.select(str_f(col('id')))
+        expected = df.select(col('id').cast('string'))
+        self.assertEquals(expected.collect(), actual.collect())
+
     def test_vectorized_udf_datatype_string(self):
         from pyspark.sql.functions import pandas_udf, col
         df = self.spark.range(10).select(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org