You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ia...@apache.org on 2021/07/18 03:12:18 UTC

[arrow] branch master updated: ARROW-13200: [R] Add binding for case_when()

This is an automated email from the ASF dual-hosted git repository.

ianmcook pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new b32e0bf  ARROW-13200: [R] Add binding for case_when()
b32e0bf is described below

commit b32e0bf79e53d098813930335bb0bf683970b4eb
Author: Ian Cook <ia...@gmail.com>
AuthorDate: Sat Jul 17 23:10:56 2021 -0400

    ARROW-13200: [R] Add binding for case_when()
    
    Adds support for `case_when()` in dplyr verbs. I followed the example of `dbplyr::case_when()`, which is much simpler and more self-contained than `dplyr::case_when()`.
    
    Closes #10737 from ianmcook/ARROW-13200
    
    Lead-authored-by: Ian Cook <ia...@gmail.com>
    Co-authored-by: Neal Richardson <ne...@gmail.com>
    Signed-off-by: Ian Cook <ia...@gmail.com>
---
 r/NEWS.md                     |   1 +
 r/R/dplyr-functions.R         |  35 ++++++++++++-
 r/src/compute.cpp             |   7 +++
 r/tests/testthat/test-dplyr.R | 119 ++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 161 insertions(+), 1 deletion(-)

diff --git a/r/NEWS.md b/r/NEWS.md
index a1cd67a..9cd7542 100644
--- a/r/NEWS.md
+++ b/r/NEWS.md
@@ -26,6 +26,7 @@
   * String operations: `strsplit()` and `str_split()`; `strptime()`; `paste()`, `paste0()`, and `str_c()`; `substr()` and `str_sub()`; `str_like()`; `str_pad()`; `stri_reverse()`
   * Date/time operations: `lubridate` methods such as `year()`, `month()`, `wday()`, and so on
   * Math: `log()`, trigonometry (`sin()`, `cos()`, et al.), `abs()`, `sign()`, `pmin()`/`pmax()`
+  * Conditional: `ifelse()` and `if_else()` (fixed-precision decimal numbers do not yet work and factors/dictionaries are converted to character strings); `case_when()` (currently works with numeric data types but not character strings, factors/dictionaries, or lists/structs)
   * `is.*` functions are supported and can be used inside `relocate()`
 
 * The print method for `arrow_dplyr_query` now includes the expression and the resulting type of columns derived by `mutate()`.
diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R
index d118eef..d429920 100644
--- a/r/R/dplyr-functions.R
+++ b/r/R/dplyr-functions.R
@@ -698,6 +698,39 @@ nse_funcs$if_else <- function(condition, true, false, missing = NULL){
 
 # Although base R ifelse allows `yes` and `no` to be different classes
 #
-nse_funcs$ifelse <- function(test, yes, no){
+nse_funcs$ifelse <- function(test, yes, no) {
   nse_funcs$if_else(condition = test, true = yes, false = no)
 }
+
+nse_funcs$case_when <- function(...) {
+  formulas <- list2(...)
+  n <- length(formulas)
+  if (n == 0) {
+    abort("No cases provided in case_when()")
+  }
+  query <- vector("list", n)
+  value <- vector("list", n)
+  mask <- caller_env()
+  for (i in seq_len(n)) {
+    f <- formulas[[i]]
+    if (!inherits(f, "formula")) {
+      abort("Each argument to case_when() must be a two-sided formula")
+    }
+    query[[i]] <- arrow_eval(f[[2]], mask)
+    value[[i]] <- arrow_eval(f[[3]], mask)
+    if (!nse_funcs$is.logical(query[[i]])) {
+      abort("Left side of each formula in case_when() must be a logical expression")
+    }
+  }
+  build_expr(
+    "case_when",
+    args = c(
+      build_expr(
+        "make_struct",
+        args = query,
+        options = list(field_names = as.character(seq_along(query)))
+      ),
+      value
+    )
+  )
+}
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 2c5ee77..3082113 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -241,6 +241,13 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
     return out;
   }
 
+  if (func_name == "make_struct") {
+    using Options = arrow::compute::MakeStructOptions;
+    // TODO (ARROW-13371): accept `field_nullability` and `field_metadata` options
+    return std::make_shared<Options>(
+        cpp11::as_cpp<std::vector<std::string>>(options["field_names"]));
+  }
+
   if (func_name == "match_substring" || func_name == "match_substring_regex" ||
       func_name == "find_substring" || func_name == "find_substring_regex" ||
       func_name == "match_like") {
diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R
index e99f743..468ad85 100644
--- a/r/tests/testthat/test-dplyr.R
+++ b/r/tests/testthat/test-dplyr.R
@@ -1225,3 +1225,122 @@ test_that("if_else and ifelse", {
     tbl
   )
 })
+
+test_that("case_when()", {
+  expect_dplyr_equal(
+    input %>%
+      transmute(cw = case_when(lgl ~ dbl, !false ~ dbl + dbl2)) %>%
+      collect(),
+    tbl
+  )
+  expect_dplyr_equal(
+    input %>%
+      mutate(cw = case_when(int > 5 ~ 1, TRUE ~ 0)) %>%
+      collect(),
+    tbl
+  )
+  expect_dplyr_equal(
+    input %>%
+      transmute(cw = case_when(chr %in% letters[1:3] ~ 1L) + 41L) %>%
+      collect(),
+    tbl
+  )
+  expect_dplyr_equal(
+    input %>%
+      filter(case_when(
+        dbl + int - 1.1 == dbl2 ~ TRUE,
+        NA ~ NA,
+        TRUE ~ FALSE
+      ) & !is.na(dbl2)) %>%
+      collect(),
+    tbl
+  )
+
+  # dplyr::case_when() errors if values on right side of formulas do not have
+  # exactly the same type, but the Arrow case_when kernel allows compatible types
+  expect_equal(
+    tbl %>%
+      mutate(i64 = as.integer64(1e10)) %>%
+      Table$create() %>%
+      transmute(cw = case_when(
+        is.na(fct) ~ int,
+        is.na(chr) ~ dbl,
+        TRUE ~ i64
+      )) %>%
+      collect(),
+    tbl %>%
+      transmute(
+        cw = ifelse(is.na(fct), int, ifelse(is.na(chr), dbl, 1e10))
+      )
+  )
+
+  # expected errors (which are caught by abandon_ship() and changed to warnings)
+  # TODO: Find a way to test these directly without abandon_ship() interfering
+  expect_error(
+    # no cases
+    expect_warning(
+      tbl %>%
+        Table$create() %>%
+        transmute(cw = case_when()),
+      "case_when"
+    )
+  )
+  expect_error(
+    # argument not a formula
+    expect_warning(
+      tbl %>%
+        Table$create() %>%
+        transmute(cw = case_when(TRUE ~ FALSE, TRUE)),
+      "case_when"
+    )
+  )
+  expect_error(
+    # non-logical R scalar on left side of formula
+    expect_warning(
+      tbl %>%
+        Table$create() %>%
+        transmute(cw = case_when(0L ~ FALSE, TRUE ~ FALSE)),
+      "case_when"
+    )
+  )
+  expect_error(
+    # non-logical Arrow column reference on left side of formula
+    expect_warning(
+      tbl %>%
+        Table$create() %>%
+        transmute(cw = case_when(int ~ FALSE)),
+      "case_when"
+    )
+  )
+  expect_error(
+    # non-logical Arrow expression on left side of formula
+    expect_warning(
+      tbl %>%
+        Table$create() %>%
+        transmute(cw = case_when(dbl + 3.14159 ~ TRUE)),
+      "case_when"
+    )
+  )
+
+  skip("case_when does not yet support with variable-width types (ARROW-13222)")
+  expect_dplyr_equal(
+    input %>%
+      transmute(cw = case_when(lgl ~ "abc")) %>%
+      collect(),
+    tbl
+  )
+  expect_dplyr_equal(
+    input %>%
+      transmute(cw = case_when(lgl ~ verses, !false ~ paste(chr, chr))) %>%
+      collect(),
+    tbl
+  )
+  expect_dplyr_equal(
+    input %>%
+      mutate(
+        cw = paste0(case_when(!(!(!(lgl))) ~ factor(chr), TRUE ~ fct), "!")
+      ) %>%
+      collect(),
+    tbl
+  )
+})