You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/02/13 18:37:37 UTC
spark git commit: [SPARK-19427][PYTHON][SQL] Support data type string
as a returnType argument of UDF
Repository: spark
Updated Branches:
refs/heads/master 5e7cd3322 -> ab88b2410
[SPARK-19427][PYTHON][SQL] Support data type string as a returnType argument of UDF
## What changes were proposed in this pull request?
Add support for data type string as a return type argument of `UserDefinedFunction`:
```python
f = udf(lambda x: x, "integer")
f.returnType
## IntegerType
```
## How was this patch tested?
Existing unit tests, additional unit tests covering new feature.
Author: zero323 <ze...@users.noreply.github.com>
Closes #16769 from zero323/SPARK-19427.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab88b241
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab88b241
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab88b241
Branch: refs/heads/master
Commit: ab88b2410623e5fdb06d558017bd6d50220e466a
Parents: 5e7cd33
Author: zero323 <ze...@users.noreply.github.com>
Authored: Mon Feb 13 10:37:34 2017 -0800
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Mon Feb 13 10:37:34 2017 -0800
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 8 +++++---
python/pyspark/sql/tests.py | 15 +++++++++++++++
2 files changed, 20 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ab88b241/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 40727ab..5213a3c 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -27,7 +27,7 @@ if sys.version < "3":
from pyspark import since, SparkContext
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.sql.types import StringType
+from pyspark.sql.types import StringType, DataType, _parse_datatype_string
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.dataframe import DataFrame
@@ -1865,7 +1865,9 @@ class UserDefinedFunction(object):
"""
def __init__(self, func, returnType, name=None):
self.func = func
- self.returnType = returnType
+ self.returnType = (
+ returnType if isinstance(returnType, DataType)
+ else _parse_datatype_string(returnType))
# Stores UserDefinedPythonFunctions jobj, once initialized
self._judf_placeholder = None
self._name = name or (
@@ -1909,7 +1911,7 @@ def udf(f, returnType=StringType()):
it is present in the query.
:param f: python function
- :param returnType: a :class:`pyspark.sql.types.DataType` object
+ :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
http://git-wip-us.apache.org/repos/asf/spark/blob/ab88b241/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 710585c..ab9d3f6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -489,6 +489,21 @@ class SQLTests(ReusedPySparkTestCase):
"judf should be initialized after UDF has been called."
)
+ def test_udf_with_string_return_type(self):
+ from pyspark.sql.functions import UserDefinedFunction
+
+ add_one = UserDefinedFunction(lambda x: x + 1, "integer")
+ make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
+ make_array = UserDefinedFunction(
+ lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
+
+ expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
+ actual = (self.spark.range(1, 2).toDF("x")
+ .select(add_one("x"), make_pair("x"), make_array("x"))
+ .first())
+
+ self.assertTupleEqual(expected, actual)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org