You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/01/04 13:07:43 UTC

spark git commit: [SPARK-22939][PYSPARK] Support Spark UDF in registerFunction

Repository: spark
Updated Branches:
  refs/heads/master d5861aba9 -> 5aadbc929


[SPARK-22939][PYSPARK] Support Spark UDF in registerFunction

## What changes were proposed in this pull request?
```Python
import random
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType, StringType
random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
spark.catalog.registerFunction("random_udf", random_udf, StringType())
spark.sql("SELECT random_udf()").collect()
```

We will get the following error.
```
Py4JError: An error occurred while calling o29.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
	at py4j.Gateway.invoke(Gateway.java:274)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:745)
```

This PR is to support it.

## How was this patch tested?
WIP

Author: gatorsmile <ga...@gmail.com>

Closes #20137 from gatorsmile/registerFunction.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5aadbc92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5aadbc92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5aadbc92

Branch: refs/heads/master
Commit: 5aadbc929cb194e06dbd3bab054a161569289af5
Parents: d5861ab
Author: gatorsmile <ga...@gmail.com>
Authored: Thu Jan 4 21:07:31 2018 +0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Jan 4 21:07:31 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/catalog.py | 27 +++++++++++++++++----
 python/pyspark/sql/context.py | 16 ++++++++++---
 python/pyspark/sql/tests.py   | 49 +++++++++++++++++++++++++++-----------
 python/pyspark/sql/udf.py     | 21 ++++++++++------
 4 files changed, 84 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5aadbc92/python/pyspark/sql/catalog.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 659bc65..1566031 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -227,15 +227,15 @@ class Catalog(object):
     @ignore_unicode_prefix
     @since(2.0)
     def registerFunction(self, name, f, returnType=StringType()):
-        """Registers a python function (including lambda function) as a UDF
-        so it can be used in SQL statements.
+        """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
+        as a UDF. The registered UDF can be used in SQL statement.
 
         In addition to a name and the function itself, the return type can be optionally specified.
         When the return type is not given it default to a string and conversion will automatically
         be done.  For any other return type, the produced object must match the specified type.
 
         :param name: name of the UDF
-        :param f: python function
+        :param f: a Python function, or a wrapped/native UserDefinedFunction
         :param returnType: a :class:`pyspark.sql.types.DataType` object
         :return: a wrapped :class:`UserDefinedFunction`
 
@@ -255,9 +255,26 @@ class Catalog(object):
         >>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
         >>> spark.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
