You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by np...@apache.org on 2023/01/18 17:38:14 UTC

[arrow] branch master updated: GH-18818: [R] Create a field ref to a field in a struct (#19706)

This is an automated email from the ASF dual-hosted git repository.

npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 1d9366f19e GH-18818: [R] Create a field ref to a field in a struct (#19706)
1d9366f19e is described below

commit 1d9366f19e4b9846b33cc0c7bd7941cb5f482d74
Author: Neal Richardson <ne...@gmail.com>
AuthorDate: Wed Jan 18 12:38:06 2023 -0500

    GH-18818: [R] Create a field ref to a field in a struct (#19706)
    
    This PR implements `$.Expression` and `[[.Expression` methods, such that if the Expression is a FieldRef, it returns a nested FieldRef. This required revising some assumptions in a few places, particularly that if an Expression is a FieldRef, it has a `name`, and that all FieldRefs correspond to a Field in a Schema. In the case where the Expression is not a FieldRef, it will create an Expression call to `struct_field` to extract the field, iff the Expression has a knowable `type`, the [...]
    
    Things not done because they weren't needed to get this working:
    
      * `Expression$field_ref()` take a vector to construct a nested ref
      * Method to return vector of nested components of a field ref in R
    
    Next steps for future PRs:
    
    * Wrap this in [tidyr::unpack()](https://tidyr.tidyverse.org/reference/pack.html) method (but unfortunately, unpack() is not a generic)
    * https://github.com/apache/arrow/issues/33756
    * https://github.com/apache/arrow/issues/33757
    * https://github.com/apache/arrow/issues/33760
    
    * Closes: #18818
    
    Authored-by: Neal Richardson <ne...@gmail.com>
    Signed-off-by: Neal Richardson <ne...@gmail.com>
---
 r/NAMESPACE                         |  3 ++
 r/R/arrow-object.R                  |  2 +-
 r/R/arrowExports.R                  |  9 ++++-
 r/R/expression.R                    | 55 +++++++++++++++++++++++++++++
 r/R/type.R                          |  3 ++
 r/src/arrowExports.cpp              | 19 ++++++++++
 r/src/compute.cpp                   | 14 ++++++++
 r/src/expression.cpp                | 40 +++++++++++++++++++--
 r/tests/testthat/test-dplyr-query.R | 70 +++++++++++++++++++++++++++++++++++++
 r/tests/testthat/test-expression.R  | 26 ++++++++++++++
 10 files changed, 237 insertions(+), 4 deletions(-)

diff --git a/r/NAMESPACE b/r/NAMESPACE
index 3df107a2d8..3ab828a958 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -2,6 +2,7 @@
 
 S3method("!=",ArrowObject)
 S3method("$",ArrowTabular)
+S3method("$",Expression)
 S3method("$",Schema)
 S3method("$",StructArray)
 S3method("$",SubTreeFileSystem)
@@ -14,6 +15,7 @@ S3method("[",Dataset)
 S3method("[",Schema)
 S3method("[",arrow_dplyr_query)
 S3method("[[",ArrowTabular)
+S3method("[[",Expression)
 S3method("[[",Schema)
 S3method("[[",StructArray)
 S3method("[[<-",ArrowTabular)
@@ -137,6 +139,7 @@ S3method(names,Scanner)
 S3method(names,ScannerBuilder)
 S3method(names,Schema)
 S3method(names,StructArray)
+S3method(names,StructType)
 S3method(names,Table)
 S3method(names,arrow_dplyr_query)
 S3method(print,"arrow-enum")
diff --git a/r/R/arrow-object.R b/r/R/arrow-object.R
index 516f407aaf..5c2cf4691f 100644
--- a/r/R/arrow-object.R
+++ b/r/R/arrow-object.R
@@ -32,7 +32,7 @@ ArrowObject <- R6Class("ArrowObject",
       assign(".:xp:.", xp, envir = self)
     },
     class_title = function() {
-      if (!is.null(self$.class_title)) {
+      if (".class_title" %in% ls(self, all.names = TRUE)) {
         # Allow subclasses to override just printing the class name first
         class_title <- self$.class_title()
       } else {
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 38f1ecfb97..2eeca24dbd 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -1084,6 +1084,10 @@ compute___expr__call <- function(func_name, argument_list, options) {
   .Call(`_arrow_compute___expr__call`, func_name, argument_list, options)
 }
 
+compute___expr__is_field_ref <- function(x) {
+  .Call(`_arrow_compute___expr__is_field_ref`, x)
+}
+
 field_names_in_expression <- function(x) {
   .Call(`_arrow_field_names_in_expression`, x)
 }
@@ -1096,6 +1100,10 @@ compute___expr__field_ref <- function(name) {
   .Call(`_arrow_compute___expr__field_ref`, name)
 }
 
+compute___expr__nested_field_ref <- function(x, name) {
+  .Call(`_arrow_compute___expr__nested_field_ref`, x, name)
+}
+
 compute___expr__scalar <- function(x) {
   .Call(`_arrow_compute___expr__scalar`, x)
 }
@@ -2087,4 +2095,3 @@ SetIOThreadPoolCapacity <- function(threads) {
 Array__infer_type <- function(x) {
   .Call(`_arrow_Array__infer_type`, x)
 }
-
diff --git a/r/R/expression.R b/r/R/expression.R
index a1163c12a8..8f84b4b31e 100644
--- a/r/R/expression.R
+++ b/r/R/expression.R
@@ -57,6 +57,9 @@ Expression <- R6Class("Expression",
       assert_that(!is.null(schema))
       compute___expr__type_id(self, schema)
     },
+    is_field_ref = function() {
+      compute___expr__is_field_ref(self)
+    },
     cast = function(to_type, safe = TRUE, ...) {
       opts <- cast_options(safe, ...)
       opts$to_type <- as_type(to_type)
@@ -89,7 +92,59 @@ Expression$create <- function(function_name,
   expr
 }
 
+
+#' @export
+`[[.Expression` <- function(x, i, ...) get_nested_field(x, i)
+
+#' @export
+`$.Expression` <- function(x, name, ...) {
+  assert_that(is.string(name))
+  if (name %in% ls(x)) {
+    get(name, x)
+  } else {
+    get_nested_field(x, name)
+  }
+}
+
+get_nested_field <- function(expr, name) {
+  if (expr$is_field_ref()) {
+    # Make a nested field ref
+    # TODO(#33756): integer (positional) field refs are supported in C++
+    assert_that(is.string(name))
+    out <- compute___expr__nested_field_ref(expr, name)
+  } else {
+    # Use the struct_field kernel if expr is a struct:
+    expr_type <- expr$type() # errors if no schema set
+    if (inherits(expr_type, "StructType")) {
+      # Because we have the type, we can validate that the field exists
+      if (!(name %in% names(expr_type))) {
+        stop(
+          "field '", name, "' not found in ",
+          expr_type$ToString(),
+          call. = FALSE
+        )
+      }
+      out <- Expression$create(
+        "struct_field",
+        expr,
+        options = list(field_ref = Expression$field_ref(name))
+      )
+    } else {
+      # TODO(#33757): if expr is list type and name is integer or Expression,
+      # call list_element
+      stop(
+        "Cannot extract a field from an Expression of type ", expr_type$ToString(),
+        call. = FALSE
+      )
+    }
+  }
+  # Schema bookkeeping
+  out$schema <- expr$schema
+  out
+}
+
 Expression$field_ref <- function(name) {
+  # TODO(#33756): allow construction of field ref from integer
   assert_that(is.string(name))
   compute___expr__field_ref(name)
 }
diff --git a/r/R/type.R b/r/R/type.R
index d1578dd822..bd69311b25 100644
--- a/r/R/type.R
+++ b/r/R/type.R
@@ -641,6 +641,9 @@ StructType$create <- function(...) struct__(.fields(list(...)))
 #' @export
 struct <- StructType$create
 
+#' @export
+names.StructType <- function(x) StructType__field_names(x)
+
 ListType <- R6Class("ListType",
   inherit = NestedType,
   public = list(
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index b7bda1870f..e918390e26 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -2732,6 +2732,14 @@ BEGIN_CPP11
 END_CPP11
 }
 // expression.cpp
+bool compute___expr__is_field_ref(const std::shared_ptr<compute::Expression>& x);
+extern "C" SEXP _arrow_compute___expr__is_field_ref(SEXP x_sexp){
+BEGIN_CPP11
+	arrow::r::Input<const std::shared_ptr<compute::Expression>&>::type x(x_sexp);
+	return cpp11::as_sexp(compute___expr__is_field_ref(x));
+END_CPP11
+}
+// expression.cpp
 std::vector<std::string> field_names_in_expression(const std::shared_ptr<compute::Expression>& x);
 extern "C" SEXP _arrow_field_names_in_expression(SEXP x_sexp){
 BEGIN_CPP11
@@ -2756,6 +2764,15 @@ BEGIN_CPP11
 END_CPP11
 }
 // expression.cpp
+std::shared_ptr<compute::Expression> compute___expr__nested_field_ref(const std::shared_ptr<compute::Expression>& x, std::string name);
+extern "C" SEXP _arrow_compute___expr__nested_field_ref(SEXP x_sexp, SEXP name_sexp){
+BEGIN_CPP11
+	arrow::r::Input<const std::shared_ptr<compute::Expression>&>::type x(x_sexp);
+	arrow::r::Input<std::string>::type name(name_sexp);
+	return cpp11::as_sexp(compute___expr__nested_field_ref(x, name));
+END_CPP11
+}
+// expression.cpp
 std::shared_ptr<compute::Expression> compute___expr__scalar(const std::shared_ptr<arrow::Scalar>& x);
 extern "C" SEXP _arrow_compute___expr__scalar(SEXP x_sexp){
 BEGIN_CPP11
@@ -5569,9 +5586,11 @@ static const R_CallMethodDef CallEntries[] = {
 		{ "_arrow_MapType__keys_sorted", (DL_FUNC) &_arrow_MapType__keys_sorted, 1}, 
 		{ "_arrow_compute___expr__equals", (DL_FUNC) &_arrow_compute___expr__equals, 2}, 
 		{ "_arrow_compute___expr__call", (DL_FUNC) &_arrow_compute___expr__call, 3}, 
+		{ "_arrow_compute___expr__is_field_ref", (DL_FUNC) &_arrow_compute___expr__is_field_ref, 1}, 
 		{ "_arrow_field_names_in_expression", (DL_FUNC) &_arrow_field_names_in_expression, 1}, 
 		{ "_arrow_compute___expr__get_field_ref_name", (DL_FUNC) &_arrow_compute___expr__get_field_ref_name, 1}, 
 		{ "_arrow_compute___expr__field_ref", (DL_FUNC) &_arrow_compute___expr__field_ref, 1}, 
+		{ "_arrow_compute___expr__nested_field_ref", (DL_FUNC) &_arrow_compute___expr__nested_field_ref, 2}, 
 		{ "_arrow_compute___expr__scalar", (DL_FUNC) &_arrow_compute___expr__scalar, 1}, 
 		{ "_arrow_compute___expr__ToString", (DL_FUNC) &_arrow_compute___expr__ToString, 1}, 
 		{ "_arrow_compute___expr__type", (DL_FUNC) &_arrow_compute___expr__type, 2}, 
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index b4b4c5fdc8..578ce74d05 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -564,6 +564,20 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
     return out;
   }
 
+  if (func_name == "struct_field") {
+    using Options = arrow::compute::StructFieldOptions;
+    if (!Rf_isNull(options["indices"])) {
+      return std::make_shared<Options>(
+          cpp11::as_cpp<std::vector<int>>(options["indices"]));
+    } else {
+      // field_ref
+      return std::make_shared<Options>(
+          *cpp11::as_cpp<std::shared_ptr<arrow::compute::Expression>>(
+               options["field_ref"])
+               ->field_ref());
+    }
+  }
+
   return nullptr;
 }
 
diff --git a/r/src/expression.cpp b/r/src/expression.cpp
index a845137e09..d7a511e760 100644
--- a/r/src/expression.cpp
+++ b/r/src/expression.cpp
@@ -46,13 +46,26 @@ std::shared_ptr<compute::Expression> compute___expr__call(std::string func_name,
       compute::call(std::move(func_name), std::move(arguments), std::move(options_ptr)));
 }
 
+// [[arrow::export]]
+bool compute___expr__is_field_ref(const std::shared_ptr<compute::Expression>& x) {
+  return x->field_ref() != nullptr;
+}
+
 // [[arrow::export]]
 std::vector<std::string> field_names_in_expression(
     const std::shared_ptr<compute::Expression>& x) {
   std::vector<std::string> out;
+  std::vector<arrow::FieldRef> nested;
+
   auto field_refs = FieldsInExpression(*x);
   for (auto f : field_refs) {
-    out.push_back(*f.name());
+    if (f.IsNested()) {
+      // We keep the top-level field name.
+      nested = *f.nested_refs();
+      out.push_back(*nested[0].name());
+    } else {
+      out.push_back(*f.name());
+    }
   }
   return out;
 }
@@ -61,7 +74,11 @@ std::vector<std::string> field_names_in_expression(
 std::string compute___expr__get_field_ref_name(
     const std::shared_ptr<compute::Expression>& x) {
   if (auto field_ref = x->field_ref()) {
-    return *field_ref->name();
+    // Exclude nested field refs because we only use this to determine if we have simple
+    // field refs
+    if (!field_ref->IsNested()) {
+      return *field_ref->name();
+    }
   }
   return "";
 }
@@ -71,6 +88,25 @@ std::shared_ptr<compute::Expression> compute___expr__field_ref(std::string name)
   return std::make_shared<compute::Expression>(compute::field_ref(std::move(name)));
 }
 
+// [[arrow::export]]
+std::shared_ptr<compute::Expression> compute___expr__nested_field_ref(
+    const std::shared_ptr<compute::Expression>& x, std::string name) {
+  if (auto field_ref = x->field_ref()) {
+    std::vector<arrow::FieldRef> ref_vec;
+    if (field_ref->IsNested()) {
+      ref_vec = *field_ref->nested_refs();
+    } else {
+      // There's just one
+      ref_vec.push_back(*field_ref);
+    }
+    // Add the new ref
+    ref_vec.push_back(arrow::FieldRef(std::move(name)));
+    return std::make_shared<compute::Expression>(compute::field_ref(std::move(ref_vec)));
+  } else {
+    cpp11::stop("'x' must be a FieldRef Expression");
+  }
+}
+
 // [[arrow::export]]
 std::shared_ptr<compute::Expression> compute___expr__scalar(
     const std::shared_ptr<arrow::Scalar>& x) {
diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R
index ee11cd6678..a91c0b6ccb 100644
--- a/r/tests/testthat/test-dplyr-query.R
+++ b/r/tests/testthat/test-dplyr-query.R
@@ -714,3 +714,73 @@ test_that("Scalars in expressions match the type of the field, if possible", {
     collect()
   expect_equal(result$tpc_h_1, result$as_dbl)
 })
+
+test_that("Can use nested field refs", {
+  nested_data <- tibble(int = 1:5, df_col = tibble(a = 6:10, b = 11:15))
+
+  compare_dplyr_binding(
+    .input %>%
+      mutate(
+        nested = df_col$a,
+        times2 = df_col$a * 2
+      ) %>%
+      filter(nested > 7) %>%
+      collect(),
+    nested_data
+  )
+
+  compare_dplyr_binding(
+    .input %>%
+      mutate(
+        nested = df_col$a,
+        times2 = df_col$a * 2
+      ) %>%
+      filter(nested > 7) %>%
+      summarize(sum(times2)) %>%
+      collect(),
+    nested_data
+  )
+
+  # Now with Dataset: make sure column pushdown in ScanNode works
+  expect_equal(
+    nested_data %>%
+      InMemoryDataset$create() %>%
+      mutate(
+        nested = df_col$a,
+        times2 = df_col$a * 2
+      ) %>%
+      filter(nested > 7) %>%
+      collect(),
+    nested_data %>%
+      mutate(
+        nested = df_col$a,
+        times2 = df_col$a * 2
+      ) %>%
+      filter(nested > 7)
+  )
+})
+
+test_that("Use struct_field for $ on non-field-ref", {
+  compare_dplyr_binding(
+    .input %>%
+      mutate(
+        df_col = tibble(i = int, d = dbl)
+      ) %>%
+      transmute(
+        int2 = df_col$i,
+        dbl2 = df_col$d
+      ) %>%
+      collect(),
+    example_data
+  )
+})
+
+test_that("nested field ref error handling", {
+  expect_error(
+    example_data %>%
+      arrow_table() %>%
+      mutate(x = int$nested) %>%
+      compute(),
+    "No match"
+  )
+})
diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R
index 2b6039b04c..ccb09b9eb0 100644
--- a/r/tests/testthat/test-expression.R
+++ b/r/tests/testthat/test-expression.R
@@ -76,6 +76,15 @@ test_that("Field reference expression schemas and types", {
   expect_equal(x$type(), int32())
 })
 
+test_that("Nested field refs", {
+  x <- Expression$field_ref("x")
+  nested <- x$y
+  expect_r6_class(nested, "Expression")
+  expect_r6_class(x[["y"]], "Expression")
+  expect_r6_class(nested$z, "Expression")
+  expect_error(Expression$scalar(42L)$y, "Cannot extract a field from an Expression of type int32")
+})
+
 test_that("Scalar expression schemas and types", {
   # type() works on scalars without setting the schema
   expect_equal(
@@ -127,3 +136,20 @@ test_that("Expression schemas and types", {
     int32()
   )
 })
+
+test_that("Nested field ref types", {
+  nested <- Expression$field_ref("x")$y
+  schm <- schema(x = struct(y = int32(), z = double()))
+  expect_equal(nested$type(schm), int32())
+  # implicit casting and schema propagation
+  x <- Expression$field_ref("x")
+  x$schema <- schm
+  expect_equal((x$y * 2)$type(), int32())
+})
+
+test_that("Nested field from a non-field-ref (struct_field kernel)", {
+  x <- Expression$scalar(data.frame(a = 1, b = "two"))
+  expect_true(inherits(x$a, "Expression"))
+  expect_equal(x$a$type(), float64())
+  expect_error(x$c, "field 'c' not found in struct<a: double, b: string>")
+})