You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/08/25 06:44:49 UTC

[spark] branch master updated: [SPARK-40214][PYTHON][SQL] add 'get' to functions

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 295e98d29b3 [SPARK-40214][PYTHON][SQL] add 'get' to functions
295e98d29b3 is described below

commit 295e98d29b34e2b472c375608b8782c3b9189444
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Thu Aug 25 14:44:18 2022 +0800

    [SPARK-40214][PYTHON][SQL] add 'get' to functions
    
    ### What changes were proposed in this pull request?
    expose `get` to dataframe functions
    
    ### Why are the changes needed?
    for function parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new API
    
    ### How was this patch tested?
    added UT
    
    Closes #37652 from zhengruifeng/py_get.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../source/reference/pyspark.sql/functions.rst     |  1 +
 python/pyspark/sql/functions.py                    | 70 ++++++++++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala     | 11 ++++
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 38 ++++++++++++
 4 files changed, 120 insertions(+)

diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index a799bb8ad0a..027babbf57d 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -176,6 +176,7 @@ Collection Functions
     explode_outer
     posexplode
     posexplode_outer
+    get
     get_json_object
     json_tuple
     from_json
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d59532f52cb..fd7a7247fc8 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -4832,6 +4832,10 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column:
     -----
     The position is not zero based, but 1 based index.
 
+    See Also
+    --------
+    :meth:`get`
+
     Examples
     --------
     >>> df = spark.createDataFrame([(["a", "b", "c"],)], ['data'])
@@ -4845,6 +4849,72 @@ def element_at(col: "ColumnOrName", extraction: Any) -> Column:
     return _invoke_function_over_columns("element_at", col, lit(extraction))
 
 
+def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column:
+    """
+    Collection function: Returns element of array at given (0-based) index.
+    If the index points outside of the array boundaries, then this function
+    returns NULL.
+
+    .. versionadded:: 3.4.0
+
+    Parameters
+    ----------
+    col : :class:`~pyspark.sql.Column` or str
+        name of column containing array
+    index : :class:`~pyspark.sql.Column` or str or int
+        index to check for in array
+
+    Notes
+    -----
+    The position is not 1 based, but 0 based index.
+
+    See Also
+    --------
+    :meth:`element_at`
+
+    Examples
+    --------
+    >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index'])
+    >>> df.select(get(df.data, 1)).show()
+    +------------+
+    |get(data, 1)|
+    +------------+
+    |           b|
+    +------------+
+
+    >>> df.select(get(df.data, -1)).show()
+    +-------------+
+    |get(data, -1)|
+    +-------------+
+    |         null|
+    +-------------+
+
+    >>> df.select(get(df.data, 3)).show()
+    +------------+
+    |get(data, 3)|
+    +------------+
+    |        null|
+    +------------+
+
+    >>> df.select(get(df.data, "index")).show()
+    +----------------+
+    |get(data, index)|
+    +----------------+
+    |               b|
+    +----------------+
+
+    >>> df.select(get(df.data, col("index") - 1)).show()
+    +----------------------+
+    |get(data, (index - 1))|
+    +----------------------+
+    |                     a|
+    +----------------------+
+    """
+    index = lit(index) if isinstance(index, int) else index
+
+    return _invoke_function_over_columns("get", col, index)
+
+
 def array_remove(col: "ColumnOrName", element: Any) -> Column:
     """
     Collection function: Remove all elements that equal to element from the given array.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index bd7473706ca..69da277d5e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3958,6 +3958,17 @@ object functions {
     ElementAt(column.expr, lit(value).expr)
   }
 
+  /**
+   * Returns element of array at given (0-based) index. If the index points
+   * outside of the array boundaries, then this function returns NULL.
+   *
+   * @group collection_funcs
+   * @since 3.4.0
+   */
+  def get(column: Column, index: Column): Column = withExpr {
+    new Get(column.expr, index.expr)
+  }
+
   /**
    * Sorts the input array in ascending order. The elements of the input array must be orderable.
    * NaN is greater than any non-NaN elements for double/float type.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index b80925f8638..ee41b1efba2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1628,6 +1628,44 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
     assert(e3.message.contains(errorMsg3))
   }
 
+  test("SPARK-40214: get function") {
+    val df = Seq(
+      (Seq[String]("1", "2", "3"), 2),
+      (Seq[String](null, ""), 1),
+      (Seq[String](), 2),
+      (null, 3)
+    ).toDF("a", "b")
+
+    checkAnswer(
+      df.select(get(df("a"), lit(-1))),
+      Seq(Row(null), Row(null), Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(get(df("a"), lit(0))),
+      Seq(Row("1"), Row(null), Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(get(df("a"), lit(1))),
+      Seq(Row("2"), Row(""), Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(get(df("a"), lit(2))),
+      Seq(Row("3"), Row(null), Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(get(df("a"), lit(3))),
+      Seq(Row(null), Row(null), Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(get(df("a"), df("b"))),
+      Seq(Row("3"), Row(""), Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(get(df("a"), df("b") - 1)),
+      Seq(Row("2"), Row(null), Row(null), Row(null))
+    )
+  }
+
   test("array_union functions") {
     val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b")
     val ans1 = Row(Seq(1, 2, 3, 4))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org