+
+        >>> import random
+        >>> from pyspark.sql.functions import udf
+        >>> from pyspark.sql.types import IntegerType, StringType
+        >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
+        >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
+        >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
+        [Row(random_udf()=u'82')]
+        >>> spark.range(1).select(newRandom_udf()).collect()  # doctest: +SKIP
+        [Row(random_udf()=u'62')]
         """
-        udf = UserDefinedFunction(f, returnType=returnType, name=name,
-                                  evalType=PythonEvalType.SQL_BATCHED_UDF)
+
+        # This is to check whether the input function is a wrapped/native UserDefinedFunction
+        if hasattr(f, 'asNondeterministic'):
+            udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
+                                      evalType=PythonEvalType.SQL_BATCHED_UDF,
+                                      deterministic=f.deterministic)
+        else:
+            udf = UserDefinedFunction(f, returnType=returnType, name=name,
+                                      evalType=PythonEvalType.SQL_BATCHED_UDF)
         self._jsparkSession.udf().registerPython(name, udf._judf)
         return udf._wrapped()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5aadbc92/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index b1e723c..b8d86cc 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -175,15 +175,15 @@ class SQLContext(object):
     @ignore_unicode_prefix
     @since(1.2)
     def registerFunction(self, name, f, returnType=StringType()):
-        """Registers a python function (including lambda function) as a UDF
-        so it can be used in SQL statements.
+        """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
+        as a UDF. The registered UDF can be used in SQL statement.
 
         In addition to a name and the function itself, the return type can be optionally specified.
         When the return type is not given it default to a string and conversion will automatically
         be done.  For any other return type, the produced object must match the specified type.
 
         :param name: name of the UDF
-        :param f: python function
+        :param f: a Python function, or a wrapped/native UserDefinedFunction
         :param returnType: a :class:`pyspark.sql.types.DataType` object
         :return: a wrapped :class:`UserDefinedFunction`
 
@@ -203,6 +203,16 @@ class SQLContext(object):
         >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
         >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
+
+        >>> import random
+        >>> from pyspark.sql.functions import udf
+        >>> from pyspark.sql.types import IntegerType, StringType
+        >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
+        >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
+        >>> sqlContext.sql("SELECT random_udf()").collect()  # doctest: +SKIP
+        [Row(random_udf()=u'82')]
+        >>> sqlContext.range(1).select(newRandom_udf()).collect()  # doctest: +SKIP
+        [Row(random_udf()=u'62')]
         """
         return self.sparkSession.catalog.registerFunction(name, f, returnType)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5aadbc92/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 67bdb3d..6dc767f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -378,6 +378,41 @@ class SQLTests(ReusedSQLTestCase):
         [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
         self.assertEqual(4, res[0])
 
+    def test_udf3(self):
+        twoargs = self.spark.catalog.registerFunction(
+            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
+        self.assertEqual(twoargs.deterministic, True)
+        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+        self.assertEqual(row[0], 5)
+
+    def test_nondeterministic_udf(self):
+        from pyspark.sql.functions import udf
+        import random
+        udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
+        self.assertEqual(udf_random_col.deterministic, False)
+        df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
+        udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
+        [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
+        self.assertEqual(row[0] + 10, row[1])
+
+    def test_nondeterministic_udf2(self):
+        import random
+        from pyspark.sql.functions import udf
+        random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
+        self.assertEqual(random_udf.deterministic, False)
+        random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
+        self.assertEqual(random_udf1.deterministic, False)
+        [row] = self.spark.sql("SELECT randInt()").collect()
+        self.assertEqual(row[0], "6")
+        [row] = self.spark.range(1).select(random_udf1()).collect()
+        self.assertEqual(row[0], "6")
+        [row] = self.spark.range(1).select(random_udf()).collect()
+        self.assertEqual(row[0], 6)
+        # render_doc() reproduces the help() exception without printing output
+        pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
+        pydoc.render_doc(random_udf)
+        pydoc.render_doc(random_udf1)
+
     def test_chained_udf(self):
         self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
         [row] = self.spark.sql("SELECT double(1)").collect()
@@ -435,15 +470,6 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(list(range(3)), l1)
         self.assertEqual(1, l2)
 
-    def test_nondeterministic_udf(self):
-        from pyspark.sql.functions import udf
-        import random
-        udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
-        df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
-        udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
-        [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
-        self.assertEqual(row[0] + 10, row[1])
-
     def test_broadcast_in_udf(self):
         bar = {"a": "aa", "b": "bb", "c": "abc"}
         foo = self.sc.broadcast(bar)
@@ -567,7 +593,6 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_udf_with_input_file_name(self):
         from pyspark.sql.functions import udf, input_file_name
-        from pyspark.sql.types import StringType
         sourceFile = udf(lambda path: path, StringType())
         filePath = "python/test_support/sql/people1.json"
         row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
@@ -575,7 +600,6 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_udf_with_input_file_name_for_hadooprdd(self):
         from pyspark.sql.functions import udf, input_file_name
-        from pyspark.sql.types import StringType
 
         def filename(path):
             return path
@@ -635,7 +659,6 @@ class SQLTests(ReusedSQLTestCase):
 
     def test_udf_shouldnt_accept_noncallable_object(self):
         from pyspark.sql.functions import UserDefinedFunction
-        from pyspark.sql.types import StringType
 
         non_callable = None
         self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
@@ -1299,7 +1322,6 @@ class SQLTests(ReusedSQLTestCase):
                          df.filter(df.a.between(df.b, df.c)).collect())
 
     def test_struct_type(self):
-        from pyspark.sql.types import StructType, StringType, StructField
         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
         struct2 = StructType([StructField("f1", StringType(), True),
                               StructField("f2", StringType(), True, None)])
@@ -1368,7 +1390,6 @@ class SQLTests(ReusedSQLTestCase):
             _parse_datatype_string("a INT, c DOUBLE"))
 
     def test_metadata_null(self):
-        from pyspark.sql.types import StructType, StringType, StructField
         schema = StructType([StructField("f1", StringType(), True, None),
                              StructField("f2", StringType(), True, {'a': None})])
         rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])

http://git-wip-us.apache.org/repos/asf/spark/blob/5aadbc92/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 54b5a865..5e75eb6 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType):
             )
 
     # Set the name of the UserDefinedFunction object to be the name of function f
-    udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType)
+    udf_obj = UserDefinedFunction(
+        f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
     return udf_obj._wrapped()
 
 
@@ -67,8 +68,10 @@ class UserDefinedFunction(object):
     .. versionadded:: 1.3
     """
     def __init__(self, func,
-                 returnType=StringType(), name=None,
-                 evalType=PythonEvalType.SQL_BATCHED_UDF):
+                 returnType=StringType(),
+                 name=None,
+                 evalType=PythonEvalType.SQL_BATCHED_UDF,
+                 deterministic=True):
         if not callable(func):
             raise TypeError(
                 "Invalid function: not a function or callable (__call__ is not defined): "
@@ -92,7 +95,7 @@ class UserDefinedFunction(object):
             func.__name__ if hasattr(func, '__name__')
             else func.__class__.__name__)
         self.evalType = evalType
-        self._deterministic = True
+        self.deterministic = deterministic
 
     @property
     def returnType(self):
@@ -130,7 +133,7 @@ class UserDefinedFunction(object):
         wrapped_func = _wrap_function(sc, self.func, self.returnType)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
-            self._name, wrapped_func, jdt, self.evalType, self._deterministic)
+            self._name, wrapped_func, jdt, self.evalType, self.deterministic)
         return judf
 
     def __call__(self, *cols):
@@ -138,6 +141,9 @@ class UserDefinedFunction(object):
         sc = SparkContext._active_spark_context
         return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
 
+    # This function is for improving the online help system in the interactive interpreter.
+    # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
+    # argument annotation. (See: SPARK-19161)
     def _wrapped(self):
         """
         Wrap this udf with a function and attach docstring from func
@@ -162,7 +168,8 @@ class UserDefinedFunction(object):
         wrapper.func = self.func
         wrapper.returnType = self.returnType
         wrapper.evalType = self.evalType
-        wrapper.asNondeterministic = self.asNondeterministic
+        wrapper.deterministic = self.deterministic
+        wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped()
 
         return wrapper
 
@@ -172,5 +179,5 @@ class UserDefinedFunction(object):
 
         .. versionadded:: 2.3
         """
-        self._deterministic = False
+        self.deterministic = False
         return self


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org