You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ze...@apache.org on 2022/01/05 19:13:28 UTC

[spark] branch master updated: [SPARK-37788][PYTHON] Update remaining PySpark functions to use ColumnOrName (over Column), where appropriate

This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6bc03cb  [SPARK-37788][PYTHON] Update remaining PySpark functions to use ColumnOrName (over Column), where appropriate
6bc03cb is described below

commit 6bc03cb50cc0b262568096b222327ceebffbd9d6
Author: Daniel-Davies <da...@gmail.com>
AuthorDate: Wed Jan 5 20:12:27 2022 +0100

    [SPARK-37788][PYTHON] Update remaining PySpark functions to use ColumnOrName (over Column), where appropriate
    
    ### What changes were proposed in this pull request?
    Please see https://issues.apache.org/jira/browse/SPARK-37788
    
    There are a few remaining functions that should but don't yet support ColumnOrName; this PR updates some annotations of functions that do support it, and converts input column string names to columns if not being done already.
    
    ### Why are the changes needed?
    API consistency in PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    Yes; namely two array functions:
    
    - array_repeat; can now support `df.select(array_repeat("data", "repeat_n").alias('r'))`
    - slice: can now support `df.select(slice("data", "index", "length").alias('r'))`
    
    Affecting developers: there are also annotation changes to the following functions:
    
    - overlay
    - least
    
    Previously, annotations in these functions related only to Column (but the function could support either column names or columns). The annotations for parameters of these two functions have been updated to ColumnOrName.
    
    ### How was this patch tested?
    Modification to three existing unit tests, and an additional test for 'least'
    
    Closes #35071 from Daniel-Davies/Daniel-Davies/update-function-input-params.
    
    Lead-authored-by: Daniel-Davies <da...@gmail.com>
    Co-authored-by: Daniel Davies <dd...@palantir.com>
    Co-authored-by: Daniel-Davies <33...@users.noreply.github.com>
    Signed-off-by: zero323 <ms...@gmail.com>
---
 python/pyspark/sql/functions.py            | 107 ++++++++++++++++++++---------
 python/pyspark/sql/tests/test_functions.py |  76 +++++++++++++++++---
 2 files changed, 139 insertions(+), 44 deletions(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 4791d3c..9061d83 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1715,13 +1715,18 @@ def greatest(*cols: "ColumnOrName") -> Column:
     return Column(sc._jvm.functions.greatest(_to_seq(sc, cols, _to_java_column)))
 
 
-def least(*cols: Column) -> Column:
+def least(*cols: "ColumnOrName") -> Column:
     """
     Returns the least value of the list of column names, skipping null values.
     This function takes at least 2 parameters. It will return null iff all parameters are null.
 
     .. versionadded:: 1.5.0
 
+    Parameters
+    ----------
+    cols : :class:`~pyspark.sql.Column` or str
+        column names or columns to be compared
+
     Examples
     --------
     >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
@@ -1757,6 +1762,8 @@ def when(condition: Column, value: Any) -> Column:
     """
     sc = SparkContext._active_spark_context
     assert sc is not None and sc._jvm is not None
+
+    # Explicitly not using ColumnOrName type here to make reading condition less opaque
     if not isinstance(condition, Column):
         raise TypeError("condition should be a Column")
     v = value._jc if isinstance(value, Column) else value
@@ -2737,11 +2744,11 @@ def session_window(timeColumn: "ColumnOrName", gapDuration: Union[Column, str])
 
     Parameters
     ----------
-    timeColumn : :class:`~pyspark.sql.Column`
-        The column or the expression to use as the timestamp for windowing by time.
+    timeColumn : :class:`~pyspark.sql.Column` or str
+        The column name or column to use as the timestamp for windowing by time.
         The time column must be of TimestampType.
     gapDuration : :class:`~pyspark.sql.Column` or str
-        A column or string specifying the timeout of the session. It could be static value,
+        A Python string literal or column specifying the timeout of the session. It could be static value,
         e.g. `10 minutes`, `1 second`, or an expression/UDF that specifies gap
         duration dynamically based on the input row.
 
@@ -2884,6 +2891,13 @@ def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None
 
     .. versionadded:: 3.1.0
 
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        column name or column that represents the input column to test
+    errMsg : :class:`~pyspark.sql.Column` or str
+        A Python string literal or column containing the error message
+
     Examples
     --------
     >>> df = spark.createDataFrame([(0,1)], ['a', 'b'])
@@ -2913,6 +2927,11 @@ def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None
 def raise_error(errMsg: Union[Column, str]) -> Column:
     """
     Throws an exception with the provided error message.
+
+    Parameters
+    ----------
+    errMsg : :class:`~pyspark.sql.Column` or str
+        A Python string literal or column containing the error message
     """
     if not isinstance(errMsg, (str, Column)):
         raise TypeError("errMsg should be a Column or a str, got {}".format(type(errMsg)))
@@ -3102,8 +3121,8 @@ def instr(str: "ColumnOrName", substr: str) -> Column:
 def overlay(
     src: "ColumnOrName",
     replace: "ColumnOrName",
-    pos: Union[Column, int],
-    len: Union[Column, int] = -1,
+    pos: Union["ColumnOrName", int],
+    len: Union["ColumnOrName", int] = -1,
 ) -> Column:
     """
     Overlay the specified portion of `src` with `replace`,
@@ -3111,15 +3130,27 @@ def overlay(
 
     .. versionadded:: 3.0.0
 
+    Parameters
+    ----------
+    src : :class:`~pyspark.sql.Column` or str
+        column name or column containing the string that will be replaced
+    replace : :class:`~pyspark.sql.Column` or str
+        column name or column containing the substitution string
+    pos : :class:`~pyspark.sql.Column` or str or int
+        column name, column, or int containing the starting position in src
+    len : :class:`~pyspark.sql.Column` or str or int
+        column name, column, or int containing the number of bytes to replace in src string by 'replace'
+        defaults to -1, which represents the length of the 'replace' string
+
     Examples
     --------
     >>> df = spark.createDataFrame([("SPARK_SQL", "CORE")], ("x", "y"))
-    >>> df.select(overlay("x", "y", 7).alias("overlayed")).show()
-    +----------+
-    | overlayed|
-    +----------+
-    |SPARK_CORE|
-    +----------+
+    >>> df.select(overlay("x", "y", 7).alias("overlayed")).collect()
+    [Row(overlayed='SPARK_CORE')]
+    >>> df.select(overlay("x", "y", 7, 0).alias("overlayed")).collect()
+    [Row(overlayed='SPARK_CORESQL')]
+    >>> df.select(overlay("x", "y", 7, 2).alias("overlayed")).collect()
+    [Row(overlayed='SPARK_COREL')]
     """
     if not isinstance(pos, (int, str, Column)):
         raise TypeError(
@@ -3707,7 +3738,9 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column:
     return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))
 
 
