You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ia...@apache.org on 2021/07/18 03:12:18 UTC
[arrow] branch master updated: ARROW-13200: [R] Add binding for
case_when()
This is an automated email from the ASF dual-hosted git repository.
ianmcook pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new b32e0bf ARROW-13200: [R] Add binding for case_when()
b32e0bf is described below
commit b32e0bf79e53d098813930335bb0bf683970b4eb
Author: Ian Cook <ia...@gmail.com>
AuthorDate: Sat Jul 17 23:10:56 2021 -0400
ARROW-13200: [R] Add binding for case_when()
Adds support for `case_when()` in dplyr verbs. I followed the example of `dbplyr::case_when()`, which is much simpler and more self-contained than `dplyr::case_when()`.
Closes #10737 from ianmcook/ARROW-13200
Lead-authored-by: Ian Cook <ia...@gmail.com>
Co-authored-by: Neal Richardson <ne...@gmail.com>
Signed-off-by: Ian Cook <ia...@gmail.com>
---
r/NEWS.md | 1 +
r/R/dplyr-functions.R | 35 ++++++++++++-
r/src/compute.cpp | 7 +++
r/tests/testthat/test-dplyr.R | 119 ++++++++++++++++++++++++++++++++++++++++++
4 files changed, 161 insertions(+), 1 deletion(-)
diff --git a/r/NEWS.md b/r/NEWS.md
index a1cd67a..9cd7542 100644
--- a/r/NEWS.md
+++ b/r/NEWS.md
@@ -26,6 +26,7 @@
* String operations: `strsplit()` and `str_split()`; `strptime()`; `paste()`, `paste0()`, and `str_c()`; `substr()` and `str_sub()`; `str_like()`; `str_pad()`; `stri_reverse()`
* Date/time operations: `lubridate` methods such as `year()`, `month()`, `wday()`, and so on
* Math: `log()`, trigonometry (`sin()`, `cos()`, et al.), `abs()`, `sign()`, `pmin()`/`pmax()`
+ * Conditional: `ifelse()` and `if_else()` (fixed-precision decimal numbers do not yet work and factors/dictionaries are converted to character strings); `case_when()` (currently works with numeric data types but not character strings, factors/dictionaries, or lists/structs)
* `is.*` functions are supported and can be used inside `relocate()`
* The print method for `arrow_dplyr_query` now includes the expression and the resulting type of columns derived by `mutate()`.
diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R
index d118eef..d429920 100644
--- a/r/R/dplyr-functions.R
+++ b/r/R/dplyr-functions.R
@@ -698,6 +698,39 @@ nse_funcs$if_else <- function(condition, true, false, missing = NULL){
# Although base R ifelse allows `yes` and `no` to be different classes
#
-nse_funcs$ifelse <- function(test, yes, no){
+nse_funcs$ifelse <- function(test, yes, no) {
nse_funcs$if_else(condition = test, true = yes, false = no)
}
+
+nse_funcs$case_when <- function(...) {
+ formulas <- list2(...)
+ n <- length(formulas)
+ if (n == 0) {
+ abort("No cases provided in case_when()")
+ }
+ query <- vector("list", n)
+ value <- vector("list", n)
+ mask <- caller_env()
+ for (i in seq_len(n)) {
+ f <- formulas[[i]]
+ if (!inherits(f, "formula")) {
+ abort("Each argument to case_when() must be a two-sided formula")
+ }
+ query[[i]] <- arrow_eval(f[[2]], mask)
+ value[[i]] <- arrow_eval(f[[3]], mask)
+ if (!nse_funcs$is.logical(query[[i]])) {
+ abort("Left side of each formula in case_when() must be a logical expression")
+ }
+ }
+ build_expr(
+ "case_when",
+ args = c(
+ build_expr(
+ "make_struct",
+ args = query,
+ options = list(field_names = as.character(seq_along(query)))
+ ),
+ value
+ )
+ )
+}
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 2c5ee77..3082113 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -241,6 +241,13 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
return out;
}
+ if (func_name == "make_struct") {
+ using Options = arrow::compute::MakeStructOptions;
+ // TODO (ARROW-13371): accept `field_nullability` and `field_metadata` options
+ return std::make_shared<Options>(
+ cpp11::as_cpp<std::vector<std::string>>(options["field_names"]));
+ }
+
if (func_name == "match_substring" || func_name == "match_substring_regex" ||
func_name == "find_substring" || func_name == "find_substring_regex" ||
func_name == "match_like") {
diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R
index e99f743..468ad85 100644
--- a/r/tests/testthat/test-dplyr.R
+++ b/r/tests/testthat/test-dplyr.R
@@ -1225,3 +1225,122 @@ test_that("if_else and ifelse", {
tbl
)
})
+
+test_that("case_when()", {
+ expect_dplyr_equal(
+ input %>%
+ transmute(cw = case_when(lgl ~ dbl, !false ~ dbl + dbl2)) %>%
+ collect(),
+ tbl
+ )
+ expect_dplyr_equal(
+ input %>%
+ mutate(cw = case_when(int > 5 ~ 1, TRUE ~ 0)) %>%
+ collect(),
+ tbl
+ )
+ expect_dplyr_equal(
+ input %>%
+ transmute(cw = case_when(chr %in% letters[1:3] ~ 1L) + 41L) %>%
+ collect(),
+ tbl
+ )
+ expect_dplyr_equal(
+ input %>%
+ filter(case_when(
+ dbl + int - 1.1 == dbl2 ~ TRUE,
+ NA ~ NA,
+ TRUE ~ FALSE
+ ) & !is.na(dbl2)) %>%
+ collect(),
+ tbl
+ )
+
+ # dplyr::case_when() errors if values on right side of formulas do not have
+ # exactly the same type, but the Arrow case_when kernel allows compatible types
+ expect_equal(
+ tbl %>%
+ mutate(i64 = as.integer64(1e10)) %>%
+ Table$create() %>%
+ transmute(cw = case_when(
+ is.na(fct) ~ int,
+ is.na(chr) ~ dbl,
+ TRUE ~ i64
+ )) %>%
+ collect(),
+ tbl %>%
+ transmute(
+ cw = ifelse(is.na(fct), int, ifelse(is.na(chr), dbl, 1e10))
+ )
+ )
+
+ # expected errors (which are caught by abandon_ship() and changed to warnings)
+ # TODO: Find a way to test these directly without abandon_ship() interfering
+ expect_error(
+ # no cases
+ expect_warning(
+ tbl %>%
+ Table$create() %>%
+ transmute(cw = case_when()),
+ "case_when"
+ )
+ )
+ expect_error(
+ # argument not a formula
+ expect_warning(
+ tbl %>%
+ Table$create() %>%
+ transmute(cw = case_when(TRUE ~ FALSE, TRUE)),
+ "case_when"
+ )
+ )
+ expect_error(
+ # non-logical R scalar on left side of formula
+ expect_warning(
+ tbl %>%
+ Table$create() %>%
+ transmute(cw = case_when(0L ~ FALSE, TRUE ~ FALSE)),
+ "case_when"
+ )
+ )
+ expect_error(
+ # non-logical Arrow column reference on left side of formula
+ expect_warning(
+ tbl %>%
+ Table$create() %>%
+ transmute(cw = case_when(int ~ FALSE)),
+ "case_when"
+ )
+ )
+ expect_error(
+ # non-logical Arrow expression on left side of formula
+ expect_warning(
+ tbl %>%
+ Table$create() %>%
+ transmute(cw = case_when(dbl + 3.14159 ~ TRUE)),
+ "case_when"
+ )
+ )
+
+ skip("case_when does not yet support with variable-width types (ARROW-13222)")
+ expect_dplyr_equal(
+ input %>%
+ transmute(cw = case_when(lgl ~ "abc")) %>%
+ collect(),
+ tbl
+ )
+ expect_dplyr_equal(
+ input %>%
+ transmute(cw = case_when(lgl ~ verses, !false ~ paste(chr, chr))) %>%
+ collect(),
+ tbl
+ )
+ expect_dplyr_equal(
+ input %>%
+ mutate(
+ cw = paste0(case_when(!(!(!(lgl))) ~ factor(chr), TRUE ~ fct), "!")
+ ) %>%
+ collect(),
+ tbl
+ )
+})