You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by np...@apache.org on 2022/11/02 23:15:50 UTC

[arrow] branch master updated: ARROW-15460: [R] Add as.data.frame.Dataset method (#14461)

This is an automated email from the ASF dual-hosted git repository.

npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e53978b56 ARROW-15460: [R] Add as.data.frame.Dataset method (#14461)
5e53978b56 is described below

commit 5e53978b56aa13f9c033f2e849cc22f2aed6e2d3
Author: Neal Richardson <ne...@gmail.com>
AuthorDate: Wed Nov 2 19:15:40 2022 -0400

    ARROW-15460: [R] Add as.data.frame.Dataset method (#14461)
    
    Plus some refactoring and disentangling of compute/collect methods
    
    Authored-by: Neal Richardson <ne...@gmail.com>
    Signed-off-by: Neal Richardson <ne...@gmail.com>
---
 r/NAMESPACE                     |  2 ++
 r/R/dataset.R                   |  7 ++++-
 r/R/dplyr-collect.R             | 57 +++++++++++++++--------------------------
 r/R/dplyr-group-by.R            | 30 ++++++++++++++++++----
 r/R/dplyr.R                     |  6 ++++-
 r/R/metadata.R                  | 13 +++++++---
 r/R/table.R                     | 14 +++++++++-
 r/man/as_arrow_table.Rd         |  3 +++
 r/man/open_dataset.Rd           |  2 +-
 r/tests/testthat/test-dataset.R |  5 ++++
 r/tests/testthat/test-udf.R     |  1 +
 11 files changed, 92 insertions(+), 48 deletions(-)

diff --git a/r/NAMESPACE b/r/NAMESPACE
index 4a0c6ed261..0b18ace9ad 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -29,6 +29,7 @@ S3method(as.character,ArrowDatum)
 S3method(as.character,FileFormat)
 S3method(as.character,FragmentScanOptions)
 S3method(as.data.frame,ArrowTabular)
+S3method(as.data.frame,Dataset)
 S3method(as.data.frame,RecordBatchReader)
 S3method(as.data.frame,Schema)
 S3method(as.data.frame,StructArray)
@@ -47,6 +48,7 @@ S3method(as_arrow_array,data.frame)
 S3method(as_arrow_array,default)
 S3method(as_arrow_array,pyarrow.lib.Array)
 S3method(as_arrow_array,vctrs_list_of)
+S3method(as_arrow_table,Dataset)
 S3method(as_arrow_table,RecordBatch)
 S3method(as_arrow_table,RecordBatchReader)
 S3method(as_arrow_table,Schema)
diff --git a/r/R/dataset.R b/r/R/dataset.R
index 54ac30e56b..78b59ecc24 100644
--- a/r/R/dataset.R
+++ b/r/R/dataset.R
@@ -131,7 +131,7 @@
 #' dir.create(tf)
 #' on.exit(unlink(tf))
 #'
-#' write_dataset(mtcars, tf, partitioning="cyl")
+#' write_dataset(mtcars, tf, partitioning = "cyl")
 #'
 #' # You can specify a directory containing the files for your dataset and
 #' # open_dataset will scan all files in your directory.
@@ -397,6 +397,11 @@ dim.Dataset <- function(x) c(x$num_rows, x$num_cols)
 #' @export
 c.Dataset <- function(...) Dataset$create(list(...))
 
+#' @export
+as.data.frame.Dataset <- function(x, row.names = NULL, optional = FALSE, ...) {
+  collect.Dataset(x)
+}
+
 #' @export
 head.Dataset <- function(x, n = 6L, ...) {
   head(Scanner$create(x), n)
diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R
index 8bf22728d6..395026ce78 100644
--- a/r/R/dplyr-collect.R
+++ b/r/R/dplyr-collect.R
@@ -19,19 +19,8 @@
 # The following S3 methods are registered on load if dplyr is present
 
 collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
-  tryCatch(
-    out <- as_arrow_table(x),
-    # n = 4 because we want the error to show up as being from collect()
-    # and not augment_io_error_msg()
-    error = function(e, call = caller_env(n = 4)) {
-      augment_io_error_msg(e, call, schema = x$.data$schema)
-    }
-  )
-
-  if (as_data_frame) {
-    out <- as.data.frame(out)
-  }
-  restore_dplyr_features(out, x)
+  out <- compute.arrow_dplyr_query(x)
+  collect.ArrowTabular(out, as_data_frame)
 }
 collect.ArrowTabular <- function(x, as_data_frame = TRUE, ...) {
   if (as_data_frame) {
@@ -40,10 +29,27 @@ collect.ArrowTabular <- function(x, as_data_frame = TRUE, ...) {
     x
   }
 }
-collect.Dataset <- collect.RecordBatchReader <- function(x, ...) dplyr::collect(as_adq(x), ...)
+collect.Dataset <- function(x, as_data_frame = TRUE, ...) {
+  collect.ArrowTabular(compute.Dataset(x), as_data_frame)
+}
+collect.RecordBatchReader <- collect.Dataset
 
-compute.arrow_dplyr_query <- function(x, ...) dplyr::collect(x, as_data_frame = FALSE)
 compute.ArrowTabular <- function(x, ...) x
+compute.arrow_dplyr_query <- function(x, ...) {
+  # TODO: should this tryCatch move down into as_arrow_table()?
+  tryCatch(
+    as_arrow_table(x),
+    # n = 4 because we want the error to show up as being from compute()
+    # and not augment_io_error_msg()
+    error = function(e, call = caller_env(n = 4)) {
+      # Use a dummy schema() here because the CSV file reader handler is only
+      # valid when you read_csv_arrow() with a schema, but Dataset always has
+      # schema
+      # TODO: clean up this
+      augment_io_error_msg(e, call, schema = schema())
+    }
+  )
+}
 compute.Dataset <- compute.RecordBatchReader <- compute.arrow_dplyr_query
 
 pull.Dataset <- function(.data,
@@ -93,27 +99,6 @@ handle_pull_as_vector <- function(out, as_vector) {
   out
 }
 
-restore_dplyr_features <- function(df, query) {
-  # An arrow_dplyr_query holds some attributes that Arrow doesn't know about
-  # After calling collect(), make sure these features are carried over
-
-  if (length(dplyr::group_vars(query))) {
-    if (is.data.frame(df)) {
-      # Preserve groupings, if present
-      df <- dplyr::group_by(
-        df,
-        !!!syms(dplyr::group_vars(query)),
-        .drop = dplyr::group_by_drop_default(query),
-        .add = FALSE
-      )
-    } else {
-      # This is a Table, via compute() or collect(as_data_frame = FALSE)
-      df$metadata$r$attributes$.group_vars <- dplyr::group_vars(query)
-    }
-  }
-  df
-}
-
 collapse.arrow_dplyr_query <- function(x, ...) {
   # Figure out what schema will result from the query
   x$schema <- implicit_schema(x)
diff --git a/r/R/dplyr-group-by.R b/r/R/dplyr-group-by.R
index a7b1ab9dbc..a9c7ed8bda 100644
--- a/r/R/dplyr-group-by.R
+++ b/r/R/dplyr-group-by.R
@@ -67,9 +67,13 @@ group_vars.ArrowTabular <- function(x) {
 
 # the logical literal in the two functions below controls the default value of
 # the .drop argument to group_by()
-group_by_drop_default.arrow_dplyr_query <-
-  function(.tbl) .tbl$drop_empty_groups %||% TRUE
-group_by_drop_default.Dataset <- group_by_drop_default.ArrowTabular <- group_by_drop_default.RecordBatchReader <-
+group_by_drop_default.arrow_dplyr_query <- function(.tbl) {
+  .tbl$drop_empty_groups %||% TRUE
+}
+group_by_drop_default.ArrowTabular <- function(.tbl) {
+  .tbl$metadata$r$attributes$.group_by_drop %||% TRUE
+}
+group_by_drop_default.Dataset <- group_by_drop_default.RecordBatchReader <-
   function(.tbl) TRUE
 
 ungroup.arrow_dplyr_query <- function(x, ...) {
@@ -79,6 +83,22 @@ ungroup.arrow_dplyr_query <- function(x, ...) {
 }
 ungroup.Dataset <- ungroup.RecordBatchReader <- force
 ungroup.ArrowTabular <- function(x) {
-  x$metadata$r$attributes$.group_vars <- NULL
-  x
+  set_group_attributes(x, NULL, NULL)
+}
+
+# Function to call after evaluating a query (as_arrow_table()) to add back any
+# group attributes to the Schema metadata. Or to remove them, pass NULL.
+set_group_attributes <- function(tab, group_vars, .drop) {
+  # dplyr::group_vars() returns character(0)
+  # so passing NULL means unset (ungroup)
+  if (is.null(group_vars) || length(group_vars)) {
+    # Since accessing schema metadata does some work, only overwrite if needed
+    new_atts <- old_atts <- tab$metadata$r$attributes
+    new_atts[[".group_vars"]] <- group_vars
+    new_atts[[".group_by_drop"]] <- .drop
+    if (!identical(new_atts, old_atts)) {
+      tab$metadata$r$attributes <- new_atts
+    }
+  }
+  tab
 }
diff --git a/r/R/dplyr.R b/r/R/dplyr.R
index c7b58bc027..5b620c6cc4 100644
--- a/r/R/dplyr.R
+++ b/r/R/dplyr.R
@@ -254,7 +254,11 @@ tail.arrow_dplyr_query <- function(x, n = 6L, ...) {
 
   if (!missing(i)) {
     out <- take_dataset_rows(x, i)
-    x <- restore_dplyr_features(out, x)
+    x <- set_group_attributes(
+      out,
+      dplyr::group_vars(x),
+      dplyr::group_by_drop_default(x)
+    )
   }
   x
 }
diff --git a/r/R/metadata.R b/r/R/metadata.R
index 74080e8f48..747f08069e 100644
--- a/r/R/metadata.R
+++ b/r/R/metadata.R
@@ -77,7 +77,7 @@ apply_arrow_r_metadata <- function(x, r_metadata) {
         trace <- trace_back()
         # TODO: remove `trace$calls %||% trace$call` once rlang > 0.4.11 is released
         in_dplyr_collect <- any(map_lgl(trace$calls %||% trace$call, function(x) {
-          grepl("collect.arrow_dplyr_query", x, fixed = TRUE)[[1]]
+          grepl("collect\\.([aA]rrow|Dataset)", x)[[1]]
         }))
         if (in_dplyr_collect) {
           warning(
@@ -103,8 +103,13 @@ apply_arrow_r_metadata <- function(x, r_metadata) {
           attr(x, "row.names") <- NULL
         }
         if (!is.null(attr(x, ".group_vars")) && requireNamespace("dplyr", quietly = TRUE)) {
-          x <- dplyr::group_by(x, !!!syms(attr(x, ".group_vars")))
+          x <- dplyr::group_by(
+            x,
+            !!!syms(attr(x, ".group_vars")),
+            .drop = attr(x, ".group_by_drop") %||% TRUE
+          )
           attr(x, ".group_vars") <- NULL
+          attr(x, ".group_by_drop") <- NULL
         }
       }
     },
@@ -144,9 +149,11 @@ arrow_attributes <- function(x, only_top_level = FALSE) {
     # uses, which may be large
     if (requireNamespace("dplyr", quietly = TRUE)) {
       gv <- dplyr::group_vars(x)
+      drop <- dplyr::group_by_drop_default(x)
       x <- dplyr::ungroup(x)
-      # ungroup() first, then set attribute, bc ungroup() would erase it
+      # ungroup() first, then set attributes, bc ungroup() would erase it
       att[[".group_vars"]] <- gv
+      att[[".group_by_drop"]] <- drop
       removed_attributes <- c(removed_attributes, "groups", "class")
     }
   }
diff --git a/r/R/table.R b/r/R/table.R
index 2007a3887b..3318060ce0 100644
--- a/r/R/table.R
+++ b/r/R/table.R
@@ -330,10 +330,22 @@ as_arrow_table.RecordBatchReader <- function(x, ...) {
   x$read_table()
 }
 
+#' @rdname as_arrow_table
+#' @export
+as_arrow_table.Dataset <- function(x, ...) {
+  Scanner$create(x)$ToTable()
+}
+
 #' @rdname as_arrow_table
 #' @export
 as_arrow_table.arrow_dplyr_query <- function(x, ...) {
-  as_arrow_table(as_record_batch_reader(x))
+  out <- as_arrow_table(as_record_batch_reader(x))
+  # arrow_dplyr_query holds group_by information. Set it on the table metadata.
+  set_group_attributes(
+    out,
+    dplyr::group_vars(x),
+    dplyr::group_by_drop_default(x)
+  )
 }
 
 #' @rdname as_arrow_table
diff --git a/r/man/as_arrow_table.Rd b/r/man/as_arrow_table.Rd
index 22d4ea1c19..1ca11be460 100644
--- a/r/man/as_arrow_table.Rd
+++ b/r/man/as_arrow_table.Rd
@@ -9,6 +9,7 @@
 \alias{as_arrow_table.RecordBatchReader}
 \alias{as_arrow_table.arrow_dplyr_query}
 \alias{as_arrow_table.Schema}
+\alias{as_arrow_table.Dataset}
 \title{Convert an object to an Arrow Table}
 \usage{
 as_arrow_table(x, ..., schema = NULL)
@@ -26,6 +27,8 @@ as_arrow_table(x, ..., schema = NULL)
 \method{as_arrow_table}{arrow_dplyr_query}(x, ...)
 
 \method{as_arrow_table}{Schema}(x, ...)
+
+\method{as_arrow_table}{Dataset}(x, ...)
 }
 \arguments{
 \item{x}{An object to convert to an Arrow Table}
diff --git a/r/man/open_dataset.Rd b/r/man/open_dataset.Rd
index 795c6f3d35..9e1473803a 100644
--- a/r/man/open_dataset.Rd
+++ b/r/man/open_dataset.Rd
@@ -167,7 +167,7 @@ tf <- tempfile()
 dir.create(tf)
 on.exit(unlink(tf))
 
-write_dataset(mtcars, tf, partitioning="cyl")
+write_dataset(mtcars, tf, partitioning = "cyl")
 
 # You can specify a directory containing the files for your dataset and
 # open_dataset will scan all files in your directory.
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index 72c1dd7ca9..a46eb8b6a9 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -324,6 +324,11 @@ test_that("Can set schema on dataset", {
   expect_equal(ds$schema, expected_schema)
 })
 
+test_that("as.data.frame.Dataset", {
+  ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
+  expect_identical(dim(as.data.frame(ds)), c(20L, 7L))
+})
+
 test_that("dim method returns the correct number of rows and columns", {
   ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
   expect_identical(dim(ds), c(20L, 7L))
diff --git a/r/tests/testthat/test-udf.R b/r/tests/testthat/test-udf.R
index 882eea03e5..7836255e86 100644
--- a/r/tests/testthat/test-udf.R
+++ b/r/tests/testthat/test-udf.R
@@ -298,6 +298,7 @@ test_that("nested exec plans can contain user-defined functions", {
 })
 
 test_that("head() on exec plan containing user-defined functions", {
+  skip("ARROW-18101")
   skip_if_not_available("dataset")
   skip_if_not(CanRunWithCapturedR())