You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/14 12:59:29 UTC
[spark] branch branch-3.4 updated: [SPARK-42433][PYTHON][CONNECT] Add `array_insert` to Connect
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 4a61ce2cc6d [SPARK-42433][PYTHON][CONNECT] Add `array_insert` to Connect
4a61ce2cc6d is described below
commit 4a61ce2cc6d6b75dd53efee5418e5a10e13851f7
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Feb 14 21:59:03 2023 +0900
[SPARK-42433][PYTHON][CONNECT] Add `array_insert` to Connect
### What changes were proposed in this pull request?
1, make `array_insert` accept int `pos` and `Any` value
2, add it to connect
### Why are the changes needed?
to be consistent with other pyspark functions
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added tests
Closes #40010 from zhengruifeng/py_array_insert.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 14cc5a5341ad4f50c041c3a721d5f46586c83fd1)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/connect/functions.py | 8 ++++++++
python/pyspark/sql/functions.py | 17 ++++++++++++-----
.../pyspark/sql/tests/connect/test_connect_function.py | 10 ++++++++++
3 files changed, 30 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index 7955f8932aa..42b59d18a5b 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -1184,6 +1184,14 @@ def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
array_except.__doc__ = pysparkfuncs.array_except.__doc__
+def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column:
+ _pos = lit(pos) if isinstance(pos, int) else _to_col(pos)
+ return _invoke_function("array_insert", _to_col(arr), _pos, lit(value))
+
+
+array_insert.__doc__ = pysparkfuncs.array_insert.__doc__
+
+
def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
return _invoke_function_over_columns("array_intersect", col1, col2)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bfedf473b93..ac842101b28 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -7677,7 +7677,7 @@ def array_distinct(col: "ColumnOrName") -> Column:
@try_remote_functions
-def array_insert(arr: "ColumnOrName", pos: "ColumnOrName", value: "ColumnOrName") -> Column:
+def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column:
"""
Collection function: adds an item into a given array at a specified array index.
Array indices start at 1 (or from the end if the index is negative).
@@ -7686,15 +7686,18 @@ def array_insert(arr: "ColumnOrName", pos: "ColumnOrName", value: "ColumnOrName"
.. versionadded:: 3.4.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
arr : :class:`~pyspark.sql.Column` or str
name of column containing an array
- pos : :class:`~pyspark.sql.Column` or str
+ pos : :class:`~pyspark.sql.Column` or str or int
name of Numeric type column indicating position of insertion
(starting at index 1, negative position is a start from the back of the array)
- value : :class:`~pyspark.sql.Column` or str
- name of column containing values for insertion into array
+ value :
+ a literal value, or a :class:`~pyspark.sql.Column` expression.
Returns
-------
@@ -7709,8 +7712,12 @@ def array_insert(arr: "ColumnOrName", pos: "ColumnOrName", value: "ColumnOrName"
... )
>>> df.select(array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect()
[Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'd', 'b', 'a'])]
+ >>> df.select(array_insert(df.data, 5, 'hello').alias('data')).collect()
+ [Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])]
"""
- return _invoke_function_over_columns("array_insert", arr, pos, value)
+ pos = lit(pos) if isinstance(pos, int) else pos
+
+ return _invoke_function_over_columns("array_insert", arr, pos, lit(value))
@try_remote_functions
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index fbf2eab445a..9e499815107 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -1085,6 +1085,16 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S
sdf.select(SF.array_append(sdf.a, sdf.f)).toPandas(),
)
+ # test array_insert
+ self.assert_eq(
+ cdf.select(CF.array_insert(cdf.a, -5, "ab")).toPandas(),
+ sdf.select(SF.array_insert(sdf.a, -5, "ab")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.array_insert(cdf.a, 3, cdf.f)).toPandas(),
+ sdf.select(SF.array_insert(sdf.a, 3, sdf.f)).toPandas(),
+ )
+
# test array_join
self.assert_eq(
cdf.select(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org