You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/14 12:59:29 UTC

[spark] branch branch-3.4 updated: [SPARK-42433][PYTHON][CONNECT] Add `array_insert` to Connect

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 4a61ce2cc6d [SPARK-42433][PYTHON][CONNECT] Add `array_insert` to Connect
4a61ce2cc6d is described below

commit 4a61ce2cc6d6b75dd53efee5418e5a10e13851f7
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Feb 14 21:59:03 2023 +0900

    [SPARK-42433][PYTHON][CONNECT] Add `array_insert` to Connect
    
    ### What changes were proposed in this pull request?
    1, make `array_insert` accept int `pos` and `Any` value
    2, add it to connect
    
    ### Why are the changes needed?
    to be consistent with other pyspark functions
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added tests
    
    Closes #40010 from zhengruifeng/py_array_insert.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 14cc5a5341ad4f50c041c3a721d5f46586c83fd1)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/connect/functions.py                 |  8 ++++++++
 python/pyspark/sql/functions.py                         | 17 ++++++++++++-----
 .../pyspark/sql/tests/connect/test_connect_function.py  | 10 ++++++++++
 3 files changed, 30 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index 7955f8932aa..42b59d18a5b 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -1184,6 +1184,14 @@ def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
 array_except.__doc__ = pysparkfuncs.array_except.__doc__
 
 
+def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column:
+    _pos = lit(pos) if isinstance(pos, int) else _to_col(pos)
+    return _invoke_function("array_insert", _to_col(arr), _pos, lit(value))
+
+
+array_insert.__doc__ = pysparkfuncs.array_insert.__doc__
+
+
 def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
     return _invoke_function_over_columns("array_intersect", col1, col2)
 
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bfedf473b93..ac842101b28 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -7677,7 +7677,7 @@ def array_distinct(col: "ColumnOrName") -> Column:
 
 
 @try_remote_functions
-def array_insert(arr: "ColumnOrName", pos: "ColumnOrName", value: "ColumnOrName") -> Column:
+def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column:
     """
     Collection function: adds an item into a given array at a specified array index.
     Array indices start at 1 (or from the end if the index is negative).
@@ -7686,15 +7686,18 @@ def array_insert(arr: "ColumnOrName", pos: "ColumnOrName", value: "ColumnOrName"
 
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 3.4.0
+        Support Spark Connect.
+
     Parameters
     ----------
     arr : :class:`~pyspark.sql.Column` or str
         name of column containing an array
-    pos : :class:`~pyspark.sql.Column` or str
+    pos : :class:`~pyspark.sql.Column` or str or int
         name of Numeric type column indicating position of insertion
         (starting at index 1, negative position is a start from the back of the array)
-    value : :class:`~pyspark.sql.Column` or str
-        name of column containing values for insertion into array
+    value :
+        a literal value, or a :class:`~pyspark.sql.Column` expression.
 
     Returns
     -------
@@ -7709,8 +7712,12 @@ def array_insert(arr: "ColumnOrName", pos: "ColumnOrName", value: "ColumnOrName"
     ... )
     >>> df.select(array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect()
     [Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'd', 'b', 'a'])]
+    >>> df.select(array_insert(df.data, 5, 'hello').alias('data')).collect()
+    [Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])]
     """
-    return _invoke_function_over_columns("array_insert", arr, pos, value)
+    pos = lit(pos) if isinstance(pos, int) else pos
+
+    return _invoke_function_over_columns("array_insert", arr, pos, lit(value))
 
 
 @try_remote_functions
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index fbf2eab445a..9e499815107 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -1085,6 +1085,16 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S
             sdf.select(SF.array_append(sdf.a, sdf.f)).toPandas(),
         )
 
+        # test array_insert
+        self.assert_eq(
+            cdf.select(CF.array_insert(cdf.a, -5, "ab")).toPandas(),
+            sdf.select(SF.array_insert(sdf.a, -5, "ab")).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.array_insert(cdf.a, 3, cdf.f)).toPandas(),
+            sdf.select(SF.array_insert(sdf.a, 3, sdf.f)).toPandas(),
+        )
+
         # test array_join
         self.assert_eq(
             cdf.select(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org