-def slice(x: "ColumnOrName", start: Union[Column, int], length: Union[Column, int]) -> Column:
+def slice(
+    x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]
+) -> Column:
     """
     Collection function: returns an array containing  all the elements in `x` from index `start`
     (array indices start at 1, or from the end if `start` is negative) with the specified `length`.
@@ -3717,11 +3750,11 @@ def slice(x: "ColumnOrName", start: Union[Column, int], length: Union[Column, in
     Parameters
     ----------
     x : :class:`~pyspark.sql.Column` or str
-        the array to be sliced
-    start : :class:`~pyspark.sql.Column` or int
-        the starting index
-    length : :class:`~pyspark.sql.Column` or int
-        the length of the slice
+        column name or column containing the array to be sliced
+    start : :class:`~pyspark.sql.Column` or str or int
+        column name, column, or int containing the starting index
+    length : :class:`~pyspark.sql.Column` or str or int
+        column name, column, or int containing the length of the slice
 
     Examples
     --------
@@ -3731,11 +3764,15 @@ def slice(x: "ColumnOrName", start: Union[Column, int], length: Union[Column, in
     """
     sc = SparkContext._active_spark_context
     assert sc is not None and sc._jvm is not None
+
+    start = lit(start) if isinstance(start, int) else start
+    length = lit(length) if isinstance(length, int) else length
+
     return Column(
         sc._jvm.functions.slice(
             _to_java_column(x),
-            start._jc if isinstance(start, Column) else start,
-            length._jc if isinstance(length, Column) else length,
+            _to_java_column(start),
+            _to_java_column(length),
         )
     )
 
