You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/07/17 07:37:48 UTC
spark git commit: [SPARK-21394][SPARK-21432][PYTHON] Reviving
callable object/partial function support in UDF in PySpark
Repository: spark
Updated Branches:
refs/heads/master e398c2814 -> 4ce735eed
[SPARK-21394][SPARK-21432][PYTHON] Reviving callable object/partial function support in UDF in PySpark
## What changes were proposed in this pull request?
This PR proposes to avoid `__name__` in the tuple naming the attributes assigned directly from the wrapped function to the wrapper function, and use `self._name` (`func.__name__` or `obj.__class__.name__`).
After SPARK-19161, we happened to break callable objects as UDFs in Python as below:
```python
from pyspark.sql import functions
class F(object):
def __call__(self, x):
return x
foo = F()
udf = functions.udf(foo)
```
```
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/functions.py", line 2142, in udf
return _udf(f=f, returnType=returnType)
File ".../spark/python/pyspark/sql/functions.py", line 2133, in _udf
return udf_obj._wrapped()
File ".../spark/python/pyspark/sql/functions.py", line 2090, in _wrapped
functools.wraps(self.func)
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/functools.py", line 33, in update_wrapper
setattr(wrapper, attr, getattr(wrapped, attr))
AttributeError: F instance has no attribute '__name__'
```
This worked in Spark 2.1:
```python
from pyspark.sql import functions
class F(object):
def __call__(self, x):
return x
foo = F()
udf = functions.udf(foo)
spark.range(1).select(udf("id")).show()
```
```
+-----+
|F(id)|
+-----+
| 0|
+-----+
```
**After**
```python
from pyspark.sql import functions
class F(object):
def __call__(self, x):
return x
foo = F()
udf = functions.udf(foo)
spark.range(1).select(udf("id")).show()
```
```
+-----+
|F(id)|
+-----+
| 0|
+-----+
```
_In addition, we also happened to break partial functions as below_:
```python
from pyspark.sql import functions
from functools import partial
partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
```
```
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../spark/python/pyspark/sql/functions.py", line 2154, in udf
return _udf(f=f, returnType=returnType)
File ".../spark/python/pyspark/sql/functions.py", line 2145, in _udf
return udf_obj._wrapped()
File ".../spark/python/pyspark/sql/functions.py", line 2099, in _wrapped
functools.wraps(self.func, assigned=assignments)
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/functools.py", line 33, in update_wrapper
setattr(wrapper, attr, getattr(wrapped, attr))
AttributeError: 'functools.partial' object has no attribute '__module__'
```
This worked in Spark 2.1:
```python
from pyspark.sql import functions
from functools import partial
partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
spark.range(1).select(udf()).show()
```
```
+---------+
|partial()|
+---------+
| 1|
+---------+
```
**After**
```python
from pyspark.sql import functions
from functools import partial
partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
spark.range(1).select(udf()).show()
```
```
+---------+
|partial()|
+---------+
| 1|
+---------+
```
## How was this patch tested?
Unit tests in `python/pyspark/sql/tests.py` and manual tests.
Author: hyukjinkwon <gu...@gmail.com>
Closes #18615 from HyukjinKwon/callable-object.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4ce735ee
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4ce735ee
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4ce735ee
Branch: refs/heads/master
Commit: 4ce735eed103f3bd055c087126acd1366c2537ec
Parents: e398c28
Author: hyukjinkwon <gu...@gmail.com>
Authored: Mon Jul 17 00:37:36 2017 -0700
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Mon Jul 17 00:37:36 2017 -0700
----------------------------------------------------------------------
python/pyspark/sql/functions.py | 14 +++++++++++++-
python/pyspark/sql/tests.py | 21 +++++++++++++++++++++
2 files changed, 34 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/4ce735ee/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d45ff63..2c8c8e2 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2087,10 +2087,22 @@ class UserDefinedFunction(object):
"""
Wrap this udf with a function and attach docstring from func
"""
- @functools.wraps(self.func)
+
+ # It is possible for a callable instance without __name__ attribute or/and
+ # __module__ attribute to be wrapped here. For example, functools.partial. In this case,
+ # we should avoid wrapping the attributes from the wrapped function to the wrapper
+ # function. So, we take out these attribute names from the default names to set and
+ # then manually assign it after being wrapped.
+ assignments = tuple(
+ a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
+
+ @functools.wraps(self.func, assigned=assignments)
def wrapper(*args):
return self(*args)
+ wrapper.__name__ = self._name
+ wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
+ else self.func.__class__.__module__)
wrapper.func = self.func
wrapper.returnType = self.returnType
http://git-wip-us.apache.org/repos/asf/spark/blob/4ce735ee/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 29e48a6..be5495c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -679,6 +679,27 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)
+ class F(object):
+ """Identity"""
+ def __call__(self, x):
+ return x
+
+ f = F()
+ return_type = IntegerType()
+ f_ = udf(f, return_type)
+
+ self.assertTrue(f.__doc__ in f_.__doc__)
+ self.assertEqual(f, f_.func)
+ self.assertEqual(return_type, f_.returnType)
+
+ f = functools.partial(f, x=1)
+ return_type = IntegerType()
+ f_ = udf(f, return_type)
+
+ self.assertTrue(f.__doc__ in f_.__doc__)
+ self.assertEqual(f, f_.func)
+ self.assertEqual(return_type, f_.returnType)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org