You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/02/15 18:16:38 UTC

spark git commit: [SPARK-19160][PYTHON][SQL] Add udf decorator

Repository: spark
Updated Branches:
  refs/heads/master 6eca21ba8 -> c97f4e17d


[SPARK-19160][PYTHON][SQL] Add udf decorator

## What changes were proposed in this pull request?

This PR adds `udf` decorator syntax as proposed in [SPARK-19160](https://issues.apache.org/jira/browse/SPARK-19160).

This allows users to define UDF using simplified syntax:

```python
from pyspark.sql.decorators import udf

udf(IntegerType())
def add_one(x):
    """Adds one"""
    if x is not None:
        return x + 1
 ```

without need to define a separate function and udf.

## How was this patch tested?

Existing unit tests to ensure backward compatibility and additional unit tests covering new functionality.

Author: zero323 <ze...@users.noreply.github.com>

Closes #16533 from zero323/SPARK-19160.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c97f4e17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c97f4e17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c97f4e17

Branch: refs/heads/master
Commit: c97f4e17de0ce39e8172a5a4ae81f1914816a358
Parents: 6eca21b
Author: zero323 <ze...@users.noreply.github.com>
Authored: Wed Feb 15 10:16:34 2017 -0800
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Wed Feb 15 10:16:34 2017 -0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py | 41 +++++++++++++++++++++-----
 python/pyspark/sql/tests.py     | 57 ++++++++++++++++++++++++++++++++++++
 2 files changed, 91 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c97f4e17/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 4f4ae10..d261720 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -20,6 +20,7 @@ A collections of builtin functions
 """
 import math
 import sys
+import functools
 
 if sys.version < "3":
     from itertools import imap as map
@@ -1908,22 +1909,48 @@ class UserDefinedFunction(object):
 
 
 @since(1.3)
-def udf(f, returnType=StringType()):
+def udf(f=None, returnType=StringType()):
     """Creates a :class:`Column` expression representing a user defined function (UDF).
 
     .. note:: The user-defined functions must be deterministic. Due to optimization,
         duplicate invocations may be eliminated or the function may even be invoked more times than
         it is present in the query.
 
-    :param f: python function
-    :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
+    :param f: python function if used as a standalone function
+    :param returnType: a :class:`pyspark.sql.types.DataType` object
 
     >>> from pyspark.sql.types import IntegerType
     >>> slen = udf(lambda s: len(s), IntegerType())
-    >>> df.select(slen(df.name).alias('slen')).collect()
-    [Row(slen=5), Row(slen=3)]
-    """
-    return UserDefinedFunction(f, returnType)
+    >>> @udf
+    ... def to_upper(s):
+    ...     if s is not None:
+    ...         return s.upper()
+    ...
+    >>> @udf(returnType=IntegerType())
+    ... def add_one(x):
+    ...     if x is not None:
+    ...         return x + 1
+    ...
+    >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
+    >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")).show()
+    +----------+--------------+------------+
+    |slen(name)|to_upper(name)|add_one(age)|
+    +----------+--------------+------------+
+    |         8|      JOHN DOE|          22|
+    +----------+--------------+------------+
+    """
+    def _udf(f, returnType=StringType()):
+        return UserDefinedFunction(f, returnType)
+
+    # decorator @udf, @udf() or @udf(dataType())
+    if f is None or isinstance(f, (str, DataType)):
+        # If DataType has been passed as a positional argument
+        # for decorator use it as a returnType
+        return_type = f or returnType
+        return functools.partial(_udf, returnType=return_type)
+    else:
+        return _udf(f=f, returnType=returnType)
+
 
 blacklist = ['map', 'since', 'ignore_unicode_prefix']
 __all__ = [k for k, v in globals().items()

http://git-wip-us.apache.org/repos/asf/spark/blob/c97f4e17/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 62e1a8c..d8b7b31 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -514,6 +514,63 @@ class SQLTests(ReusedPySparkTestCase):
         non_callable = None
         self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
 
+    def test_udf_with_decorator(self):
+        from pyspark.sql.functions import lit, udf
+        from pyspark.sql.types import IntegerType, DoubleType
+
+        @udf(IntegerType())
+        def add_one(x):
+            if x is not None:
+                return x + 1
+
+        @udf(returnType=DoubleType())
+        def add_two(x):
+            if x is not None:
+                return float(x + 2)
+
+        @udf
+        def to_upper(x):
+            if x is not None:
+                return x.upper()
+
+        @udf()
+        def to_lower(x):
+            if x is not None:
+                return x.lower()
+
+        @udf
+        def substr(x, start, end):
+            if x is not None:
+                return x[start:end]
+
+        @udf("long")
+        def trunc(x):
+            return int(x)
+
+        @udf(returnType="double")
+        def as_double(x):
+            return float(x)
+
+        df = (
+            self.spark
+                .createDataFrame(
+                    [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
+                .select(
+                    add_one("one"), add_two("one"),
+                    to_upper("Foo"), to_lower("Foo"),
+                    substr("foobar", lit(0), lit(3)),
+                    trunc("float"), as_double("one")))
+
+        self.assertListEqual(
+            [tpe for _, tpe in df.dtypes],
+            ["int", "double", "string", "string", "string", "bigint", "double"]
+        )
+
+        self.assertListEqual(
+            list(df.first()),
+            [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
+        )
+
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
         df = self.spark.read.json(rdd)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org