You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/03/19 23:06:43 UTC

[spark] branch master updated: [SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c99463d  [SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well
c99463d is described below

commit c99463d4cfd5c70a28fdf89414207955f60c4789
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Wed Mar 20 08:06:10 2019 +0900

    [SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well
    
    ## What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/23882 to handle binary math/string functions. For instance, see the cases below:
    
    **Before:**
    
    ```python
    >>> from pyspark.sql.functions import lit, ascii
    >>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/.../spark/python/pyspark/sql/functions.py", line 51, in _
        jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
      File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1286, in __call__
      File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
        return f(*a, **kw)
      File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/protocol.py", line 332, in get_return_value
    py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.ascii. Trace:
    py4j.Py4JException: Method ascii([class java.lang.String]) does not exist
    	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
    	at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
    	at py4j.Gateway.invoke(Gateway.java:276)
    	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    	at py4j.commands.CallCommand.execute(CallCommand.java:79)
    	at py4j.GatewayConnection.run(GatewayConnection.java:238)
    	at java.lang.Thread.run(Thread.java:748)
    ```
    
    ```python
    >>> from pyspark.sql.functions import atan2
    >>> spark.range(1).select(atan2("id", "id"))
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/.../spark/python/pyspark/sql/functions.py", line 78, in _
        jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
    ValueError: could not convert string to float: id
    ```
    
    **After:**
    
    ```python
    >>> from pyspark.sql.functions import lit, ascii
    >>> spark.range(1).select(lit('a').alias("value")).select(ascii("value"))
    DataFrame[ascii(value): int]
    ```
    
    ```python
    >>> from pyspark.sql.functions import atan2
    >>> spark.range(1).select(atan2("id", "id"))
    DataFrame[ATAN2(id, id): double]
    ```
    
    Note that,
    
    - This PR causes a slight behaviour changes for math functions. For instance, numbers as strings (e.g., `"1"`) were supported as arguments of binary math functions before. After this PR, it recognises it as column names.
    
    - I also intentionally didn't document this behaviour changes since we're going ahead for Spark 3.0 and I don't think numbers as strings make much sense in math functions.
    
    - There is another exception `when`, which takes string as literal values as below. This PR doeesn't fix this ambiguity.
      ```python
      >>> spark.range(1).select(when(lit(True), col("id"))).show()
      ```
    
      ```
      +--------------------------+
      |CASE WHEN true THEN id END|
      +--------------------------+
      |                         0|
      +--------------------------+
      ```
    
      ```python
      >>> spark.range(1).select(when(lit(True), "id")).show()
      ```
    
      ```
      +--------------------------+
      |CASE WHEN true THEN id END|
      +--------------------------+
      |                        id|
      +--------------------------+
      ```
    
    This PR also fixes as below:
    
    https://github.com/apache/spark/pull/23882 fixed it to:
    
    - Rename `_create_function` to `_create_name_function`
    - Define new `_create_function` to take strings as column names.
    
    This PR, I proposes to:
    
    - Revert `_create_name_function` name to `_create_function`.
    - Define new `_create_function_over_column` to take strings as column names.
    
    ## How was this patch tested?
    
    Some unit tests were added for binary math / string functions.
    
    Closes #24121 from HyukjinKwon/SPARK-26979.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/functions.py            | 79 +++++++++++++++++++-----------
 python/pyspark/sql/tests/test_functions.py | 14 +++++-
 2 files changed, 64 insertions(+), 29 deletions(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3ee485c..0326613 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -30,15 +30,22 @@ if sys.version >= '3':
 
 from pyspark import since, SparkContext
 from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
-from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
+from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal, \
+    _create_column_from_name
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.types import StringType, DataType
 # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
 from pyspark.sql.udf import UserDefinedFunction, _create_udf
 
+# Note to developers: all of PySpark functions here take string as column names whenever possible.
+# Namely, if columns are referred as arguments, they can be always both Column or string,
+# even though there might be few exceptions for legacy or inevitable reasons.
+# If you are fixing other language APIs together, also please note that Scala side is not the case
+# since it requires to make every single overridden definition.
 
-def _create_name_function(name, doc=""):
-    """ Create a function that takes a column name argument, by name"""
+
+def _create_function(name, doc=""):
+    """Create a PySpark function by its name"""
     def _(col):
         sc = SparkContext._active_spark_context
         jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
@@ -48,8 +55,11 @@ def _create_name_function(name, doc=""):
     return _
 
 
-def _create_function(name, doc=""):
-    """ Create a function that takes a Column object, by name"""
+def _create_function_over_column(name, doc=""):
+    """Similar with `_create_function` but creates a PySpark function that takes a column
+    (as string as well). This is mainly for PySpark functions to take strings as
+    column names.
+    """
     def _(col):
         sc = SparkContext._active_spark_context
         jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
@@ -71,9 +81,23 @@ def _create_binary_mathfunction(name, doc=""):
     """ Create a binary mathfunction by name"""
     def _(col1, col2):
         sc = SparkContext._active_spark_context
-        # users might write ints for simplicity. This would throw an error on the JVM side.
-        jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
-                                              col2._jc if isinstance(col2, Column) else float(col2))
+        # For legacy reasons, the arguments here can be implicitly converted into floats,
+        # if they are not columns or strings.
+        if isinstance(col1, Column):
+            arg1 = col1._jc
+        elif isinstance(col1, basestring):
+            arg1 = _create_column_from_name(col1)
+        else:
+            arg1 = float(col1)
+
+        if isinstance(col2, Column):
+            arg2 = col2._jc
+        elif isinstance(col2, basestring):
+            arg2 = _create_column_from_name(col2)
+        else:
+            arg2 = float(col2)
+
+        jc = getattr(sc._jvm.functions, name)(arg1, arg2)
         return Column(jc)
     _.__name__ = name
     _.__doc__ = doc
@@ -96,8 +120,7 @@ _lit_doc = """
     >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1)
     [Row(height=5, spark_user=True)]
     """
-_name_functions = {
-    # name functions take a column name as their argument
+_functions = {
     'lit': _lit_doc,
     'col': 'Returns a :class:`Column` based on the given column name.',
     'column': 'Returns a :class:`Column` based on the given column name.',
@@ -105,9 +128,7 @@ _name_functions = {
     'desc': 'Returns a sort expression based on the descending order of the given column name.',
 }
 
-_functions = {
-    'upper': 'Converts a string expression to upper case.',
-    'lower': 'Converts a string expression to upper case.',
+_functions_over_column = {
     'sqrt': 'Computes the square root of the specified float value.',
     'abs': 'Computes the absolute value.',
 
@@ -120,7 +141,7 @@ _functions = {
     'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
 }
 
-_functions_1_4 = {
+_functions_1_4_over_column = {
     # unary math functions
     'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`',
     'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`',
@@ -155,7 +176,7 @@ _functions_1_4 = {
     'bitwiseNOT': 'Computes bitwise not.',
 }
 
-_name_functions_2_4 = {
+_functions_2_4 = {
     'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' +
                        ' column name, and null values return before non-null values.',
     'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' +
@@ -186,7 +207,7 @@ _collect_set_doc = """
     >>> df2.agg(collect_set('age')).collect()
     [Row(collect_set(age)=[5, 2])]
     """
-_functions_1_6 = {
+_functions_1_6_over_column = {
     # unary math functions
     'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' +
               ' the expression in a group.',
@@ -203,7 +224,7 @@ _functions_1_6 = {
     'collect_set': _collect_set_doc
 }
 
-_functions_2_1 = {
+_functions_2_1_over_column = {
     # unary math functions
     'degrees': """
                Converts an angle measured in radians to an approximately equivalent angle
@@ -268,24 +289,24 @@ _window_functions = {
 _functions_deprecated = {
 }
 
-for _name, _doc in _name_functions.items():
-    globals()[_name] = since(1.3)(_create_name_function(_name, _doc))
 for _name, _doc in _functions.items():
     globals()[_name] = since(1.3)(_create_function(_name, _doc))
-for _name, _doc in _functions_1_4.items():
-    globals()[_name] = since(1.4)(_create_function(_name, _doc))
+for _name, _doc in _functions_over_column.items():
+    globals()[_name] = since(1.3)(_create_function_over_column(_name, _doc))
+for _name, _doc in _functions_1_4_over_column.items():
+    globals()[_name] = since(1.4)(_create_function_over_column(_name, _doc))
 for _name, _doc in _binary_mathfunctions.items():
     globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc))
 for _name, _doc in _window_functions.items():
     globals()[_name] = since(1.6)(_create_window_function(_name, _doc))
-for _name, _doc in _functions_1_6.items():
-    globals()[_name] = since(1.6)(_create_function(_name, _doc))
-for _name, _doc in _functions_2_1.items():
-    globals()[_name] = since(2.1)(_create_function(_name, _doc))
+for _name, _doc in _functions_1_6_over_column.items():
+    globals()[_name] = since(1.6)(_create_function_over_column(_name, _doc))
+for _name, _doc in _functions_2_1_over_column.items():
+    globals()[_name] = since(2.1)(_create_function_over_column(_name, _doc))
 for _name, _message in _functions_deprecated.items():
     globals()[_name] = _wrap_deprecated_function(globals()[_name], _message)
-for _name, _doc in _name_functions_2_4.items():
-    globals()[_name] = since(2.4)(_create_name_function(_name, _doc))
+for _name, _doc in _functions_2_4.items():
+    globals()[_name] = since(2.4)(_create_function(_name, _doc))
 del _name, _doc
 
 
@@ -1450,6 +1471,8 @@ def hash(*cols):
 # ---------------------- String/Binary functions ------------------------------
 
 _string_functions = {
+    'upper': 'Converts a string expression to upper case.',
+    'lower': 'Converts a string expression to lower case.',
     'ascii': 'Computes the numeric value of the first character of the string column.',
     'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.',
     'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.',
@@ -1460,7 +1483,7 @@ _string_functions = {
 
 
 for _name, _doc in _string_functions.items():
-    globals()[_name] = since(1.5)(_create_function(_name, _doc))
+    globals()[_name] = since(1.5)(_create_function_over_column(_name, _doc))
 del _name, _doc
 
 
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index fe66602..b777573 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -129,6 +129,12 @@ class FunctionsTests(ReusedSQLTestCase):
                      df.select(functions.pow(df.a, 2.0)).collect())
         assert_close([math.hypot(i, 2 * i) for i in range(10)],
                      df.select(functions.hypot(df.a, df.b)).collect())
+        assert_close([math.hypot(i, 2 * i) for i in range(10)],
+                     df.select(functions.hypot("a", u"b")).collect())
+        assert_close([math.hypot(i, 2) for i in range(10)],
+                     df.select(functions.hypot("a", 2)).collect())
+        assert_close([math.hypot(i, 2) for i in range(10)],
+                     df.select(functions.hypot(df.a, 2)).collect())
 
     def test_rand_functions(self):
         df = self.df
@@ -151,7 +157,8 @@ class FunctionsTests(ReusedSQLTestCase):
         self.assertEqual(sorted(rndn1), sorted(rndn2))
 
     def test_string_functions(self):
-        from pyspark.sql.functions import col, lit
+        from pyspark.sql import functions
+        from pyspark.sql.functions import col, lit, _string_functions
         df = self.spark.createDataFrame([['nick']], schema=['name'])
         self.assertRaisesRegexp(
             TypeError,
@@ -162,6 +169,11 @@ class FunctionsTests(ReusedSQLTestCase):
                 TypeError,
                 lambda: df.select(col('name').substr(long(0), long(1))))
 
+        for name in _string_functions.keys():
+            self.assertEqual(
+                df.select(getattr(functions, name)("name")).first()[0],
+                df.select(getattr(functions, name)(col("name"))).first()[0])
+
     def test_array_contains_function(self):
         from pyspark.sql.functions import array_contains
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org