You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ze...@apache.org on 2022/08/22 18:18:53 UTC
[spark] branch master updated: [SPARK-40166][SPARK-40167][PYTHON][R][SQL] Expose array_sort(column, comparator) in Python and R
This is an automated email from the ASF dual-hosted git repository.
zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fdd2db48b1b [SPARK-40166][SPARK-40167][PYTHON][R][SQL] Expose array_sort(column, comparator) in Python and R
fdd2db48b1b is described below
commit fdd2db48b1b6de2cd2bcbac7f063913ad08b3d6b
Author: zero323 <ms...@gmail.com>
AuthorDate: Mon Aug 22 20:18:06 2022 +0200
[SPARK-40166][SPARK-40167][PYTHON][R][SQL] Expose array_sort(column, comparator) in Python and R
### What changes were proposed in this pull request?
This PR exposes array_sort(column, comparator) in Python and R.
### Why are the changes needed?
Feature parity.
### Does this PR introduce _any_ user-facing change?
New signature in Python and R APIs.
### How was this patch tested?
- New doctest, manual testing (Python)
- New unit test (R)
Closes #37600 from zero323/SPARK-40166.
Authored-by: zero323 <ms...@gmail.com>
Signed-off-by: zero323 <ms...@gmail.com>
---
R/pkg/R/functions.R | 21 ++++++++++++++++++---
R/pkg/R/generics.R | 2 +-
R/pkg/tests/fulltests/test_sparkSQL.R | 10 ++++++++++
python/pyspark/sql/functions.py | 25 ++++++++++++++++++++++---
4 files changed, 51 insertions(+), 7 deletions(-)
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index d772c9bd4e4..00e2bec670a 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -258,6 +258,13 @@ NULL
#' into accumulator (the first argument).
#' @param finish an unary \code{function} \code{(Column) -> Column} used to
#' apply final transformation on the accumulated data in \code{array_aggregate}.
+#' @param comparator an optional binary (\code{(Column, Column) -> Column}) \code{function}
+#' which is used to compare the elemnts of the array.
+#' The comparator will take two
+#' arguments representing two elements of the array. It returns a negative integer,
+#' 0, or a positive integer as the first element is less than, equal to,
+#' or greater than the second element.
+#' If the comparator function returns null, the function will fail and raise an error.
#' @param ... additional argument(s).
#' \itemize{
#' \item \code{to_json}, \code{from_json} and \code{schema_of_json}: this contains
@@ -292,6 +299,7 @@ NULL
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1), shuffle(tmp$v1)))
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
+#' head(select(tmp, array_sort(tmp$v1, function(x, y) coalesce(cast(y - x, "integer"), lit(0L)))))
#' head(select(tmp, reverse(tmp$v1), array_remove(tmp$v1, 21)))
#' head(select(tmp, array_transform("v1", function(x) x * 10)))
#' head(select(tmp, array_exists("v1", function(x) x > 120)))
@@ -4141,9 +4149,16 @@ setMethod("array_repeat",
#' @note array_sort since 2.4.0
setMethod("array_sort",
signature(x = "Column"),
- function(x) {
- jc <- callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc)
- column(jc)
+ function(x, comparator = NULL) {
+ if (is.null(comparator)) {
+ column(callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc))
+ } else {
+ invoke_higher_order_function(
+ "ArraySort",
+ cols = list(x),
+ funs = list(comparator)
+ )
+ }
})
#' @details
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 089fb882053..93cd0f3bff3 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -840,7 +840,7 @@ setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat")
#' @rdname column_collection_functions
#' @name NULL
-setGeneric("array_sort", function(x) { standardGeneric("array_sort") })
+setGeneric("array_sort", function(x, ...) { standardGeneric("array_sort") })
#' @rdname column_ml_functions
#' @name NULL
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index 33ca43b11f9..68fa6aac8e7 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1604,6 +1604,16 @@ test_that("column functions", {
result <- collect(select(df, array_sort(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L, NA), list(4L, 5L, 6L, NA, NA)))
+ result <- collect(select(
+ df,
+ array_sort(
+ df[[1]],
+ function(x, y) otherwise(
+ when(isNull(x), 1L), otherwise(when(isNull(y), -1L), cast(y - x, "integer"))
+ )
+ )
+ ))[[1]]
+ expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA)))
result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]]
expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA)))
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 474d225b0d0..abedaf24417 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -4914,25 +4914,44 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column:
return _invoke_function("sort_array", _to_java_column(col), asc)
-def array_sort(col: "ColumnOrName") -> Column:
+def array_sort(
+ col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None
+) -> Column:
"""
Collection function: sorts the input array in ascending order. The elements of the input array
must be orderable. Null elements will be placed at the end of the returned array.
.. versionadded:: 2.4.0
+ .. versionchanged:: 3.4.0
+ Can take a `comparator` function.
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
name of column or expression
+ comparator : callable, optional
+ A binary ``(Column, Column) -> Column: ...``.
+ The comparator will take two
+ arguments representing two elements of the array. It returns a negative integer, 0, or a
+ positive integer as the first element is less than, equal to, or greater than the second
+ element. If the comparator function returns null, the function will fail and raise an error.
Examples
--------
>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
>>> df.select(array_sort(df.data).alias('r')).collect()
[Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
- """
- return _invoke_function_over_columns("array_sort", col)
+ >>> df = spark.createDataFrame([(["foo", "foobar", None, "bar"],),(["foo"],),([],)], ['data'])
+ >>> df.select(array_sort(
+ ... "data",
+ ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x))
+ ... ).alias("r")).collect()
+ [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])]
+ """
+ if comparator is None:
+ return _invoke_function_over_columns("array_sort", col)
+ else:
+ return _invoke_higher_order_function("ArraySort", [col], [comparator])
def shuffle(col: "ColumnOrName") -> Column:
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org