You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/03/19 23:06:43 UTC
[spark] branch master updated: [SPARK-26979][PYTHON][FOLLOW-UP]
Make binary math/string functions take string as columns as well
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c99463d [SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well
c99463d is described below
commit c99463d4cfd5c70a28fdf89414207955f60c4789
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Wed Mar 20 08:06:10 2019 +0900
[SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well
## What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/23882 to handle binary math/string functions. For instance, see the cases below:
**Before:**
```python
>>> from pyspark.sql.functions import lit, ascii
>>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/functions.py", line 51, in _
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1286, in __call__
File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/protocol.py", line 332, in get_return_value
py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.ascii. Trace:
py4j.Py4JException: Method ascii([class java.lang.String]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
at py4j.Gateway.invoke(Gateway.java:276)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
```
```python
>>> from pyspark.sql.functions import atan2
>>> spark.range(1).select(atan2("id", "id"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/functions.py", line 78, in _
jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
ValueError: could not convert string to float: id
```
**After:**
```python
>>> from pyspark.sql.functions import lit, ascii
>>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
DataFrame[ascii(value): int]
```
```python
>>> from pyspark.sql.functions import atan2
>>> spark.range(1).select(atan2("id", "id"))
DataFrame[ATAN2(id, id): double]
```
Note that,
- This PR causes a slight behaviour changes for math functions. For instance, numbers as strings (e.g., `"1"`) were supported as arguments of binary math functions before. After this PR, it recognises it as column names.
- I also intentionally didn't document this behaviour changes since we're going ahead for Spark 3.0 and I don't think numbers as strings make much sense in math functions.
- There is another exception `when`, which takes string as literal values as below. This PR doeesn't fix this ambiguity.
```python
>>> spark.range(1).select(when(lit(True), col("id"))).show()
```
```
+--------------------------+
|CASE WHEN true THEN id END|
+--------------------------+
| 0|
+--------------------------+
```
```python
>>> spark.range(1).select(when(lit(True), "id")).show()
```
```
+--------------------------+
|CASE WHEN true THEN id END|
+--------------------------+
| id|
+--------------------------+
```
This PR also fixes as below:
https://github.com/apache/spark/pull/23882 fixed it to:
- Rename `_create_function` to `_create_name_function`
- Define new `_create_function` to take strings as column names.
This PR, I proposes to:
- Revert `_create_name_function` name to `_create_function`.
- Define new `_create_function_over_column` to take strings as column names.
## How was this patch tested?
Some unit tests were added for binary math / string functions.
Closes #24121 from HyukjinKwon/SPARK-26979.
Authored-by: Hyukjin Kwon <gu...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/functions.py | 79 +++++++++++++++++++-----------
python/pyspark/sql/tests/test_functions.py | 14 +++++-
2 files changed, 64 insertions(+), 29 deletions(-)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3ee485c..0326613 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -30,15 +30,22 @@ if sys.version >= '3':
from pyspark import since, SparkContext
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
-from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
+from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal, \
+ _create_column_from_name
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StringType, DataType
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_udf
+# Note to developers: all of PySpark functions here take string as column names whenever possible.
+# Namely, if columns are referred as arguments, they can be always both Column or string,
+# even though there might be few exceptions for legacy or inevitable reasons.
+# If you are fixing other language APIs together, also please note that Scala side is not the case
+# since it requires to make every single overridden definition.
-def _create_name_function(name, doc=""):
- """ Create a function that takes a column name argument, by name"""
+
+def _create_function(name, doc=""):
+ """Create a PySpark function by its name"""
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
@@ -48,8 +55,11 @@ def _create_name_function(name, doc=""):
return _
-def _create_function(name, doc=""):
- """ Create a function that takes a Column object, by name"""
+def _create_function_over_column(name, doc=""):
+ """Similar with `_create_function` but creates a PySpark function that takes a column
+ (as string as well). This is mainly for PySpark functions to take strings as
+ column names.
+ """
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
@@ -71,9 +81,23 @@ def _create_binary_mathfunction(name, doc=""):
""" Create a binary mathfunction by name"""
def _(col1, col2):
sc = SparkContext._active_spark_context
- # users might write ints for simplicity. This would throw an error on the JVM side.
- jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
- col2._jc if isinstance(col2, Column) else float(col2))
+ # For legacy reasons, the arguments here can be implicitly converted into floats,
+ # if they are not columns or strings.
+ if isinstance(col1, Column):
+ arg1 = col1._jc
+ elif isinstance(col1, basestring):
+ arg1 = _create_column_from_name(col1)
+ else:
+ arg1 = float(col1)
+
+ if isinstance(col2, Column):
+ arg2 = col2._jc
+ elif isinstance(col2, basestring):
+ arg2 = _create_column_from_name(col2)
+ else:
+ arg2 = float(col2)
+
+ jc = getattr(sc._jvm.functions, name)(arg1, arg2)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
@@ -96,8 +120,7 @@ _lit_doc = """
>>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1)
[Row(height=5, spark_user=True)]
"""
-_name_functions = {
- # name functions take a column name as their argument
+_functions = {
'lit': _lit_doc,
'col': 'Returns a :class:`Column` based on the given column name.',
'column': 'Returns a :class:`Column` based on the given column name.',
@@ -105,9 +128,7 @@ _name_functions = {
'desc': 'Returns a sort expression based on the descending order of the given column name.',
}
-_functions = {
- 'upper': 'Converts a string expression to upper case.',
- 'lower': 'Converts a string expression to upper case.',
+_functions_over_column = {
'sqrt': 'Computes the square root of the specified float value.',
'abs': 'Computes the absolute value.',
@@ -120,7 +141,7 @@ _functions = {
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}
-_functions_1_4 = {
+_functions_1_4_over_column = {
# unary math functions
'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`',
'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`',
@@ -155,7 +176,7 @@ _functions_1_4 = {
'bitwiseNOT': 'Computes bitwise not.',
}
-_name_functions_2_4 = {
+_functions_2_4 = {
'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' +
' column name, and null values return before non-null values.',
'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' +
@@ -186,7 +207,7 @@ _collect_set_doc = """
>>> df2.agg(collect_set('age')).collect()
[Row(collect_set(age)=[5, 2])]
"""
-_functions_1_6 = {
+_functions_1_6_over_column = {
# unary math functions
'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' +
' the expression in a group.',
@@ -203,7 +224,7 @@ _functions_1_6 = {
'collect_set': _collect_set_doc
}
-_functions_2_1 = {
+_functions_2_1_over_column = {
# unary math functions
'degrees': """
Converts an angle measured in radians to an approximately equivalent angle
@@ -268,24 +289,24 @@ _window_functions = {
_functions_deprecated = {
}
-for _name, _doc in _name_functions.items():
- globals()[_name] = since(1.3)(_create_name_function(_name, _doc))
for _name, _doc in _functions.items():
globals()[_name] = since(1.3)(_create_function(_name, _doc))
-for _name, _doc in _functions_1_4.items():
- globals()[_name] = since(1.4)(_create_function(_name, _doc))
+for _name, _doc in _functions_over_column.items():
+ globals()[_name] = since(1.3)(_create_function_over_column(_name, _doc))
+for _name, _doc in _functions_1_4_over_column.items():
+ globals()[_name] = since(1.4)(_create_function_over_column(_name, _doc))
for _name, _doc in _binary_mathfunctions.items():
globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc))
for _name, _doc in _window_functions.items():
globals()[_name] = since(1.6)(_create_window_function(_name, _doc))
-for _name, _doc in _functions_1_6.items():
- globals()[_name] = since(1.6)(_create_function(_name, _doc))
-for _name, _doc in _functions_2_1.items():
- globals()[_name] = since(2.1)(_create_function(_name, _doc))
+for _name, _doc in _functions_1_6_over_column.items():
+ globals()[_name] = since(1.6)(_create_function_over_column(_name, _doc))
+for _name, _doc in _functions_2_1_over_column.items():
+ globals()[_name] = since(2.1)(_create_function_over_column(_name, _doc))
for _name, _message in _functions_deprecated.items():
globals()[_name] = _wrap_deprecated_function(globals()[_name], _message)
-for _name, _doc in _name_functions_2_4.items():
- globals()[_name] = since(2.4)(_create_name_function(_name, _doc))
+for _name, _doc in _functions_2_4.items():
+ globals()[_name] = since(2.4)(_create_function(_name, _doc))
del _name, _doc
@@ -1450,6 +1471,8 @@ def hash(*cols):
# ---------------------- String/Binary functions ------------------------------
_string_functions = {
+ 'upper': 'Converts a string expression to upper case.',
+ 'lower': 'Converts a string expression to lower case.',
'ascii': 'Computes the numeric value of the first character of the string column.',
'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.',
'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.',
@@ -1460,7 +1483,7 @@ _string_functions = {
for _name, _doc in _string_functions.items():
- globals()[_name] = since(1.5)(_create_function(_name, _doc))
+ globals()[_name] = since(1.5)(_create_function_over_column(_name, _doc))
del _name, _doc
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index fe66602..b777573 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -129,6 +129,12 @@ class FunctionsTests(ReusedSQLTestCase):
df.select(functions.pow(df.a, 2.0)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
+ assert_close([math.hypot(i, 2 * i) for i in range(10)],
+ df.select(functions.hypot("a", u"b")).collect())
+ assert_close([math.hypot(i, 2) for i in range(10)],
+ df.select(functions.hypot("a", 2)).collect())
+ assert_close([math.hypot(i, 2) for i in range(10)],
+ df.select(functions.hypot(df.a, 2)).collect())
def test_rand_functions(self):
df = self.df
@@ -151,7 +157,8 @@ class FunctionsTests(ReusedSQLTestCase):
self.assertEqual(sorted(rndn1), sorted(rndn2))
def test_string_functions(self):
- from pyspark.sql.functions import col, lit
+ from pyspark.sql import functions
+ from pyspark.sql.functions import col, lit, _string_functions
df = self.spark.createDataFrame([['nick']], schema=['name'])
self.assertRaisesRegexp(
TypeError,
@@ -162,6 +169,11 @@ class FunctionsTests(ReusedSQLTestCase):
TypeError,
lambda: df.select(col('name').substr(long(0), long(1))))
+ for name in _string_functions.keys():
+ self.assertEqual(
+ df.select(getattr(functions, name)("name")).first()[0],
+ df.select(getattr(functions, name)(col("name"))).first()[0])
+
def test_array_contains_function(self):
from pyspark.sql.functions import array_contains
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org