You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/01/04 13:07:43 UTC
spark git commit: [SPARK-22939][PYSPARK] Support Spark UDF in
registerFunction
Repository: spark
Updated Branches:
refs/heads/master d5861aba9 -> 5aadbc929
[SPARK-22939][PYSPARK] Support Spark UDF in registerFunction
## What changes were proposed in this pull request?
```Python
import random
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType, StringType
random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
spark.catalog.registerFunction("random_udf", random_udf, StringType())
spark.sql("SELECT random_udf()").collect()
```
We will get the following error.
```
Py4JError: An error occurred while calling o29.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326)
at py4j.Gateway.invoke(Gateway.java:274)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:214)
at java.lang.Thread.run(Thread.java:745)
```
This PR is to support it.
## How was this patch tested?
WIP
Author: gatorsmile <ga...@gmail.com>
Closes #20137 from gatorsmile/registerFunction.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5aadbc92
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5aadbc92
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5aadbc92
Branch: refs/heads/master
Commit: 5aadbc929cb194e06dbd3bab054a161569289af5
Parents: d5861ab
Author: gatorsmile <ga...@gmail.com>
Authored: Thu Jan 4 21:07:31 2018 +0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Jan 4 21:07:31 2018 +0800
----------------------------------------------------------------------
python/pyspark/sql/catalog.py | 27 +++++++++++++++++----
python/pyspark/sql/context.py | 16 ++++++++++---
python/pyspark/sql/tests.py | 49 +++++++++++++++++++++++++++-----------
python/pyspark/sql/udf.py | 21 ++++++++++------
4 files changed, 84 insertions(+), 29 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5aadbc92/python/pyspark/sql/catalog.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 659bc65..1566031 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -227,15 +227,15 @@ class Catalog(object):
@ignore_unicode_prefix
@since(2.0)
def registerFunction(self, name, f, returnType=StringType()):
- """Registers a python function (including lambda function) as a UDF
- so it can be used in SQL statements.
+ """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
+ as a UDF. The registered UDF can be used in SQL statement.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.
:param name: name of the UDF
- :param f: python function
+ :param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`
@@ -255,9 +255,26 @@ class Catalog(object):
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
+
+ >>> import random
+ >>> from pyspark.sql.functions import udf
+ >>> from pyspark.sql.types import IntegerType, StringType
+ >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
+ >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType())
+ >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
+ [Row(random_udf()=u'82')]
+ >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
+ [Row(random_udf()=u'62')]
"""
- udf = UserDefinedFunction(f, returnType=returnType, name=name,
- evalType=PythonEvalType.SQL_BATCHED_UDF)
+
+ # This is to check whether the input function is a wrapped/native UserDefinedFunction
+ if hasattr(f, 'asNondeterministic'):
+ udf = UserDefinedFunction(f.func, returnType=returnType, name=name,
+ evalType=PythonEvalType.SQL_BATCHED_UDF,
+ deterministic=f.deterministic)
+ else:
+ udf = UserDefinedFunction(f, returnType=returnType, name=name,
+ evalType=PythonEvalType.SQL_BATCHED_UDF)
self._jsparkSession.udf().registerPython(name, udf._judf)
return udf._wrapped()
http://git-wip-us.apache.org/repos/asf/spark/blob/5aadbc92/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index b1e723c..b8d86cc 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -175,15 +175,15 @@ class SQLContext(object):
@ignore_unicode_prefix
@since(1.2)
def registerFunction(self, name, f, returnType=StringType()):
- """Registers a python function (including lambda function) as a UDF
- so it can be used in SQL statements.
+ """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
+ as a UDF. The registered UDF can be used in SQL statement.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.
:param name: name of the UDF
- :param f: python function
+ :param f: a Python function, or a wrapped/native UserDefinedFunction
:param returnType: a :class:`pyspark.sql.types.DataType` object
:return: a wrapped :class:`UserDefinedFunction`
@@ -203,6 +203,16 @@ class SQLContext(object):
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
+
+ >>> import random
+ >>> from pyspark.sql.functions import udf
+ >>> from pyspark.sql.types import IntegerType, StringType
+ >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
+ >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType())
+ >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
+ [Row(random_udf()=u'82')]
+ >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP
+ [Row(random_udf()=u'62')]
"""
return self.sparkSession.catalog.registerFunction(name, f, returnType)
http://git-wip-us.apache.org/repos/asf/spark/blob/5aadbc92/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 67bdb3d..6dc767f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -378,6 +378,41 @@ class SQLTests(ReusedSQLTestCase):
[res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
+ def test_udf3(self):
+ twoargs = self.spark.catalog.registerFunction(
+ "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType())
+ self.assertEqual(twoargs.deterministic, True)
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_nondeterministic_udf(self):
+ from pyspark.sql.functions import udf
+ import random
+ udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
+ self.assertEqual(udf_random_col.deterministic, False)
+ df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
+ udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
+ [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
+ self.assertEqual(row[0] + 10, row[1])
+
+ def test_nondeterministic_udf2(self):
+ import random
+ from pyspark.sql.functions import udf
+ random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
+ self.assertEqual(random_udf.deterministic, False)
+ random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType())
+ self.assertEqual(random_udf1.deterministic, False)
+ [row] = self.spark.sql("SELECT randInt()").collect()
+ self.assertEqual(row[0], "6")
+ [row] = self.spark.range(1).select(random_udf1()).collect()
+ self.assertEqual(row[0], "6")
+ [row] = self.spark.range(1).select(random_udf()).collect()
+ self.assertEqual(row[0], 6)
+ # render_doc() reproduces the help() exception without printing output
+ pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
+ pydoc.render_doc(random_udf)
+ pydoc.render_doc(random_udf1)
+
def test_chained_udf(self):
self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.spark.sql("SELECT double(1)").collect()
@@ -435,15 +470,6 @@ class SQLTests(ReusedSQLTestCase):
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)
- def test_nondeterministic_udf(self):
- from pyspark.sql.functions import udf
- import random
- udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
- df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
- udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
- [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
- self.assertEqual(row[0] + 10, row[1])
-
def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
@@ -567,7 +593,6 @@ class SQLTests(ReusedSQLTestCase):
def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
- from pyspark.sql.types import StringType
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
@@ -575,7 +600,6 @@ class SQLTests(ReusedSQLTestCase):
def test_udf_with_input_file_name_for_hadooprdd(self):
from pyspark.sql.functions import udf, input_file_name
- from pyspark.sql.types import StringType
def filename(path):
return path
@@ -635,7 +659,6 @@ class SQLTests(ReusedSQLTestCase):
def test_udf_shouldnt_accept_noncallable_object(self):
from pyspark.sql.functions import UserDefinedFunction
- from pyspark.sql.types import StringType
non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
@@ -1299,7 +1322,6 @@ class SQLTests(ReusedSQLTestCase):
df.filter(df.a.between(df.b, df.c)).collect())
def test_struct_type(self):
- from pyspark.sql.types import StructType, StringType, StructField
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True),
StructField("f2", StringType(), True, None)])
@@ -1368,7 +1390,6 @@ class SQLTests(ReusedSQLTestCase):
_parse_datatype_string("a INT, c DOUBLE"))
def test_metadata_null(self):
- from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
StructField("f2", StringType(), True, {'a': None})])
rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
http://git-wip-us.apache.org/repos/asf/spark/blob/5aadbc92/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 54b5a865..5e75eb6 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType):
)
# Set the name of the UserDefinedFunction object to be the name of function f
- udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType)
+ udf_obj = UserDefinedFunction(
+ f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
return udf_obj._wrapped()
@@ -67,8 +68,10 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
def __init__(self, func,
- returnType=StringType(), name=None,
- evalType=PythonEvalType.SQL_BATCHED_UDF):
+ returnType=StringType(),
+ name=None,
+ evalType=PythonEvalType.SQL_BATCHED_UDF,
+ deterministic=True):
if not callable(func):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): "
@@ -92,7 +95,7 @@ class UserDefinedFunction(object):
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.evalType = evalType
- self._deterministic = True
+ self.deterministic = deterministic
@property
def returnType(self):
@@ -130,7 +133,7 @@ class UserDefinedFunction(object):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
- self._name, wrapped_func, jdt, self.evalType, self._deterministic)
+ self._name, wrapped_func, jdt, self.evalType, self.deterministic)
return judf
def __call__(self, *cols):
@@ -138,6 +141,9 @@ class UserDefinedFunction(object):
sc = SparkContext._active_spark_context
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
+ # This function is for improving the online help system in the interactive interpreter.
+ # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
+ # argument annotation. (See: SPARK-19161)
def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
@@ -162,7 +168,8 @@ class UserDefinedFunction(object):
wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.evalType = self.evalType
- wrapper.asNondeterministic = self.asNondeterministic
+ wrapper.deterministic = self.deterministic
+ wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped()
return wrapper
@@ -172,5 +179,5 @@ class UserDefinedFunction(object):
.. versionadded:: 2.3
"""
- self._deterministic = False
+ self.deterministic = False
return self
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org