You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ro...@apache.org on 2019/06/26 12:21:14 UTC
[arrow] branch master updated: ARROW-3811 [R]: Support inferring
data.frame column as StructArray in array constructors
This is an automated email from the ASF dual-hosted git repository.
romainfrancois pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new c631c9b ARROW-3811 [R]: Support inferring data.frame column as StructArray in array constructors
c631c9b is described below
commit c631c9b07c1f06362ab5c292db51f66a2095d362
Author: Romain Francois <ro...@rstudio.com>
AuthorDate: Wed Jun 26 14:21:04 2019 +0200
ARROW-3811 [R]: Support inferring data.frame column as StructArray in array constructors
So that `array()` and `chunked_array()` handle data frame columns (struct arrays in arrow speech), and record batches and tables handle them too.
``` r
library(arrow, warn.conflicts = FALSE)
library(tibble)
df <- tibble(x = 1:10, y = letters[1:10])
a <- array(df)
a$type
#> arrow::StructType
#> struct<x: int32, y: string>
a$as_vector()
#> x y
#> 1 1 a
#> 2 2 b
#> 3 3 c
#> 4 4 d
#> 5 5 e
#> 6 6 f
#> 7 7 g
#> 8 8 h
#> 9 9 i
#> 10 10 j
batch <- record_batch(a = rnorm(10), y = df)
batch$schema
#> arrow::Schema
#> a: double
#> y: struct<x: int32, y: string>
as_tibble(batch)
#> # A tibble: 10 x 2
#> a y$x $y
#> <dbl> <int> <chr>
#> 1 -0.747 1 a
#> 2 0.285 2 b
#> 3 0.159 3 c
#> 4 0.707 4 d
#> 5 0.366 5 e
#> 6 0.237 6 f
#> 7 -0.730 7 g
#> 8 -0.880 8 h
#> 9 -0.833 9 i
#> 10 1.21 10 j
tab <- table(a = rnorm(10), y = df)
tab$schema
#> arrow::Schema
#> a: double
#> y: struct<x: int32, y: string>
as_tibble(tab)
#> # A tibble: 10 x 2
#> a y$x $y
#> <dbl> <int> <chr>
#> 1 0.828 1 a
#> 2 1.57 2 b
#> 3 -1.54 3 c
#> 4 0.723 4 d
#> 5 1.07 5 e
#> 6 1.08 6 f
#> 7 -0.338 7 g
#> 8 -2.26 8 h
#> 9 0.198 9 i
#> 10 -1.58 10 j
```
<sup>Created on 2019-06-25 by the [reprex package](https://reprex.tidyverse.org) (v0.3.0.9000)</sup>
Author: Romain Francois <ro...@rstudio.com>
Closes #4690 from romainfrancois/ARROW-3811/struct_column_inference and squashes the following commits:
b24645a2 <Romain Francois> + test using data frame column in record batch
773ffdee <Romain Francois> implement CheckCompatibleStruct()
6e54b209 <Romain Francois> array() handles data frame columns -> struct arrays
2d3bdd9a <Romain Francois> type() can infer struct column types from data frames
---
r/src/array_from_vector.cpp | 73 ++++++++++++++++++++++++++++++++----
r/tests/testthat/test-Array.R | 24 ++++++++++++
r/tests/testthat/test-RecordBatch.R | 20 ++++++++++
r/tests/testthat/test-chunkedarray.R | 7 ++++
r/tests/testthat/test-type.R | 5 +++
5 files changed, 122 insertions(+), 7 deletions(-)
diff --git a/r/src/array_from_vector.cpp b/r/src/array_from_vector.cpp
index 0d4e318..ad66bea 100644
--- a/r/src/array_from_vector.cpp
+++ b/r/src/array_from_vector.cpp
@@ -163,6 +163,15 @@ std::shared_ptr<Array> MakeFactorArray(Rcpp::IntegerVector_ factor,
}
}
+std::shared_ptr<Array> MakeStructArray(SEXP df, const std::shared_ptr<DataType>& type) {
+ int n = type->num_children();
+ std::vector<std::shared_ptr<Array>> children(n);
+ for (int i = 0; i < n; i++) {
+ children[i] = Array__from_vector(VECTOR_ELT(df, i), type->child(i)->type(), true);
+ }
+ return std::make_shared<StructArray>(type, children[0]->length(), children);
+}
+
template <typename T>
int64_t time_cast(T value);
@@ -728,14 +737,12 @@ Status GetConverter(const std::shared_ptr<DataType>& type,
SIMPLE_CONVERTER_CASE(DATE32, Date32Converter);
SIMPLE_CONVERTER_CASE(DATE64, Date64Converter);
- // TODO: probably after we merge ARROW-3628
- // case Type::DECIMAL:
+ // TODO: probably after we merge ARROW-3628
+ // case Type::DECIMAL:
- case Type::DICTIONARY:
-
- TIME_CONVERTER_CASE(TIME32, Time32Type, Time32Converter);
- TIME_CONVERTER_CASE(TIME64, Time64Type, Time64Converter);
- TIME_CONVERTER_CASE(TIMESTAMP, TimestampType, TimestampConverter);
+ TIME_CONVERTER_CASE(TIME32, Time32Type, Time32Converter);
+ TIME_CONVERTER_CASE(TIME64, Time64Type, Time64Converter);
+ TIME_CONVERTER_CASE(TIMESTAMP, TimestampType, TimestampConverter);
default:
break;
@@ -801,6 +808,18 @@ std::shared_ptr<arrow::DataType> InferType(SEXP x) {
return int8();
case STRSXP:
return utf8();
+ case VECSXP:
+ if (Rf_inherits(x, "data.frame")) {
+ R_xlen_t n = XLENGTH(x);
+ SEXP names = Rf_getAttrib(x, R_NamesSymbol);
+ std::vector<std::shared_ptr<arrow::Field>> fields(n);
+ for (R_xlen_t i = 0; i < n; i++) {
+ fields[i] = std::make_shared<arrow::Field>(CHAR(STRING_ELT(names, i)),
+ InferType(VECTOR_ELT(x, i)));
+ }
+ return std::make_shared<StructType>(std::move(fields));
+ }
+ break;
default:
break;
}
@@ -920,6 +939,37 @@ bool CheckCompatibleFactor(SEXP obj, const std::shared_ptr<arrow::DataType>& typ
return dict_type->value_type() == utf8();
}
+arrow::Status CheckCompatibleStruct(SEXP obj,
+ const std::shared_ptr<arrow::DataType>& type) {
+ if (!Rf_inherits(obj, "data.frame")) {
+ return Status::RError("Conversion to struct arrays requires a data.frame");
+ }
+
+ // check the number of columns
+ int num_fields = type->num_children();
+ if (XLENGTH(obj) != num_fields) {
+ return Status::RError("Number of fields in struct (", num_fields,
+ ") incompatible with number of columns in the data frame (",
+ XLENGTH(obj), ")");
+ }
+
+ // check the names of each column
+ //
+ // the columns themselves are not checked against the
+ // types of the fields, because Array__from_vector will error
+ // when not compatible.
+ SEXP names = Rf_getAttrib(obj, R_NamesSymbol);
+ for (int i = 0; i < num_fields; i++) {
+ if (type->child(i)->name() != CHAR(STRING_ELT(names, i))) {
+ return Status::RError("Field name in position ", i, " (", type->child(i)->name(),
+ ") does not match the name of the column of the data frame (",
+ CHAR(STRING_ELT(names, i)), ")");
+ }
+ }
+
+ return Status::OK();
+}
+
std::shared_ptr<arrow::Array> Array__from_vector(
SEXP x, const std::shared_ptr<arrow::DataType>& type, bool type_infered) {
// short circuit if `x` is already an Array
@@ -948,6 +998,15 @@ std::shared_ptr<arrow::Array> Array__from_vector(
Rcpp::stop("Object incompatible with dictionary type");
}
+ // struct types
+ if (type->id() == Type::STRUCT) {
+ if (!type_infered) {
+ STOP_IF_NOT_OK(arrow::r::CheckCompatibleStruct(x, type));
+ }
+
+ return arrow::r::MakeStructArray(x, type);
+ }
+
// general conversion with converter and builder
std::unique_ptr<arrow::r::VectorConverter> converter;
STOP_IF_NOT_OK(arrow::r::GetConverter(type, &converter));
diff --git a/r/tests/testthat/test-Array.R b/r/tests/testthat/test-Array.R
index fd2fe51..36518cb 100644
--- a/r/tests/testthat/test-Array.R
+++ b/r/tests/testthat/test-Array.R
@@ -415,3 +415,27 @@ test_that("array() recognise arrow::Array (ARROW-3815)", {
a <- array(1:10)
expect_equal(a, array(a))
})
+
+test_that("array() handles data frame -> struct arrays (ARROW-3811)", {
+ df <- tibble::tibble(x = 1:10, y = x / 2, z = letters[1:10])
+ a <- array(df)
+ expect_equal(a$type, struct(x = int32(), y = float64(), z = utf8()))
+ expect_equivalent(a$as_vector(), df)
+})
+
+test_that("array() can handle data frame with custom struct type (not infered)", {
+ df <- tibble::tibble(x = 1:10, y = 1:10)
+ type <- struct(x = float64(), y = int16())
+ a <- array(df, type = type)
+ expect_equal(a$type, type)
+
+ type <- struct(x = float64(), y = int16(), z = int32())
+ expect_error(array(df, type = type), regexp = "Number of fields in struct.* incompatible with number of columns in the data frame")
+
+ type <- struct(y = int16(), x = float64())
+ expect_error(array(df, type = type), regexp = "Field name in position.*does not match the name of the column of the data frame")
+
+ type <- struct(x = float64(), y = utf8())
+ expect_error(array(df, type = type), regexp = "Cannot convert R object to string array")
+})
+
diff --git a/r/tests/testthat/test-RecordBatch.R b/r/tests/testthat/test-RecordBatch.R
index 7d4b27b..c77e1f1 100644
--- a/r/tests/testthat/test-RecordBatch.R
+++ b/r/tests/testthat/test-RecordBatch.R
@@ -150,3 +150,23 @@ test_that("record_batch() handles arrow::Array", {
batch <- record_batch(x = 1:10, y = arrow::array(1:10))
expect_equal(batch$schema, schema(x = int32(), y = int32()))
})
+
+test_that("record_batch() handles data frame columns", {
+ tib <- tibble::tibble(x = 1:10, y = 1:10)
+ batch <- record_batch(a = 1:10, b = tib)
+ expect_equal(batch$schema, schema(a = int32(), struct(x = int32(), y = int32())))
+ out <- as.data.frame(batch)
+ expect_equivalent(out, tibble::tibble(a = 1:10, b = tib))
+})
+
+test_that("record_batch() handles data frame columns with schema spec", {
+ tib <- tibble::tibble(x = 1:10, y = 1:10)
+ schema <- schema(a = int32(), b = struct(x = int16(), y = float64()))
+ batch <- record_batch(a = 1:10, b = tib, schema = schema)
+ expect_equal(batch$schema, schema)
+ out <- as.data.frame(batch)
+ expect_equivalent(out, tibble::tibble(a = 1:10, b = tib))
+
+ schema <- schema(a = int32(), b = struct(x = int16(), y = utf8()))
+ expect_error(record_batch(a = 1:10, b = tib, schema = schema))
+})
diff --git a/r/tests/testthat/test-chunkedarray.R b/r/tests/testthat/test-chunkedarray.R
index 9505b8f..b8ffe3d 100644
--- a/r/tests/testthat/test-chunkedarray.R
+++ b/r/tests/testthat/test-chunkedarray.R
@@ -272,3 +272,10 @@ test_that("chunked_array() can ingest arrays (ARROW-3815)", {
1:10
)
})
+
+test_that("chunked_array() handles data frame -> struct arrays (ARROW-3811)", {
+ df <- tibble::tibble(x = 1:10, y = x / 2, z = letters[1:10])
+ a <- chunked_array(df, df)
+ expect_equal(a$type, struct(x = int32(), y = float64(), z = utf8()))
+ expect_equivalent(a$as_vector(), rbind(df, df))
+})
diff --git a/r/tests/testthat/test-type.R b/r/tests/testthat/test-type.R
index a319033..70f8df6 100644
--- a/r/tests/testthat/test-type.R
+++ b/r/tests/testthat/test-type.R
@@ -50,3 +50,8 @@ test_that("type() infers from R type", {
int64()
)
})
+
+test_that("type() can infer struct types from data frames", {
+ df <- tibble::tibble(x = 1:10, y = rnorm(10), z = letters[1:10])
+ expect_equal(type(df), struct(x = int32(), y = float64(), z = utf8()))
+})