You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/02/13 18:37:37 UTC

spark git commit: [SPARK-19427][PYTHON][SQL] Support data type string as a returnType argument of UDF

Repository: spark
Updated Branches:
  refs/heads/master 5e7cd3322 -> ab88b2410


[SPARK-19427][PYTHON][SQL] Support data type string as a returnType argument of UDF

## What changes were proposed in this pull request?

Add support for data type string as a return type argument of `UserDefinedFunction`:

```python
f = udf(lambda x: x, "integer")
 f.returnType

## IntegerType
```

## How was this patch tested?

Existing unit tests, additional unit tests covering new feature.

Author: zero323 <ze...@users.noreply.github.com>

Closes #16769 from zero323/SPARK-19427.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab88b241
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab88b241
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab88b241

Branch: refs/heads/master
Commit: ab88b2410623e5fdb06d558017bd6d50220e466a
Parents: 5e7cd33
Author: zero323 <ze...@users.noreply.github.com>
Authored: Mon Feb 13 10:37:34 2017 -0800
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Mon Feb 13 10:37:34 2017 -0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py |  8 +++++---
 python/pyspark/sql/tests.py     | 15 +++++++++++++++
 2 files changed, 20 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ab88b241/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 40727ab..5213a3c 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -27,7 +27,7 @@ if sys.version < "3":
 from pyspark import since, SparkContext
 from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.sql.types import StringType
+from pyspark.sql.types import StringType, DataType, _parse_datatype_string
 from pyspark.sql.column import Column, _to_java_column, _to_seq
 from pyspark.sql.dataframe import DataFrame
 
@@ -1865,7 +1865,9 @@ class UserDefinedFunction(object):
     """
     def __init__(self, func, returnType, name=None):
         self.func = func
-        self.returnType = returnType
+        self.returnType = (
+            returnType if isinstance(returnType, DataType)
+            else _parse_datatype_string(returnType))
         # Stores UserDefinedPythonFunctions jobj, once initialized
         self._judf_placeholder = None
         self._name = name or (
@@ -1909,7 +1911,7 @@ def udf(f, returnType=StringType()):
         it is present in the query.
 
     :param f: python function
-    :param returnType: a :class:`pyspark.sql.types.DataType` object
+    :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
 
     >>> from pyspark.sql.types import IntegerType
     >>> slen = udf(lambda s: len(s), IntegerType())

http://git-wip-us.apache.org/repos/asf/spark/blob/ab88b241/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 710585c..ab9d3f6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -489,6 +489,21 @@ class SQLTests(ReusedPySparkTestCase):
             "judf should be initialized after UDF has been called."
         )
 
+    def test_udf_with_string_return_type(self):
+        from pyspark.sql.functions import UserDefinedFunction
+
+        add_one = UserDefinedFunction(lambda x: x + 1, "integer")
+        make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
+        make_array = UserDefinedFunction(
+            lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
+
+        expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
+        actual = (self.spark.range(1, 2).toDF("x")
+                  .select(add_one("x"), make_pair("x"), make_array("x"))
+                  .first())
+
+        self.assertTupleEqual(expected, actual)
+
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
         df = self.spark.read.json(rdd)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org