You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/08/25 06:44:49 UTC
[spark] branch master updated: [SPARK-40214][PYTHON][SQL] add 'get' to functions
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 295e98d29b3 [SPARK-40214][PYTHON][SQL] add 'get' to functions
295e98d29b3 is described below
commit 295e98d29b34e2b472c375608b8782c3b9189444
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Thu Aug 25 14:44:18 2022 +0800
[SPARK-40214][PYTHON][SQL] add 'get' to functions
### What changes were proposed in this pull request?
expose `get` to dataframe functions
### Why are the changes needed?
for function parity
### Does this PR introduce _any_ user-facing change?
yes, new API
### How was this patch tested?
added UT
Closes #37652 from zhengruifeng/py_get.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../source/reference/pyspark.sql/functions.rst | 1 +
python/pyspark/sql/functions.py | 70 ++++++++++++++++++++++
.../scala/org/apache/spark/sql/functions.scala | 11 ++++
.../apache/spark/sql/DataFrameFunctionsSuite.scala | 38 ++++++++++++
4 files changed, 120 insertions(+)
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index a799bb8ad0a..027babbf57d 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -176,6 +176,7 @@ Collection Functions
explode_outer
posexplode
posexplode_outer
+ get
get_json_object
json_tuple
from_json
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d59532f52cb..fd7a7247fc8 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -4832,6 +4832,10 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column:
-----
The position is not zero based, but 1 based index.
+ See Also
+ --------
+ :meth:`get`
+
Examples
--------
>>> df = spark.createDataFrame([(["a", "b", "c"],)], ['data'])
@@ -4845,6 +4849,72 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column:
return _invoke_function_over_columns("element_at", col, lit(extraction))
+def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column:
+ """
+ Collection function: Returns element of array at given (0-based) index.
+ If the index points outside of the array boundaries, then this function
+ returns NULL.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column containing array
+ index : :class:`~pyspark.sql.Column` or str or int
+ index to check for in array
+
+ Notes
+ -----
+ The position is not 1 based, but 0 based index.
+
+ See Also
+ --------
+ :meth:`element_at`
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index'])
+ >>> df.select(get(df.data, 1)).show()
+ +------------+
+ |get(data, 1)|
+ +------------+
+ | b|
+ +------------+
+
+ >>> df.select(get(df.data, -1)).show()
+ +-------------+
+ |get(data, -1)|
+ +-------------+
+ | null|
+ +-------------+
+
+ >>> df.select(get(df.data, 3)).show()
+ +------------+
+ |get(data, 3)|
+ +------------+
+ | null|
+ +------------+
+
+ >>> df.select(get(df.data, "index")).show()
+ +----------------+
+ |get(data, index)|
+ +----------------+
+ | b|
+ +----------------+
+
+ >>> df.select(get(df.data, col("index") - 1)).show()
+ +----------------------+
+ |get(data, (index - 1))|
+ +----------------------+
+ | a|
+ +----------------------+
+ """
+ index = lit(index) if isinstance(index, int) else index
+
+ return _invoke_function_over_columns("get", col, index)
+
+
def array_remove(col: "ColumnOrName", element: Any) -> Column:
"""
Collection function: Remove all elements that equal to element from the given array.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index bd7473706ca..69da277d5e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3958,6 +3958,17 @@ object functions {
ElementAt(column.expr, lit(value).expr)
}
+ /**
+ * Returns element of array at given (0-based) index. If the index points
+ * outside of the array boundaries, then this function returns NULL.
+ *
+ * @group collection_funcs
+ * @since 3.4.0
+ */
+ def get(column: Column, index: Column): Column = withExpr {
+ new Get(column.expr, index.expr)
+ }
+
/**
* Sorts the input array in ascending order. The elements of the input array must be orderable.
* NaN is greater than any non-NaN elements for double/float type.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index b80925f8638..ee41b1efba2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1628,6 +1628,44 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
assert(e3.message.contains(errorMsg3))
}
+ test("SPARK-40214: get function") {
+ val df = Seq(
+ (Seq[String]("1", "2", "3"), 2),
+ (Seq[String](null, ""), 1),
+ (Seq[String](), 2),
+ (null, 3)
+ ).toDF("a", "b")
+
+ checkAnswer(
+ df.select(get(df("a"), lit(-1))),
+ Seq(Row(null), Row(null), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(get(df("a"), lit(0))),
+ Seq(Row("1"), Row(null), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(get(df("a"), lit(1))),
+ Seq(Row("2"), Row(""), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(get(df("a"), lit(2))),
+ Seq(Row("3"), Row(null), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(get(df("a"), lit(3))),
+ Seq(Row(null), Row(null), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(get(df("a"), df("b"))),
+ Seq(Row("3"), Row(""), Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(get(df("a"), df("b") - 1)),
+ Seq(Row("2"), Row(null), Row(null), Row(null))
+ )
+ }
+
test("array_union functions") {
val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b")
val ans1 = Row(Seq(1, 2, 3, 4))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org