You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by np...@apache.org on 2023/01/18 17:38:14 UTC
[arrow] branch master updated: GH-18818: [R] Create a field ref to a field in a struct (#19706)
This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 1d9366f19e GH-18818: [R] Create a field ref to a field in a struct (#19706)
1d9366f19e is described below
commit 1d9366f19e4b9846b33cc0c7bd7941cb5f482d74
Author: Neal Richardson <ne...@gmail.com>
AuthorDate: Wed Jan 18 12:38:06 2023 -0500
GH-18818: [R] Create a field ref to a field in a struct (#19706)
This PR implements `$.Expression` and `[[.Expression` methods, such that if the Expression is a FieldRef, it returns a nested FieldRef. This required revising some assumptions in a few places, particularly that if an Expression is a FieldRef, it has a `name`, and that all FieldRefs correspond to a Field in a Schema. In the case where the Expression is not a FieldRef, it will create an Expression call to `struct_field` to extract the field, iff the Expression has a knowable `type`, the [...]
Things not done because they weren't needed to get this working:
* `Expression$field_ref()` take a vector to construct a nested ref
* Method to return vector of nested components of a field ref in R
Next steps for future PRs:
* Wrap this in [tidyr::unpack()](https://tidyr.tidyverse.org/reference/pack.html) method (but unfortunately, unpack() is not a generic)
* https://github.com/apache/arrow/issues/33756
* https://github.com/apache/arrow/issues/33757
* https://github.com/apache/arrow/issues/33760
* Closes: #18818
Authored-by: Neal Richardson <ne...@gmail.com>
Signed-off-by: Neal Richardson <ne...@gmail.com>
---
r/NAMESPACE | 3 ++
r/R/arrow-object.R | 2 +-
r/R/arrowExports.R | 9 ++++-
r/R/expression.R | 55 +++++++++++++++++++++++++++++
r/R/type.R | 3 ++
r/src/arrowExports.cpp | 19 ++++++++++
r/src/compute.cpp | 14 ++++++++
r/src/expression.cpp | 40 +++++++++++++++++++--
r/tests/testthat/test-dplyr-query.R | 70 +++++++++++++++++++++++++++++++++++++
r/tests/testthat/test-expression.R | 26 ++++++++++++++
10 files changed, 237 insertions(+), 4 deletions(-)
diff --git a/r/NAMESPACE b/r/NAMESPACE
index 3df107a2d8..3ab828a958 100644
--- a/r/NAMESPACE
+++ b/r/NAMESPACE
@@ -2,6 +2,7 @@
S3method("!=",ArrowObject)
S3method("$",ArrowTabular)
+S3method("$",Expression)
S3method("$",Schema)
S3method("$",StructArray)
S3method("$",SubTreeFileSystem)
@@ -14,6 +15,7 @@ S3method("[",Dataset)
S3method("[",Schema)
S3method("[",arrow_dplyr_query)
S3method("[[",ArrowTabular)
+S3method("[[",Expression)
S3method("[[",Schema)
S3method("[[",StructArray)
S3method("[[<-",ArrowTabular)
@@ -137,6 +139,7 @@ S3method(names,Scanner)
S3method(names,ScannerBuilder)
S3method(names,Schema)
S3method(names,StructArray)
+S3method(names,StructType)
S3method(names,Table)
S3method(names,arrow_dplyr_query)
S3method(print,"arrow-enum")
diff --git a/r/R/arrow-object.R b/r/R/arrow-object.R
index 516f407aaf..5c2cf4691f 100644
--- a/r/R/arrow-object.R
+++ b/r/R/arrow-object.R
@@ -32,7 +32,7 @@ ArrowObject <- R6Class("ArrowObject",
assign(".:xp:.", xp, envir = self)
},
class_title = function() {
- if (!is.null(self$.class_title)) {
+ if (".class_title" %in% ls(self, all.names = TRUE)) {
# Allow subclasses to override just printing the class name first
class_title <- self$.class_title()
} else {
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 38f1ecfb97..2eeca24dbd 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -1084,6 +1084,10 @@ compute___expr__call <- function(func_name, argument_list, options) {
.Call(`_arrow_compute___expr__call`, func_name, argument_list, options)
}
+compute___expr__is_field_ref <- function(x) {
+ .Call(`_arrow_compute___expr__is_field_ref`, x)
+}
+
field_names_in_expression <- function(x) {
.Call(`_arrow_field_names_in_expression`, x)
}
@@ -1096,6 +1100,10 @@ compute___expr__field_ref <- function(name) {
.Call(`_arrow_compute___expr__field_ref`, name)
}
+compute___expr__nested_field_ref <- function(x, name) {
+ .Call(`_arrow_compute___expr__nested_field_ref`, x, name)
+}
+
compute___expr__scalar <- function(x) {
.Call(`_arrow_compute___expr__scalar`, x)
}
@@ -2087,4 +2095,3 @@ SetIOThreadPoolCapacity <- function(threads) {
Array__infer_type <- function(x) {
.Call(`_arrow_Array__infer_type`, x)
}
-
diff --git a/r/R/expression.R b/r/R/expression.R
index a1163c12a8..8f84b4b31e 100644
--- a/r/R/expression.R
+++ b/r/R/expression.R
@@ -57,6 +57,9 @@ Expression <- R6Class("Expression",
assert_that(!is.null(schema))
compute___expr__type_id(self, schema)
},
+ is_field_ref = function() {
+ compute___expr__is_field_ref(self)
+ },
cast = function(to_type, safe = TRUE, ...) {
opts <- cast_options(safe, ...)
opts$to_type <- as_type(to_type)
@@ -89,7 +92,59 @@ Expression$create <- function(function_name,
expr
}
+
+#' @export
+`[[.Expression` <- function(x, i, ...) get_nested_field(x, i)
+
+#' @export
+`$.Expression` <- function(x, name, ...) {
+ assert_that(is.string(name))
+ if (name %in% ls(x)) {
+ get(name, x)
+ } else {
+ get_nested_field(x, name)
+ }
+}
+
+get_nested_field <- function(expr, name) {
+ if (expr$is_field_ref()) {
+ # Make a nested field ref
+ # TODO(#33756): integer (positional) field refs are supported in C++
+ assert_that(is.string(name))
+ out <- compute___expr__nested_field_ref(expr, name)
+ } else {
+ # Use the struct_field kernel if expr is a struct:
+ expr_type <- expr$type() # errors if no schema set
+ if (inherits(expr_type, "StructType")) {
+ # Because we have the type, we can validate that the field exists
+ if (!(name %in% names(expr_type))) {
+ stop(
+ "field '", name, "' not found in ",
+ expr_type$ToString(),
+ call. = FALSE
+ )
+ }
+ out <- Expression$create(
+ "struct_field",
+ expr,
+ options = list(field_ref = Expression$field_ref(name))
+ )
+ } else {
+ # TODO(#33757): if expr is list type and name is integer or Expression,
+ # call list_element
+ stop(
+ "Cannot extract a field from an Expression of type ", expr_type$ToString(),
+ call. = FALSE
+ )
+ }
+ }
+ # Schema bookkeeping
+ out$schema <- expr$schema
+ out
+}
+
Expression$field_ref <- function(name) {
+ # TODO(#33756): allow construction of field ref from integer
assert_that(is.string(name))
compute___expr__field_ref(name)
}
diff --git a/r/R/type.R b/r/R/type.R
index d1578dd822..bd69311b25 100644
--- a/r/R/type.R
+++ b/r/R/type.R
@@ -641,6 +641,9 @@ StructType$create <- function(...) struct__(.fields(list(...)))
#' @export
struct <- StructType$create
+#' @export
+names.StructType <- function(x) StructType__field_names(x)
+
ListType <- R6Class("ListType",
inherit = NestedType,
public = list(
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index b7bda1870f..e918390e26 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -2732,6 +2732,14 @@ BEGIN_CPP11
END_CPP11
}
// expression.cpp
+bool compute___expr__is_field_ref(const std::shared_ptr<compute::Expression>& x);
+extern "C" SEXP _arrow_compute___expr__is_field_ref(SEXP x_sexp){
+BEGIN_CPP11
+ arrow::r::Input<const std::shared_ptr<compute::Expression>&>::type x(x_sexp);
+ return cpp11::as_sexp(compute___expr__is_field_ref(x));
+END_CPP11
+}
+// expression.cpp
std::vector<std::string> field_names_in_expression(const std::shared_ptr<compute::Expression>& x);
extern "C" SEXP _arrow_field_names_in_expression(SEXP x_sexp){
BEGIN_CPP11
@@ -2756,6 +2764,15 @@ BEGIN_CPP11
END_CPP11
}
// expression.cpp
+std::shared_ptr<compute::Expression> compute___expr__nested_field_ref(const std::shared_ptr<compute::Expression>& x, std::string name);
+extern "C" SEXP _arrow_compute___expr__nested_field_ref(SEXP x_sexp, SEXP name_sexp){
+BEGIN_CPP11
+ arrow::r::Input<const std::shared_ptr<compute::Expression>&>::type x(x_sexp);
+ arrow::r::Input<std::string>::type name(name_sexp);
+ return cpp11::as_sexp(compute___expr__nested_field_ref(x, name));
+END_CPP11
+}
+// expression.cpp
std::shared_ptr<compute::Expression> compute___expr__scalar(const std::shared_ptr<arrow::Scalar>& x);
extern "C" SEXP _arrow_compute___expr__scalar(SEXP x_sexp){
BEGIN_CPP11
@@ -5569,9 +5586,11 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_MapType__keys_sorted", (DL_FUNC) &_arrow_MapType__keys_sorted, 1},
{ "_arrow_compute___expr__equals", (DL_FUNC) &_arrow_compute___expr__equals, 2},
{ "_arrow_compute___expr__call", (DL_FUNC) &_arrow_compute___expr__call, 3},
+ { "_arrow_compute___expr__is_field_ref", (DL_FUNC) &_arrow_compute___expr__is_field_ref, 1},
{ "_arrow_field_names_in_expression", (DL_FUNC) &_arrow_field_names_in_expression, 1},
{ "_arrow_compute___expr__get_field_ref_name", (DL_FUNC) &_arrow_compute___expr__get_field_ref_name, 1},
{ "_arrow_compute___expr__field_ref", (DL_FUNC) &_arrow_compute___expr__field_ref, 1},
+ { "_arrow_compute___expr__nested_field_ref", (DL_FUNC) &_arrow_compute___expr__nested_field_ref, 2},
{ "_arrow_compute___expr__scalar", (DL_FUNC) &_arrow_compute___expr__scalar, 1},
{ "_arrow_compute___expr__ToString", (DL_FUNC) &_arrow_compute___expr__ToString, 1},
{ "_arrow_compute___expr__type", (DL_FUNC) &_arrow_compute___expr__type, 2},
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index b4b4c5fdc8..578ce74d05 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -564,6 +564,20 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
return out;
}
+ if (func_name == "struct_field") {
+ using Options = arrow::compute::StructFieldOptions;
+ if (!Rf_isNull(options["indices"])) {
+ return std::make_shared<Options>(
+ cpp11::as_cpp<std::vector<int>>(options["indices"]));
+ } else {
+ // field_ref
+ return std::make_shared<Options>(
+ *cpp11::as_cpp<std::shared_ptr<arrow::compute::Expression>>(
+ options["field_ref"])
+ ->field_ref());
+ }
+ }
+
return nullptr;
}
diff --git a/r/src/expression.cpp b/r/src/expression.cpp
index a845137e09..d7a511e760 100644
--- a/r/src/expression.cpp
+++ b/r/src/expression.cpp
@@ -46,13 +46,26 @@ std::shared_ptr<compute::Expression> compute___expr__call(std::string func_name,
compute::call(std::move(func_name), std::move(arguments), std::move(options_ptr)));
}
+// [[arrow::export]]
+bool compute___expr__is_field_ref(const std::shared_ptr<compute::Expression>& x) {
+ return x->field_ref() != nullptr;
+}
+
// [[arrow::export]]
std::vector<std::string> field_names_in_expression(
const std::shared_ptr<compute::Expression>& x) {
std::vector<std::string> out;
+ std::vector<arrow::FieldRef> nested;
+
auto field_refs = FieldsInExpression(*x);
for (auto f : field_refs) {
- out.push_back(*f.name());
+ if (f.IsNested()) {
+ // We keep the top-level field name.
+ nested = *f.nested_refs();
+ out.push_back(*nested[0].name());
+ } else {
+ out.push_back(*f.name());
+ }
}
return out;
}
@@ -61,7 +74,11 @@ std::vector<std::string> field_names_in_expression(
std::string compute___expr__get_field_ref_name(
const std::shared_ptr<compute::Expression>& x) {
if (auto field_ref = x->field_ref()) {
- return *field_ref->name();
+ // Exclude nested field refs because we only use this to determine if we have simple
+ // field refs
+ if (!field_ref->IsNested()) {
+ return *field_ref->name();
+ }
}
return "";
}
@@ -71,6 +88,25 @@ std::shared_ptr<compute::Expression> compute___expr__field_ref(std::string name)
return std::make_shared<compute::Expression>(compute::field_ref(std::move(name)));
}
+// [[arrow::export]]
+std::shared_ptr<compute::Expression> compute___expr__nested_field_ref(
+ const std::shared_ptr<compute::Expression>& x, std::string name) {
+ if (auto field_ref = x->field_ref()) {
+ std::vector<arrow::FieldRef> ref_vec;
+ if (field_ref->IsNested()) {
+ ref_vec = *field_ref->nested_refs();
+ } else {
+ // There's just one
+ ref_vec.push_back(*field_ref);
+ }
+ // Add the new ref
+ ref_vec.push_back(arrow::FieldRef(std::move(name)));
+ return std::make_shared<compute::Expression>(compute::field_ref(std::move(ref_vec)));
+ } else {
+ cpp11::stop("'x' must be a FieldRef Expression");
+ }
+}
+
// [[arrow::export]]
std::shared_ptr<compute::Expression> compute___expr__scalar(
const std::shared_ptr<arrow::Scalar>& x) {
diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R
index ee11cd6678..a91c0b6ccb 100644
--- a/r/tests/testthat/test-dplyr-query.R
+++ b/r/tests/testthat/test-dplyr-query.R
@@ -714,3 +714,73 @@ test_that("Scalars in expressions match the type of the field, if possible", {
collect()
expect_equal(result$tpc_h_1, result$as_dbl)
})
+
+test_that("Can use nested field refs", {
+ nested_data <- tibble(int = 1:5, df_col = tibble(a = 6:10, b = 11:15))
+
+ compare_dplyr_binding(
+ .input %>%
+ mutate(
+ nested = df_col$a,
+ times2 = df_col$a * 2
+ ) %>%
+ filter(nested > 7) %>%
+ collect(),
+ nested_data
+ )
+
+ compare_dplyr_binding(
+ .input %>%
+ mutate(
+ nested = df_col$a,
+ times2 = df_col$a * 2
+ ) %>%
+ filter(nested > 7) %>%
+ summarize(sum(times2)) %>%
+ collect(),
+ nested_data
+ )
+
+ # Now with Dataset: make sure column pushdown in ScanNode works
+ expect_equal(
+ nested_data %>%
+ InMemoryDataset$create() %>%
+ mutate(
+ nested = df_col$a,
+ times2 = df_col$a * 2
+ ) %>%
+ filter(nested > 7) %>%
+ collect(),
+ nested_data %>%
+ mutate(
+ nested = df_col$a,
+ times2 = df_col$a * 2
+ ) %>%
+ filter(nested > 7)
+ )
+})
+
+test_that("Use struct_field for $ on non-field-ref", {
+ compare_dplyr_binding(
+ .input %>%
+ mutate(
+ df_col = tibble(i = int, d = dbl)
+ ) %>%
+ transmute(
+ int2 = df_col$i,
+ dbl2 = df_col$d
+ ) %>%
+ collect(),
+ example_data
+ )
+})
+
+test_that("nested field ref error handling", {
+ expect_error(
+ example_data %>%
+ arrow_table() %>%
+ mutate(x = int$nested) %>%
+ compute(),
+ "No match"
+ )
+})
diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R
index 2b6039b04c..ccb09b9eb0 100644
--- a/r/tests/testthat/test-expression.R
+++ b/r/tests/testthat/test-expression.R
@@ -76,6 +76,15 @@ test_that("Field reference expression schemas and types", {
expect_equal(x$type(), int32())
})
+test_that("Nested field refs", {
+ x <- Expression$field_ref("x")
+ nested <- x$y
+ expect_r6_class(nested, "Expression")
+ expect_r6_class(x[["y"]], "Expression")
+ expect_r6_class(nested$z, "Expression")
+ expect_error(Expression$scalar(42L)$y, "Cannot extract a field from an Expression of type int32")
+})
+
test_that("Scalar expression schemas and types", {
# type() works on scalars without setting the schema
expect_equal(
@@ -127,3 +136,20 @@ test_that("Expression schemas and types", {
int32()
)
})
+
+test_that("Nested field ref types", {
+ nested <- Expression$field_ref("x")$y
+ schm <- schema(x = struct(y = int32(), z = double()))
+ expect_equal(nested$type(schm), int32())
+ # implicit casting and schema propagation
+ x <- Expression$field_ref("x")
+ x$schema <- schm
+ expect_equal((x$y * 2)$type(), int32())
+})
+
+test_that("Nested field from a non-field-ref (struct_field kernel)", {
+ x <- Expression$scalar(data.frame(a = 1, b = "two"))
+ expect_true(inherits(x$a, "Expression"))
+ expect_equal(x$a$type(), float64())
+ expect_error(x$c, "field 'c' not found in struct<a: double, b: string>")
+})