You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ze...@apache.org on 2022/08/22 18:18:53 UTC

[spark] branch master updated: [SPARK-40166][SPARK-40167][PYTHON][R][SQL] Expose array_sort(column, comparator) in Python and R

This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fdd2db48b1b [SPARK-40166][SPARK-40167][PYTHON][R][SQL] Expose array_sort(column, comparator) in Python and R
fdd2db48b1b is described below

commit fdd2db48b1b6de2cd2bcbac7f063913ad08b3d6b
Author: zero323 <ms...@gmail.com>
AuthorDate: Mon Aug 22 20:18:06 2022 +0200

    [SPARK-40166][SPARK-40167][PYTHON][R][SQL] Expose array_sort(column, comparator) in Python and R
    
    ### What changes were proposed in this pull request?
    
    This PR exposes array_sort(column, comparator) in Python and R.
    
    ### Why are the changes needed?
    
    Feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    New signature in Python and R APIs.
    
    ### How was this patch tested?
    
    - New doctest, manual testing (Python)
    - New unit test (R)
    
    Closes #37600 from zero323/SPARK-40166.
    
    Authored-by: zero323 <ms...@gmail.com>
    Signed-off-by: zero323 <ms...@gmail.com>
---
 R/pkg/R/functions.R                   | 21 ++++++++++++++++++---
 R/pkg/R/generics.R                    |  2 +-
 R/pkg/tests/fulltests/test_sparkSQL.R | 10 ++++++++++
 python/pyspark/sql/functions.py       | 25 ++++++++++++++++++++++---
 4 files changed, 51 insertions(+), 7 deletions(-)

diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index d772c9bd4e4..00e2bec670a 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -258,6 +258,13 @@ NULL
 #'          into accumulator (the first argument).
 #' @param finish an unary \code{function} \code{(Column) -> Column} used to
 #'          apply final transformation on the accumulated data in \code{array_aggregate}.
+#' @param comparator an optional binary (\code{(Column, Column) -> Column}) \code{function}
+#'          which is used to compare the elemnts of the array.
+#'          The comparator will take two
+#'          arguments representing two elements of the array. It returns a negative integer,
+#'          0, or a positive integer as the first element is less than, equal to,
+#'          or greater than the second element.
+#'          If the comparator function returns null, the function will fail and raise an error.
 #' @param ... additional argument(s).
 #'          \itemize{
 #'          \item \code{to_json}, \code{from_json} and \code{schema_of_json}: this contains
@@ -292,6 +299,7 @@ NULL
 #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1), shuffle(tmp$v1)))
 #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1)))
 #' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
+#' head(select(tmp, array_sort(tmp$v1, function(x, y) coalesce(cast(y - x, "integer"), lit(0L)))))
 #' head(select(tmp, reverse(tmp$v1), array_remove(tmp$v1, 21)))
 #' head(select(tmp, array_transform("v1", function(x) x * 10)))
 #' head(select(tmp, array_exists("v1", function(x) x > 120)))
@@ -4141,9 +4149,16 @@ setMethod("array_repeat",
 #' @note array_sort since 2.4.0
 setMethod("array_sort",
           signature(x = "Column"),
-          function(x) {
-            jc <- callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc)
-            column(jc)
+          function(x, comparator = NULL) {
+            if (is.null(comparator)) {
+               column(callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc))
+            } else {
+              invoke_higher_order_function(
+                "ArraySort",
+                cols = list(x),
+                funs = list(comparator)
+              )
+            }
           })
 
 #' @details
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 089fb882053..93cd0f3bff3 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -840,7 +840,7 @@ setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat")
 
 #' @rdname column_collection_functions
 #' @name NULL
-setGeneric("array_sort", function(x) { standardGeneric("array_sort") })
+setGeneric("array_sort", function(x, ...) { standardGeneric("array_sort") })
 
 #' @rdname column_ml_functions
 #' @name NULL
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index 33ca43b11f9..68fa6aac8e7 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1604,6 +1604,16 @@ test_that("column functions", {
 
   result <- collect(select(df, array_sort(df[[1]])))[[1]]
   expect_equal(result, list(list(1L, 2L, 3L, NA), list(4L, 5L, 6L, NA, NA)))
+  result <- collect(select(
+    df,
+    array_sort(
+      df[[1]],
+      function(x, y) otherwise(
+        when(isNull(x), 1L), otherwise(when(isNull(y), -1L), cast(y - x, "integer"))
+      )
+    )
+  ))[[1]]
+  expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA)))
 
   result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]]
   expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA)))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 474d225b0d0..abedaf24417 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -4914,25 +4914,44 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column:
     return _invoke_function("sort_array", _to_java_column(col), asc)
 
 
-def array_sort(col: "ColumnOrName") -> Column:
+def array_sort(
+    col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None
+) -> Column:
     """
     Collection function: sorts the input array in ascending order. The elements of the input array
     must be orderable. Null elements will be placed at the end of the returned array.
 
     .. versionadded:: 2.4.0
+    .. versionchanged:: 3.4.0
+        Can take a `comparator` function.
 
     Parameters
     ----------
     col : :class:`~pyspark.sql.Column` or str
         name of column or expression
+    comparator : callable, optional
+        A binary ``(Column, Column) -> Column: ...``.
+        The comparator will take two
+        arguments representing two elements of the array. It returns a negative integer, 0, or a
+        positive integer as the first element is less than, equal to, or greater than the second
+        element. If the comparator function returns null, the function will fail and raise an error.
 
     Examples
     --------
     >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
     >>> df.select(array_sort(df.data).alias('r')).collect()
     [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
-    """
-    return _invoke_function_over_columns("array_sort", col)
+    >>> df = spark.createDataFrame([(["foo", "foobar", None, "bar"],),(["foo"],),([],)], ['data'])
+    >>> df.select(array_sort(
+    ...     "data",
+    ...     lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x))
+    ... ).alias("r")).collect()
+    [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])]
+    """
+    if comparator is None:
+        return _invoke_function_over_columns("array_sort", col)
+    else:
+        return _invoke_higher_order_function("ArraySort", [col], [comparator])
 
 
 def shuffle(col: "ColumnOrName") -> Column:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org