You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by np...@apache.org on 2021/06/16 23:20:23 UTC
[arrow] branch master updated: ARROW-11705: [R] Support scalar
value recycling in RecordBatch/Table$create()
This is an automated email from the ASF dual-hosted git repository.
npr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new dbcd0d9 ARROW-11705: [R] Support scalar value recycling in RecordBatch/Table$create()
dbcd0d9 is described below
commit dbcd0d944ce9cbf30e2e95468276a89450ac97cb
Author: Nic Crane <th...@gmail.com>
AuthorDate: Wed Jun 16 16:18:07 2021 -0700
ARROW-11705: [R] Support scalar value recycling in RecordBatch/Table$create()
This also adds missing spaces in some unrelated R files
Closes #10269 from thisisnic/ARROW-11705_scalar_recycling
Lead-authored-by: Nic Crane <th...@gmail.com>
Co-authored-by: Nic <th...@gmail.com>
Signed-off-by: Neal Richardson <ne...@gmail.com>
---
r/R/arrow-datum.R | 4 +--
r/R/arrow-package.R | 2 +-
r/R/arrow-tabular.R | 4 +--
r/R/arrowExports.R | 4 +--
r/R/chunked-array.R | 2 +-
r/R/compression.R | 4 +--
r/R/compute.R | 4 +--
r/R/csv.R | 6 ++---
r/R/enums.R | 4 +--
r/R/filesystem.R | 2 +-
r/R/metadata.R | 2 +-
r/R/parquet.R | 4 +--
r/R/record-batch.R | 3 +++
r/R/scalar.R | 2 +-
r/R/table.R | 19 ++++++++------
r/R/util.R | 45 ++++++++++++++++++++++++++++++++
r/data-raw/codegen.R | 8 +++---
r/extra-tests/helpers.R | 4 +--
r/extra-tests/write-files.R | 2 +-
r/man/recycle_scalars.Rd | 18 +++++++++++++
r/man/repeat_value_as_array.Rd | 20 +++++++++++++++
r/src/arrowExports.cpp | 11 ++++----
r/src/scalar.cpp | 4 +--
r/tests/testthat/helper-expectation.R | 6 ++---
r/tests/testthat/test-RecordBatch.R | 46 ++++++++++++++++++++++++++++-----
r/tests/testthat/test-Table.R | 48 ++++++++++++++++++++++++++++++++---
r/tests/testthat/test-dataset.R | 2 +-
r/tools/winlibs.R | 6 ++---
28 files changed, 226 insertions(+), 60 deletions(-)
diff --git a/r/R/arrow-datum.R b/r/R/arrow-datum.R
index 3be8d75..8becc37 100644
--- a/r/R/arrow-datum.R
+++ b/r/R/arrow-datum.R
@@ -128,7 +128,7 @@ eval_array_expression <- function(FUN,
}
#' @export
-na.omit.ArrowDatum <- function(object, ...){
+na.omit.ArrowDatum <- function(object, ...) {
object$Filter(!is.na(object))
}
@@ -136,7 +136,7 @@ na.omit.ArrowDatum <- function(object, ...){
na.exclude.ArrowDatum <- na.omit.ArrowDatum
#' @export
-na.fail.ArrowDatum <- function(object, ...){
+na.fail.ArrowDatum <- function(object, ...) {
if (object$null_count > 0) {
stop("missing values in object", call. = FALSE)
}
diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R
index 6843820..d2bf81c 100644
--- a/r/R/arrow-package.R
+++ b/r/R/arrow-package.R
@@ -279,7 +279,7 @@ ArrowObject <- R6Class("ArrowObject",
class_title <- class(self)[[1]]
}
cat(class_title, "\n", sep = "")
- if (!is.null(self$ToString)){
+ if (!is.null(self$ToString)) {
cat(self$ToString(), "\n", sep = "")
}
invisible(self)
diff --git a/r/R/arrow-tabular.R b/r/R/arrow-tabular.R
index f5535f9..440dcea 100644
--- a/r/R/arrow-tabular.R
+++ b/r/R/arrow-tabular.R
@@ -212,7 +212,7 @@ head.ArrowTabular <- head.ArrowDatum
tail.ArrowTabular <- tail.ArrowDatum
#' @export
-na.fail.ArrowTabular <- function(object, ...){
+na.fail.ArrowTabular <- function(object, ...) {
for (col in seq_len(object$num_columns)) {
if (object$column(col - 1L)$null_count > 0) {
stop("missing values in object", call. = FALSE)
@@ -222,7 +222,7 @@ na.fail.ArrowTabular <- function(object, ...){
}
#' @export
-na.omit.ArrowTabular <- function(object, ...){
+na.omit.ArrowTabular <- function(object, ...) {
not_na <- map(object$columns, ~call_function("is_valid", .x))
not_na_agg <- Reduce("&", not_na)
object$Filter(not_na_agg)
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index 45a0ea6..577773c 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -1548,8 +1548,8 @@ Scalar__as_vector <- function(scalar){
.Call(`_arrow_Scalar__as_vector`, scalar)
}
-MakeArrayFromScalar <- function(scalar){
- .Call(`_arrow_MakeArrayFromScalar`, scalar)
+MakeArrayFromScalar <- function(scalar, n){
+ .Call(`_arrow_MakeArrayFromScalar`, scalar, n)
}
Scalar__is_valid <- function(s){
diff --git a/r/R/chunked-array.R b/r/R/chunked-array.R
index fac1eeba..c58e5ac 100644
--- a/r/R/chunked-array.R
+++ b/r/R/chunked-array.R
@@ -83,7 +83,7 @@ ChunkedArray <- R6Class("ChunkedArray", inherit = ArrowDatum,
type_id = function() ChunkedArray__type(self)$id,
chunk = function(i) Array$create(ChunkedArray__chunk(self, i)),
as_vector = function() ChunkedArray__as_vector(self),
- Slice = function(offset, length = NULL){
+ Slice = function(offset, length = NULL) {
if (is.null(length)) {
ChunkedArray__Slice1(self, offset)
} else {
diff --git a/r/R/compression.R b/r/R/compression.R
index 8fd709f..499a75c 100644
--- a/r/R/compression.R
+++ b/r/R/compression.R
@@ -99,7 +99,7 @@ compression_from_name <- function(name) {
#' @export
#' @include arrow-package.R
CompressedOutputStream <- R6Class("CompressedOutputStream", inherit = OutputStream)
-CompressedOutputStream$create <- function(stream, codec = "gzip", compression_level = NA){
+CompressedOutputStream$create <- function(stream, codec = "gzip", compression_level = NA) {
codec <- Codec$create(codec, compression_level = compression_level)
if (is.string(stream)) {
stream <- FileOutputStream$create(stream)
@@ -113,7 +113,7 @@ CompressedOutputStream$create <- function(stream, codec = "gzip", compression_le
#' @format NULL
#' @export
CompressedInputStream <- R6Class("CompressedInputStream", inherit = InputStream)
-CompressedInputStream$create <- function(stream, codec = "gzip", compression_level = NA){
+CompressedInputStream$create <- function(stream, codec = "gzip", compression_level = NA) {
codec <- Codec$create(codec, compression_level = compression_level)
if (is.string(stream)) {
stream <- ReadableFile$create(stream)
diff --git a/r/R/compute.R b/r/R/compute.R
index 4d36f60..5a00e88 100644
--- a/r/R/compute.R
+++ b/r/R/compute.R
@@ -202,7 +202,7 @@ unique.ArrowDatum <- function(x, incomparables = FALSE, ...) {
}
#' @export
-any.ArrowDatum <- function(..., na.rm = FALSE){
+any.ArrowDatum <- function(..., na.rm = FALSE) {
a <- collect_arrays_from_dots(list(...))
result <- call_function("any", a)
@@ -217,7 +217,7 @@ any.ArrowDatum <- function(..., na.rm = FALSE){
}
#' @export
-all.ArrowDatum <- function(..., na.rm = FALSE){
+all.ArrowDatum <- function(..., na.rm = FALSE) {
a <- collect_arrays_from_dots(list(...))
result <- call_function("all", a)
diff --git a/r/R/csv.R b/r/R/csv.R
index 2708a53..1312a26 100644
--- a/r/R/csv.R
+++ b/r/R/csv.R
@@ -414,7 +414,7 @@ CsvReadOptions$create <- function(use_threads = option_use_threads(),
#' @rdname CsvReadOptions
#' @export
CsvWriteOptions <- R6Class("CsvWriteOptions", inherit = ArrowObject)
-CsvWriteOptions$create <- function(include_header = TRUE, batch_size = 1024L){
+CsvWriteOptions$create <- function(include_header = TRUE, batch_size = 1024L) {
assert_that(is_integerish(batch_size, n = 1, finite = TRUE), batch_size > 0)
csv___WriteOptions__initialize(
list(
@@ -637,9 +637,9 @@ write_csv_arrow <- function(x,
on.exit(sink$close())
}
- if(inherits(x, "RecordBatch")){
+ if (inherits(x, "RecordBatch")) {
csv___WriteCSV__RecordBatch(x, write_options, sink)
- } else if(inherits(x, "Table")){
+ } else if (inherits(x, "Table")) {
csv___WriteCSV__Table(x, write_options, sink)
}
diff --git a/r/R/enums.R b/r/R/enums.R
index ae44ccf..4271f2a 100644
--- a/r/R/enums.R
+++ b/r/R/enums.R
@@ -16,11 +16,11 @@
# under the License.
#' @export
-`print.arrow-enum` <- function(x, ...){
+`print.arrow-enum` <- function(x, ...) {
NextMethod()
}
-enum <- function(class, ..., .list = list(...)){
+enum <- function(class, ..., .list = list(...)) {
structure(
.list,
class = c(class, "arrow-enum")
diff --git a/r/R/filesystem.R b/r/R/filesystem.R
index 6761aca..283fbbb 100644
--- a/r/R/filesystem.R
+++ b/r/R/filesystem.R
@@ -203,7 +203,7 @@ FileSystem <- R6Class("FileSystem", inherit = ArrowObject,
GetFileInfo = function(x) {
if (inherits(x, "FileSelector")) {
fs___FileSystem__GetTargetInfos_FileSelector(self, x)
- } else if (is.character(x)){
+ } else if (is.character(x)) {
fs___FileSystem__GetTargetInfos_Paths(self, clean_path_rel(x))
} else {
abort("incompatible type for FileSystem$GetFileInfo()")
diff --git a/r/R/metadata.R b/r/R/metadata.R
index d3e5e21..408c221 100644
--- a/r/R/metadata.R
+++ b/r/R/metadata.R
@@ -59,7 +59,7 @@ apply_arrow_r_metadata <- function(x, r_metadata) {
x[[name]] <- apply_arrow_r_metadata(x[[name]], columns_metadata[[name]])
}
}
- } else if(is.list(x) && !inherits(x, "POSIXlt") && !is.null(columns_metadata)) {
+ } else if (is.list(x) && !inherits(x, "POSIXlt") && !is.null(columns_metadata)) {
x <- map2(x, columns_metadata, function(.x, .y) {
apply_arrow_r_metadata(.x, .y)
})
diff --git a/r/R/parquet.R b/r/R/parquet.R
index a9aef2c..3006fcb 100644
--- a/r/R/parquet.R
+++ b/r/R/parquet.R
@@ -296,7 +296,7 @@ ParquetWriterPropertiesBuilder <- R6Class("ParquetWriterPropertiesBuilder", inhe
parquet___ArrowWriterProperties___Builder__set_compressions
)
},
- set_compression_level = function(table, compression_level){
+ set_compression_level = function(table, compression_level) {
# cast to integer but keep names
compression_level <- set_names(as.integer(compression_level), names(compression_level))
private$.set(table, compression_level,
@@ -558,7 +558,7 @@ ParquetArrowReaderProperties <- R6Class("ParquetArrowReaderProperties",
),
active = list(
use_threads = function(use_threads) {
- if(missing(use_threads)) {
+ if (missing(use_threads)) {
parquet___arrow___ArrowReaderProperties__get_use_threads(self)
} else {
parquet___arrow___ArrowReaderProperties__set_use_threads(self, use_threads)
diff --git a/r/R/record-batch.R b/r/R/record-batch.R
index 1e41d65..0ba6b4b 100644
--- a/r/R/record-batch.R
+++ b/r/R/record-batch.R
@@ -162,6 +162,9 @@ RecordBatch$create <- function(..., schema = NULL) {
return(dplyr::group_by(out, !!!dplyr::groups(arrays[[1]])))
}
+ # If any arrays are length 1, recycle them
+ arrays <- recycle_scalars(arrays)
+
# TODO: should this also assert that they're all Arrays?
RecordBatch__from_arrays(schema, arrays)
}
diff --git a/r/R/scalar.R b/r/R/scalar.R
index 01a50b0..6e5e63c 100644
--- a/r/R/scalar.R
+++ b/r/R/scalar.R
@@ -58,7 +58,7 @@ Scalar <- R6Class("Scalar",
ToString = function() Scalar__ToString(self),
type_id = function() Scalar__type(self)$id,
as_vector = function() Scalar__as_vector(self),
- as_array = function() MakeArrayFromScalar(self),
+ as_array = function(length = 1L) MakeArrayFromScalar(self, as.integer(length)),
Equals = function(other, ...) {
inherits(other, "Scalar") && Scalar__Equals(self, other)
},
diff --git a/r/R/table.R b/r/R/table.R
index 09be952..3e5c52d 100644
--- a/r/R/table.R
+++ b/r/R/table.R
@@ -166,18 +166,21 @@ Table$create <- function(..., schema = NULL) {
names(dots) <- rep_len("", length(dots))
}
stopifnot(length(dots) > 0)
+
+ if (all_record_batches(dots)) {
+ return(Table__from_record_batches(dots, schema))
+ }
+
+ # If any arrays are length 1, recycle them
+ dots <- recycle_scalars(dots)
+ out <- Table__from_dots(dots, schema, option_use_threads())
+
# Preserve any grouping
if (length(dots) == 1 && inherits(dots[[1]], "grouped_df")) {
- out <- Table__from_dots(dots, schema, option_use_threads())
- return(dplyr::group_by(out, !!!dplyr::groups(dots[[1]])))
- }
-
- if (all_record_batches(dots)) {
- Table__from_record_batches(dots, schema)
- } else {
- Table__from_dots(dots, schema, option_use_threads())
+ out <- dplyr::group_by(out, !!!dplyr::groups(dots[[1]]))
}
+ out
}
#' @export
diff --git a/r/R/util.R b/r/R/util.R
index 8d1f51b..884c346 100644
--- a/r/R/util.R
+++ b/r/R/util.R
@@ -139,3 +139,48 @@ attr(is_writable_table, "fail") <- function(call, env){
)
}
+#' Recycle scalar values in a list of arrays
+#'
+#' @param arrays List of arrays
+#' @return List of arrays with any vector/Scalar/Array/ChunkedArray values of length 1 recycled
+#' @keywords internal
+recycle_scalars <- function(arrays){
+ # Get lengths of items in arrays
+ arr_lens <- map_int(arrays, NROW)
+
+ is_scalar <- arr_lens == 1
+
+ if (length(arrays) > 1 && any(is_scalar) && !all(is_scalar)) {
+
+ # Recycling not supported for tibbles and data.frames
+ if (all(map_lgl(arrays, ~inherits(.x, "data.frame")))) {
+
+ abort(c(
+ "All input tibbles or data.frames must have the same number of rows",
+ x = paste(
+ "Number of rows in longest and shortest inputs:",
+ oxford_paste(c(max(arr_lens), min(arr_lens)))
+ )
+ ))
+ }
+
+ max_array_len <- max(arr_lens)
+ arrays[is_scalar] <- lapply(arrays[is_scalar], repeat_value_as_array, max_array_len)
+ }
+ arrays
+}
+
+#' Take an object of length 1 and repeat it.
+#'
+#' @param object Object of length 1 to be repeated - vector, `Scalar`, `Array`, or `ChunkedArray`
+#' @param n Number of repetitions
+#'
+#' @return `Array` of length `n`
+#'
+#' @keywords internal
+repeat_value_as_array <- function(object, n) {
+ if (inherits(object, "ChunkedArray")) {
+ return(Scalar$create(object$chunks[[1]])$as_array(n))
+ }
+ return(Scalar$create(object)$as_array(n))
+}
\ No newline at end of file
diff --git a/r/data-raw/codegen.R b/r/data-raw/codegen.R
index ad4514a..9b25cb1 100644
--- a/r/data-raw/codegen.R
+++ b/r/data-raw/codegen.R
@@ -67,13 +67,13 @@ get_exported_functions <- function(decorations, export_tag) {
glue_collapse_data <- function(data, ..., sep = ", ", last = "") {
res <- glue_collapse(glue_data(data, ...), sep = sep, last = last)
- if(length(res) == 0) res <- ""
+ if (length(res) == 0) res <- ""
res
}
wrap_call <- function(name, return_type, args) {
call <- glue::glue('{name}({list_params})', list_params = glue_collapse_data(args, "{name}"))
- if(return_type == "void") {
+ if (return_type == "void") {
glue::glue("\t{call};\n\treturn R_NilValue;", .trim = FALSE)
} else {
glue::glue("\treturn cpp11::as_sexp({call});")
@@ -149,7 +149,7 @@ cpp_functions_definitions <- arrow_exports %>%
sep = "\n",
real_params = glue_collapse_data(args, "{type} {name}"),
input_params = glue_collapse_data(args, "\tarrow::r::Input<{type}>::type {name}({name}_sexp);", sep = "\n"),
- return_line = if(nrow(args)) "\n" else "")
+ return_line = if (nrow(args)) "\n" else "")
glue::glue('
// {basename(file)}
@@ -162,7 +162,7 @@ cpp_functions_definitions <- arrow_exports %>%
cpp_functions_registration <- arrow_exports %>%
select(name, return_type, args) %>%
- pmap_chr(function(name, return_type, args){
+ pmap_chr(function(name, return_type, args) {
glue('\t\t{{ "_arrow_{name}", (DL_FUNC) &_arrow_{name}, {nrow(args)}}}, ')
}) %>%
glue_collapse(sep = "\n")
diff --git a/r/extra-tests/helpers.R b/r/extra-tests/helpers.R
index af57d45..3fb450e 100644
--- a/r/extra-tests/helpers.R
+++ b/r/extra-tests/helpers.R
@@ -24,13 +24,13 @@ if_version_less_than <- function(version) {
}
skip_if_version_less_than <- function(version, msg) {
- if(if_version(version, `<`)) {
+ if (if_version(version, `<`)) {
skip(msg)
}
}
skip_if_version_equals <- function(version, msg) {
- if(if_version(version, `==`)) {
+ if (if_version(version, `==`)) {
skip(msg)
}
}
diff --git a/r/extra-tests/write-files.R b/r/extra-tests/write-files.R
index 75889b6..e11405d 100644
--- a/r/extra-tests/write-files.R
+++ b/r/extra-tests/write-files.R
@@ -26,7 +26,7 @@ source("tests/testthat/helper-data.R")
write_parquet(example_with_metadata, "extra-tests/files/ex_data.parquet")
for (comp in c("lz4", "uncompressed", "zstd")) {
- if(!codec_is_available(comp)) break
+ if (!codec_is_available(comp)) break
name <- paste0("extra-tests/files/ex_data_", comp, ".feather")
write_feather(example_with_metadata, name, compression = comp)
diff --git a/r/man/recycle_scalars.Rd b/r/man/recycle_scalars.Rd
new file mode 100644
index 0000000..3d97ecf
--- /dev/null
+++ b/r/man/recycle_scalars.Rd
@@ -0,0 +1,18 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/util.R
+\name{recycle_scalars}
+\alias{recycle_scalars}
+\title{Recycle scalar values in a list of arrays}
+\usage{
+recycle_scalars(arrays)
+}
+\arguments{
+\item{arrays}{List of arrays}
+}
+\value{
+List of arrays with any vector/Scalar/Array/ChunkedArray values of length 1 recycled
+}
+\description{
+Recycle scalar values in a list of arrays
+}
+\keyword{internal}
diff --git a/r/man/repeat_value_as_array.Rd b/r/man/repeat_value_as_array.Rd
new file mode 100644
index 0000000..a493732
--- /dev/null
+++ b/r/man/repeat_value_as_array.Rd
@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/util.R
+\name{repeat_value_as_array}
+\alias{repeat_value_as_array}
+\title{Take an object of length 1 and repeat it.}
+\usage{
+repeat_value_as_array(object, n)
+}
+\arguments{
+\item{object}{Object of length 1 to be repeated - vector, \code{Scalar}, \code{Array}, or \code{ChunkedArray}}
+
+\item{n}{Number of repetitions}
+}
+\value{
+\code{Array} of length \code{n}
+}
+\description{
+Take an object of length 1 and repeat it.
+}
+\keyword{internal}
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index 2024483..024e5c5 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -6091,15 +6091,16 @@ extern "C" SEXP _arrow_Scalar__as_vector(SEXP scalar_sexp){
// scalar.cpp
#if defined(ARROW_R_WITH_ARROW)
-std::shared_ptr<arrow::Array> MakeArrayFromScalar(const std::shared_ptr<arrow::Scalar>& scalar);
-extern "C" SEXP _arrow_MakeArrayFromScalar(SEXP scalar_sexp){
+std::shared_ptr<arrow::Array> MakeArrayFromScalar(const std::shared_ptr<arrow::Scalar>& scalar, int n);
+extern "C" SEXP _arrow_MakeArrayFromScalar(SEXP scalar_sexp, SEXP n_sexp){
BEGIN_CPP11
arrow::r::Input<const std::shared_ptr<arrow::Scalar>&>::type scalar(scalar_sexp);
- return cpp11::as_sexp(MakeArrayFromScalar(scalar));
+ arrow::r::Input<int>::type n(n_sexp);
+ return cpp11::as_sexp(MakeArrayFromScalar(scalar, n));
END_CPP11
}
#else
-extern "C" SEXP _arrow_MakeArrayFromScalar(SEXP scalar_sexp){
+extern "C" SEXP _arrow_MakeArrayFromScalar(SEXP scalar_sexp, SEXP n_sexp){
Rf_error("Cannot call MakeArrayFromScalar(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. ");
}
#endif
@@ -7279,7 +7280,7 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_StructScalar__field", (DL_FUNC) &_arrow_StructScalar__field, 2},
{ "_arrow_StructScalar__GetFieldByName", (DL_FUNC) &_arrow_StructScalar__GetFieldByName, 2},
{ "_arrow_Scalar__as_vector", (DL_FUNC) &_arrow_Scalar__as_vector, 1},
- { "_arrow_MakeArrayFromScalar", (DL_FUNC) &_arrow_MakeArrayFromScalar, 1},
+ { "_arrow_MakeArrayFromScalar", (DL_FUNC) &_arrow_MakeArrayFromScalar, 2},
{ "_arrow_Scalar__is_valid", (DL_FUNC) &_arrow_Scalar__is_valid, 1},
{ "_arrow_Scalar__type", (DL_FUNC) &_arrow_Scalar__type, 1},
{ "_arrow_Scalar__Equals", (DL_FUNC) &_arrow_Scalar__Equals, 2},
diff --git a/r/src/scalar.cpp b/r/src/scalar.cpp
index 057e587..5450a6f 100644
--- a/r/src/scalar.cpp
+++ b/r/src/scalar.cpp
@@ -70,8 +70,8 @@ SEXP Scalar__as_vector(const std::shared_ptr<arrow::Scalar>& scalar) {
// [[arrow::export]]
std::shared_ptr<arrow::Array> MakeArrayFromScalar(
- const std::shared_ptr<arrow::Scalar>& scalar) {
- return ValueOrStop(arrow::MakeArrayFromScalar(*scalar, 1, gc_memory_pool()));
+ const std::shared_ptr<arrow::Scalar>& scalar, int n) {
+ return ValueOrStop(arrow::MakeArrayFromScalar(*scalar, n, gc_memory_pool()));
}
// [[arrow::export]]
diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R
index 5b6958a..b815515 100644
--- a/r/tests/testthat/helper-expectation.R
+++ b/r/tests/testthat/helper-expectation.R
@@ -16,7 +16,7 @@
# under the License.
expect_as_vector <- function(x, y, ignore_attr = FALSE, ...) {
- expect_fun <- if(ignore_attr){
+ expect_fun <- if (ignore_attr) {
expect_equivalent
} else {
expect_equal
@@ -28,7 +28,7 @@ expect_data_frame <- function(x, y, ...) {
expect_equal(as.data.frame(x), y, ...)
}
-expect_r6_class <- function(object, class){
+expect_r6_class <- function(object, class) {
expect_s3_class(object, class)
expect_s3_class(object, "R6")
}
@@ -255,7 +255,7 @@ expect_vector_error <- function(expr, # A vectorized R expression containing `in
}
}
-split_vector_as_list <- function(vec){
+split_vector_as_list <- function(vec) {
vec_split <- length(vec) %/% 2
vec1 <- vec[seq(from = min(1, length(vec) - 1), to = min(length(vec) - 1, vec_split), by = 1)]
vec2 <- vec[seq(from = min(length(vec), vec_split + 1), to = length(vec), by = 1)]
diff --git a/r/tests/testthat/test-RecordBatch.R b/r/tests/testthat/test-RecordBatch.R
index beb1306..6617805 100644
--- a/r/tests/testthat/test-RecordBatch.R
+++ b/r/tests/testthat/test-RecordBatch.R
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-
test_that("RecordBatch", {
# Note that we're reusing `tbl` and `batch` throughout the tests in this file
tbl <- tibble::tibble(
@@ -415,14 +414,50 @@ test_that("record_batch() handles null type (ARROW-7064)", {
expect_equivalent(batch$schema, schema(a = int32(), n = null()))
})
-test_that("record_batch() scalar recycling", {
- skip("Not implemented (ARROW-11705)")
+test_that("record_batch() scalar recycling with vectors", {
expect_data_frame(
record_batch(a = 1:10, b = 5),
tibble::tibble(a = 1:10, b = 5)
)
})
+test_that("record_batch() scalar recycling with Scalars, Arrays, and ChunkedArrays", {
+
+ expect_data_frame(
+ record_batch(a = Array$create(1:10), b = Scalar$create(5)),
+ tibble::tibble(a = 1:10, b = 5)
+ )
+
+ expect_data_frame(
+ record_batch(a = Array$create(1:10), b = Array$create(5)),
+ tibble::tibble(a = 1:10, b = 5)
+ )
+
+ expect_data_frame(
+ record_batch(a = Array$create(1:10), b = ChunkedArray$create(5)),
+ tibble::tibble(a = 1:10, b = 5)
+ )
+
+})
+
+test_that("record_batch() no recycling with tibbles", {
+ expect_error(
+ record_batch(
+ tibble::tibble(a = 1:10),
+ tibble::tibble(a = 1, b = 5)
+ ),
+ regexp = "All input tibbles or data.frames must have the same number of rows"
+ )
+
+ expect_error(
+ record_batch(
+ tibble::tibble(a = 1:10),
+ tibble::tibble(a = 1)
+ ),
+ regexp = "All input tibbles or data.frames must have the same number of rows"
+ )
+})
+
test_that("RecordBatch$Equals", {
df <- tibble::tibble(x = 1:10, y = letters[1:10])
a <- record_batch(df)
@@ -435,7 +470,7 @@ test_that("RecordBatch$Equals", {
test_that("RecordBatch$Equals(check_metadata)", {
df <- tibble::tibble(x = 1:2, y = c("a", "b"))
rb1 <- record_batch(df)
- rb2 <- record_batch(df, schema = rb1$schema$WithMetadata(list(some="metadata")))
+ rb2 <- record_batch(df, schema = rb1$schema$WithMetadata(list(some = "metadata")))
expect_r6_class(rb1, "RecordBatch")
expect_r6_class(rb2, "RecordBatch")
@@ -467,8 +502,7 @@ test_that("RecordBatch name assignment", {
test_that("record_batch() with different length arrays", {
msg <- "All arrays must have the same length"
- expect_error(record_batch(a=1:5, b = 42), msg)
- expect_error(record_batch(a=1:5, b = 1:6), msg)
+ expect_error(record_batch(a = 1:5, b = 1:6), msg)
})
test_that("Handling string data with embedded nuls", {
diff --git a/r/tests/testthat/test-Table.R b/r/tests/testthat/test-Table.R
index 1f96288..6dd36b2 100644
--- a/r/tests/testthat/test-Table.R
+++ b/r/tests/testthat/test-Table.R
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-
test_that("read_table handles various input streams (ARROW-3450, ARROW-3505)", {
tbl <- tibble::tibble(
int = 1:10, dbl = as.numeric(1:10),
@@ -471,8 +470,51 @@ test_that("Table name assignment", {
test_that("Table$create() with different length columns", {
msg <- "All columns must have the same length"
- expect_error(Table$create(a=1:5, b = 42), msg)
- expect_error(Table$create(a=1:5, b = 1:6), msg)
+ expect_error(Table$create(a = 1:5, b = 1:6), msg)
+})
+
+test_that("Table$create() scalar recycling with vectors", {
+ expect_data_frame(
+ Table$create(a = 1:10, b = 5),
+ tibble::tibble(a = 1:10, b = 5)
+ )
+})
+
+test_that("Table$create() scalar recycling with Scalars, Arrays, and ChunkedArrays", {
+
+ expect_data_frame(
+ Table$create(a = Array$create(1:10), b = Scalar$create(5)),
+ tibble::tibble(a = 1:10, b = 5)
+ )
+
+ expect_data_frame(
+ Table$create(a = Array$create(1:10), b = Array$create(5)),
+ tibble::tibble(a = 1:10, b = 5)
+ )
+
+ expect_data_frame(
+ Table$create(a = Array$create(1:10), b = ChunkedArray$create(5)),
+ tibble::tibble(a = 1:10, b = 5)
+ )
+
+})
+
+test_that("Table$create() no recycling with tibbles", {
+ expect_error(
+ Table$create(
+ tibble::tibble(a = 1:10, b = 5),
+ tibble::tibble(a = 1, b = 5)
+ ),
+ regexp = "All input tibbles or data.frames must have the same number of rows"
+ )
+
+ expect_error(
+ Table$create(
+ tibble::tibble(a = 1:10, b = 5),
+ tibble::tibble(a = 1)
+ ),
+ regexp = "All input tibbles or data.frames must have the same number of rows"
+ )
})
test_that("ARROW-11769 - grouping preserved in table creation", {
diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R
index d84ed03..ad3e7c3 100644
--- a/r/tests/testthat/test-dataset.R
+++ b/r/tests/testthat/test-dataset.R
@@ -90,7 +90,7 @@ test_that("Setup (putting data in the dir)", {
expect_length(dir(tsv_dir, recursive = TRUE), 2)
})
-if(arrow_with_parquet()) {
+if (arrow_with_parquet()) {
files <- c(
file.path(dataset_dir, 1, "file1.parquet", fsep = "/"),
file.path(dataset_dir, 2, "file2.parquet", fsep = "/")
diff --git a/r/tools/winlibs.R b/r/tools/winlibs.R
index f90becb..ccaa5c9 100644
--- a/r/tools/winlibs.R
+++ b/r/tools/winlibs.R
@@ -17,12 +17,12 @@
args <- commandArgs(TRUE)
VERSION <- args[1]
-if(!file.exists(sprintf("windows/arrow-%s/include/arrow/api.h", VERSION))){
- if(length(args) > 1){
+if (!file.exists(sprintf("windows/arrow-%s/include/arrow/api.h", VERSION))) {
+ if (length(args) > 1) {
# Arg 2 would be the path/to/lib.zip
localfile <- args[2]
cat(sprintf("*** Using RWINLIB_LOCAL %s\n", localfile))
- if(!file.exists(localfile)){
+ if (!file.exists(localfile)) {
cat(sprintf("*** %s does not exist; build will fail\n", localfile))
}
file.copy(localfile, "lib.zip")