@@ -4172,12 +4209,10 @@ def from_json(
     Parameters
     ----------
     col : :class:`~pyspark.sql.Column` or str
-        string column in json format
+        a column or column name in JSON format
     schema : :class:`DataType` or str
-        a StructType or ArrayType of StructType to use when parsing the json column.
-
-        .. versionchanged:: 2.3
-            the DDL-formatted string is also supported for ``schema``.
+        a StructType, ArrayType of StructType or Python string literal with a DDL-formatted string
+        to use when parsing the json column
     options : dict, optional
         options to control parsing. accepts the same options as the json datasource.
         See `Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option>`_
@@ -4687,12 +4722,19 @@ def map_from_entries(col: "ColumnOrName") -> Column:
     return Column(sc._jvm.functions.map_from_entries(_to_java_column(col)))
 
 
-def array_repeat(col: "ColumnOrName", count: Union[Column, int]) -> Column:
+def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column:
     """
     Collection function: creates an array containing a column repeated count times.
 
     .. versionadded:: 2.4.0
 
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        column name or column that contains the element to be repeated
+    count : :class:`~pyspark.sql.Column` or str or int
+        column name, column, or int containing the number of times to repeat the first argument
+
     Examples
     --------
     >>> df = spark.createDataFrame([('ab',)], ['data'])
@@ -4701,11 +4743,10 @@ def array_repeat(col: "ColumnOrName", count: Union[Column, int]) -> Column:
     """
     sc = SparkContext._active_spark_context
     assert sc is not None and sc._jvm is not None
-    return Column(
-        sc._jvm.functions.array_repeat(
-            _to_java_column(col), _to_java_column(count) if isinstance(count, Column) else count
-        )
-    )
+
+    count = lit(count) if isinstance(count, int) else count
+
+    return Column(sc._jvm.functions.array_repeat(_to_java_column(col), _to_java_column(count)))
 
 
 def arrays_zip(*cols: "ColumnOrName") -> Column:
@@ -4806,9 +4847,9 @@ def from_csv(
     Parameters
     ----------
     col : :class:`~pyspark.sql.Column` or str
-        string column in CSV format
+        a column or column name in CSV format
     schema :class:`~pyspark.sql.Column` or str
-        a string with schema in DDL format to use when parsing the CSV column.
+        a column, or Python string literal with schema in DDL format, to use when parsing the CSV column.
     options : dict, optional
         options to control parsing. accepts the same options as the CSV datasource.
         See `Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-csv.html#data-source-option>`_
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index eb3b433..5021da5 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -46,6 +46,10 @@ from pyspark.sql.functions import (
     date_add,
     date_sub,
     add_months,
+    array_repeat,
+    size,
+    slice,
+    least,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
 
@@ -486,13 +490,31 @@ class FunctionsTests(ReusedSQLTestCase):
             self.assertEqual(result[0], "")
 
     def test_slice(self):
-        from pyspark.sql.functions import lit, size, slice
-
-        df = self.spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"])
+        df = self.spark.createDataFrame(
+            [
+                (
+                    [1, 2, 3],
+                    2,
+                    2,
+                ),
+                (
+                    [4, 5],
+                    2,
+                    2,
+                ),
+            ],
+            ["x", "index", "len"],
+        )
 
-        self.assertEqual(
-            df.select(slice(df.x, 2, 2).alias("sliced")).collect(),
-            df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect(),
+        expected = [Row(sliced=[2, 3]), Row(sliced=[5])]
+        self.assertTrue(
+            all(
+                [
+                    df.select(slice(df.x, 2, 2).alias("sliced")).collect() == expected,
+                    df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect() == expected,
+                    df.select(slice("x", "index", "len").alias("sliced")).collect() == expected,
+                ]
+            )
         )
 
         self.assertEqual(
@@ -505,13 +527,18 @@ class FunctionsTests(ReusedSQLTestCase):
         )
 
     def test_array_repeat(self):
-        from pyspark.sql.functions import array_repeat, lit
-
         df = self.spark.range(1)
+        df = df.withColumn("repeat_n", lit(3))
 
-        self.assertEqual(
-            df.select(array_repeat("id", 3)).toDF("val").collect(),
-            df.select(array_repeat("id", lit(3))).toDF("val").collect(),
+        expected = [Row(val=[0, 0, 0])]
+        self.assertTrue(
+            all(
+                [
+                    df.select(array_repeat("id", 3).alias("val")).collect() == expected,
+                    df.select(array_repeat("id", lit(3)).alias("val")).collect() == expected,
+                    df.select(array_repeat("id", "repeat_n").alias("val")).collect() == expected,
+                ]
+            )
         )
 
     def test_input_file_name_udf(self):
@@ -520,6 +547,20 @@ class FunctionsTests(ReusedSQLTestCase):
         file_name = df.collect()[0].file
         self.assertTrue("python/test_support/hello/hello.txt" in file_name)
 
+    def test_least(self):
+        df = self.spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"])
+
+        expected = [Row(least=1)]
+        self.assertTrue(
+            all(
+                [
+                    df.select(least(df.a, df.b, df.c).alias("least")).collect() == expected,
+                    df.select(least(lit(3), lit(5), lit(1)).alias("least")).collect() == expected,
+                    df.select(least("a", "b", "c").alias("least")).collect() == expected,
+                ]
+            )
+        )
+
     def test_overlay(self):
         from pyspark.sql.functions import col, lit, overlay
         from itertools import chain
@@ -552,6 +593,19 @@ class FunctionsTests(ReusedSQLTestCase):
 
         self.assertListEqual(actual, expected)
 
+        df = self.spark.createDataFrame([("SPARK_SQL", "CORE", 7, 0)], ("x", "y", "pos", "len"))
+
+        exp = [Row(ol="SPARK_CORESQL")]
+        self.assertTrue(
+            all(
+                [
+                    df.select(overlay(df.x, df.y, 7, 0).alias("ol")).collect() == exp,
+                    df.select(overlay(df.x, df.y, lit(7), lit(0)).alias("ol")).collect() == exp,
+                    df.select(overlay("x", "y", "pos", "len").alias("ol")).collect() == exp,
+                ]
+            )
+        )
+
     def test_percentile_approx(self):
         actual = list(
             chain.from_iterable(

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org