You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/04/08 18:38:09 UTC
[GitHub] [arrow] paleolimbot commented on a diff in pull request #12751: ARROW-15989: [R] rbind & cbind for Table & RecordBatch
paleolimbot commented on code in PR #12751:
URL: https://github.com/apache/arrow/pull/12751#discussion_r846372100
##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
#' @export
names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+ abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+ if (length != target_length) {
+ abort(
+ sprintf(
+ "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
Review Comment:
```suggestion
"Non-scalar inputs must have an equal number of rows. ..1 has %d, ..%d has %d",
```
##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
#' @export
names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+ abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+ if (length != target_length) {
+ abort(
+ sprintf(
+ "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
+ target_length,
+ idx,
+ length
+ )
+ )
+ }
+}
+
+#' @export
+cbind.RecordBatch <- function(...) {
+ inputs <- list(...)
+ num_rows <- inputs[[1]]$num_rows
+
+ batches <- imap(inputs, function(input, idx) {
+ if (inherits(input, "RecordBatch")) {
+ cbind_check_length(num_rows, input$num_rows, idx)
+ input
+ } else if (is.atomic(input) && length(input) == 1) {
+ RecordBatch$create("{idx}" := rep(input, num_rows))
+ } else {
+ tryCatch(
+ {
+ cbind_check_length(num_rows, length(input), idx)
+ RecordBatch$create("{idx}" := input)
+ },
+ error = function(err) {
+ abort(sprintf("Input ..%i cannot be converted to an Arrow Array: %s", idx, err))
+ }
+ )
+ }
Review Comment:
Does `data.frame()` work here?
##########
r/tests/testthat/test-RecordBatch.R:
##########
@@ -513,6 +513,53 @@ test_that("record_batch() with different length arrays", {
expect_error(record_batch(a = 1:5, b = 1:6), msg)
})
+test_that("RecordBatch doesn't support rbind", {
+ expect_error(
+ rbind(
+ RecordBatch$create(a = 1:10),
+ RecordBatch$create(a = 2:4)
+ ),
+ regexp = "Use Table\\$create"
+ )
+})
+
+test_that("RecordBatch supports cbind", {
+ expect_error(
+ cbind(
+ RecordBatch$create(a = 1:10, ),
+ RecordBatch$create(a = c("a", "b"))
+ ),
+ regexp = "Non-scalar inputs must have an equal number of rows"
+ )
+
+ batches <- list(
+ RecordBatch$create(a = c(1, 2), b = c("a", "b")),
+ RecordBatch$create(a = c("d", "c")),
+ RecordBatch$create(c = c(2, 3))
+ )
+
+ expected <- RecordBatch$create(
+ do.call(cbind, lapply(batches, function(batch) as.data.frame(batch)))
+ )
+ expect_equal(do.call(cbind, batches), expected, ignore_attr = TRUE)
+
+ # Handles a variety of input types
+ inputs <- list(
+ RecordBatch$create(a = 1:2),
+ b = Array$create(4:5),
+ c = factor(c("a", "b")),
+ d = 1L
+ )
+ r_inputs <- inputs
+ r_inputs[[1]] <- as.data.frame(r_inputs[[1]])
+ r_inputs[["b"]] <- as.vector(r_inputs[["b"]])
+
+ expected <- RecordBatch$create(do.call(cbind, r_inputs))
+ actual <- do.call(cbind, inputs)
+ expect_equal(expected, actual, ignore_attr = TRUE)
+ expect_equal(names(actual), c("a", "b", "c", "d"))
+})
Review Comment:
It's hard for me to visualize what bits of code you are testing here. Could it be rewritten as one or more of
```r
expect_equal(
rbind(...),
record_batch(...)
)
```
##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
#' @export
names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+ abort("Use Table$create to combine record batches")
Review Comment:
```suggestion
abort("Use `Table$create()` to combine record batches")
```
##########
r/R/table.R:
##########
@@ -149,6 +149,65 @@ Table$create <- function(..., schema = NULL) {
#' @export
names.Table <- function(x) x$ColumnNames()
+#' @export
+rbind.Table <- function(...) {
+ tables <- list(...)
+
+ # assert they have same schema
+ schema <- tables[[1]]$schema
+ unequal_schema_idx <- which.min(lapply(tables, function(x) x$schema == schema))
+ if (unequal_schema_idx != 1) {
+ stop(paste0(
Review Comment:
```suggestion
abort(paste0(
```
##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
#' @export
names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+ abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+ if (length != target_length) {
+ abort(
+ sprintf(
+ "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
+ target_length,
+ idx,
+ length
+ )
+ )
+ }
+}
+
+#' @export
+cbind.RecordBatch <- function(...) {
+ inputs <- list(...)
+ num_rows <- inputs[[1]]$num_rows
+
+ batches <- imap(inputs, function(input, idx) {
+ if (inherits(input, "RecordBatch")) {
+ cbind_check_length(num_rows, input$num_rows, idx)
+ input
+ } else if (is.atomic(input) && length(input) == 1) {
Review Comment:
I'm fairly sure `is.atomic()` doesn't mean what you think it means here. The recycling logic currently lives in https://github.com/apache/arrow/blob/master/r/R/util.R#L156-L184 and I wonder if you can't resuse some of that here.
##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
#' @export
names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+ abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+ if (length != target_length) {
+ abort(
+ sprintf(
+ "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
+ target_length,
+ idx,
+ length
+ )
+ )
+ }
+}
+
+#' @export
+cbind.RecordBatch <- function(...) {
+ inputs <- list(...)
+ num_rows <- inputs[[1]]$num_rows
+
+ batches <- imap(inputs, function(input, idx) {
Review Comment:
My reading of this is that `idx` might be an integer (if all arguments are unnamed), `""` (if some arguments were named), or `"the_arg_name"`. Your code below assumes some of both, and you probably want `lapply(seq_along(inputs), function(i) {...})` so that you can access both the name and the location when you need it.
##########
r/R/table.R:
##########
@@ -149,6 +149,65 @@ Table$create <- function(..., schema = NULL) {
#' @export
names.Table <- function(x) x$ColumnNames()
+#' @export
+rbind.Table <- function(...) {
+ tables <- list(...)
+
+ # assert they have same schema
+ schema <- tables[[1]]$schema
+ unequal_schema_idx <- which.min(lapply(tables, function(x) x$schema == schema))
+ if (unequal_schema_idx != 1) {
+ stop(paste0(
+ sprintf("Schema at index %i does not match the first schema\n", unequal_schema_idx),
+ "Schema 1:\n",
+ schema$ToString(),
+ sprintf("\nSchema %i:\n", unequal_schema_idx),
+ tables[[unequal_schema_idx]]$schema$ToString()
+ ))
+ }
+
+ # create chunked array for each column
+ columns <- map(1:tables[[1]]$num_columns, function(i) {
Review Comment:
```suggestion
columns <- map(seq_len(tables[[1]]$num_columns), function(i) {
```
(In case `$num_columns` is 0)
##########
r/tests/testthat/test-RecordBatch.R:
##########
@@ -513,6 +513,53 @@ test_that("record_batch() with different length arrays", {
expect_error(record_batch(a = 1:5, b = 1:6), msg)
})
+test_that("RecordBatch doesn't support rbind", {
+ expect_error(
Review Comment:
I think `expect_snapshot_error()` will give you the best test here (a few of these passed even though the error wasn't the one you intended)
##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
#' @export
names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+ abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+ if (length != target_length) {
+ abort(
+ sprintf(
+ "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
+ target_length,
+ idx,
+ length
+ )
+ )
+ }
+}
+
+#' @export
+cbind.RecordBatch <- function(...) {
+ inputs <- list(...)
+ num_rows <- inputs[[1]]$num_rows
+
+ batches <- imap(inputs, function(input, idx) {
+ if (inherits(input, "RecordBatch")) {
+ cbind_check_length(num_rows, input$num_rows, idx)
+ input
+ } else if (is.atomic(input) && length(input) == 1) {
+ RecordBatch$create("{idx}" := rep(input, num_rows))
+ } else {
+ tryCatch(
+ {
+ cbind_check_length(num_rows, length(input), idx)
+ RecordBatch$create("{idx}" := input)
+ },
+ error = function(err) {
+ abort(sprintf("Input ..%i cannot be converted to an Arrow Array: %s", idx, err))
Review Comment:
```suggestion
abort(sprintf("Error processing argument ..%d", idx), parent = err)
```
##########
r/R/table.R:
##########
@@ -149,6 +149,65 @@ Table$create <- function(..., schema = NULL) {
#' @export
names.Table <- function(x) x$ColumnNames()
+#' @export
+rbind.Table <- function(...) {
+ tables <- list(...)
+
+ # assert they have same schema
+ schema <- tables[[1]]$schema
+ unequal_schema_idx <- which.min(lapply(tables, function(x) x$schema == schema))
+ if (unequal_schema_idx != 1) {
+ stop(paste0(
+ sprintf("Schema at index %i does not match the first schema\n", unequal_schema_idx),
+ "Schema 1:\n",
+ schema$ToString(),
+ sprintf("\nSchema %i:\n", unequal_schema_idx),
+ tables[[unequal_schema_idx]]$schema$ToString()
+ ))
+ }
+
+ # create chunked array for each column
+ columns <- map(1:tables[[1]]$num_columns, function(i) {
+ do.call(c, map(tables, ~ .[[i]]))
+ })
+
+ Table$create(!!!set_names(columns, names(schema)), schema = schema)
+}
+
+#' @export
+cbind.Table <- function(...) {
+ inputs <- list(...)
+ num_rows <- inputs[[1]]$num_rows
+
+ tables <- imap(inputs, function(input, idx) {
Review Comment:
See above comments about the `cbind.RecordBatch()`
##########
r/R/table.R:
##########
@@ -149,6 +149,65 @@ Table$create <- function(..., schema = NULL) {
#' @export
names.Table <- function(x) x$ColumnNames()
+#' @export
+rbind.Table <- function(...) {
+ tables <- list(...)
+
+ # assert they have same schema
+ schema <- tables[[1]]$schema
+ unequal_schema_idx <- which.min(lapply(tables, function(x) x$schema == schema))
+ if (unequal_schema_idx != 1) {
+ stop(paste0(
+ sprintf("Schema at index %i does not match the first schema\n", unequal_schema_idx),
+ "Schema 1:\n",
+ schema$ToString(),
+ sprintf("\nSchema %i:\n", unequal_schema_idx),
Review Comment:
```suggestion
sprintf("\nSchema %d:\n", unequal_schema_idx),
```
##########
r/tests/testthat/test-RecordBatch.R:
##########
@@ -513,6 +513,53 @@ test_that("record_batch() with different length arrays", {
expect_error(record_batch(a = 1:5, b = 1:6), msg)
})
+test_that("RecordBatch doesn't support rbind", {
+ expect_error(
+ rbind(
+ RecordBatch$create(a = 1:10),
+ RecordBatch$create(a = 2:4)
+ ),
+ regexp = "Use Table\\$create"
+ )
+})
+
+test_that("RecordBatch supports cbind", {
+ expect_error(
+ cbind(
+ RecordBatch$create(a = 1:10, ),
+ RecordBatch$create(a = c("a", "b"))
+ ),
+ regexp = "Non-scalar inputs must have an equal number of rows"
+ )
+
+ batches <- list(
+ RecordBatch$create(a = c(1, 2), b = c("a", "b")),
+ RecordBatch$create(a = c("d", "c")),
+ RecordBatch$create(c = c(2, 3))
+ )
+
+ expected <- RecordBatch$create(
+ do.call(cbind, lapply(batches, function(batch) as.data.frame(batch)))
+ )
+ expect_equal(do.call(cbind, batches), expected, ignore_attr = TRUE)
Review Comment:
I think this will be clearer to future maintainers of this test if the full `record_batch(...)` and `data.frame()` is typed out.
##########
r/tests/testthat/test-Table.R:
##########
@@ -518,6 +518,77 @@ test_that("Table$create() no recycling with tibbles", {
)
})
+test_that("Table supports rbind", {
+ expect_error(
+ rbind(
+ Table$create(a = 1:10, b = Scalar$create(5)),
+ Table$create(a = c("a", "b"), b = Scalar$create(5))
+ ),
+ regexp = "Schema at index 2 does not match the first schema"
+ )
+
+ tables <- list(
+ Table$create(a = 1:10, b = Scalar$create("x")),
+ Table$create(a = 2:42, b = Scalar$create("y")),
+ Table$create(a = 8:10, b = Scalar$create("z"))
+ )
+ expected <- Table$create(do.call(rbind, lapply(tables, function(table) as.data.frame(table))))
+ actual <- do.call(rbind, tables)
+ expect_equal(actual, expected, ignore_attr = TRUE)
+})
+
+test_that("Table supports cbind", {
+ expect_error(
Review Comment:
See comments about the `cbind.RecordBatch` test
##########
r/tests/testthat/test-chunked-array.R:
##########
@@ -100,6 +100,18 @@ test_that("print ChunkedArray", {
})
})
+test_that("ChunkedArray can be concatenated with c()", {
+ a <- chunked_array(c(1, 2), 3)
+ b <- chunked_array(c(4, 5), 6)
+ expected <- chunked_array(c(1, 2), 3, c(4, 5), 6)
+ expect_equal(c(a, b), expected)
+
+ # Can handle Arrays and base vectors
+ vectors <- list(chunked_array(1:10), arrow::Array$create(1:10), 1:10)
Review Comment:
```suggestion
vectors <- list(chunked_array(1:10), Array$create(1:10), 1:10)
```
##########
r/tests/testthat/test-Table.R:
##########
@@ -518,6 +518,77 @@ test_that("Table$create() no recycling with tibbles", {
)
})
+test_that("Table supports rbind", {
+ expect_error(
+ rbind(
+ Table$create(a = 1:10, b = Scalar$create(5)),
+ Table$create(a = c("a", "b"), b = Scalar$create(5))
+ ),
+ regexp = "Schema at index 2 does not match the first schema"
+ )
+
+ tables <- list(
+ Table$create(a = 1:10, b = Scalar$create("x")),
+ Table$create(a = 2:42, b = Scalar$create("y")),
+ Table$create(a = 8:10, b = Scalar$create("z"))
+ )
+ expected <- Table$create(do.call(rbind, lapply(tables, function(table) as.data.frame(table))))
+ actual <- do.call(rbind, tables)
+ expect_equal(actual, expected, ignore_attr = TRUE)
+})
+
+test_that("Table supports cbind", {
+ expect_error(
+ cbind(
+ Table$create(a = 1:10, ),
+ Table$create(a = c("a", "b"))
+ ),
+ regexp = "Non-scalar inputs must have an equal number of rows"
+ )
+
+ tables <- list(
+ Table$create(a = 1:10, b = Scalar$create("x")),
+ Table$create(a = 11:20, b = Scalar$create("y")),
+ Table$create(c = rnorm(10))
+ )
+ expected <- Table$create(do.call(cbind, lapply(tables, function(table) as.data.frame(table))))
+ actual <- do.call(cbind, tables)
+ expect_equal(actual, expected, ignore_attr = TRUE)
+
+ # Handles a variety of input types
+ inputs <- list(
+ Table$create(a = 1L:2L),
+ b = Array$create(4:5),
+ c = factor(c("a", "b")),
+ d = 1L
+ )
+
+ r_inputs <- inputs
+ r_inputs[[1]] <- as.data.frame(r_inputs[[1]])
+ r_inputs[["b"]] <- as.vector(r_inputs[["b"]])
+
+ expected <- Table$create(do.call(cbind, r_inputs))
+ actual <- do.call(cbind, inputs)
+ expect_equal(expected, actual, ignore_attr = TRUE)
+ expect_equal(names(actual), c("a", "b", "c", "d"))
+})
+
+test_that("cbind.Table handles record batches and tables", {
+ # R 3.6 cbind dispatch rules cause cbind to fall back to default impl if
+ # there are multiple arguments with distinct cbind implementations
+ skip_if(getRversion() < "4.0.0", "R 3.6 cbind dispatch rules prevent this behavior")
+
+ inputs <- list(
+ Table$create(a = 1L:2L),
+ RecordBatch$create(b = 4:5)
+ )
+
+ expected <- Table$create(do.call(cbind, map(inputs, as.data.frame)))
+ actual <- do.call(cbind, inputs)
+ expect_equal(expected, actual, ignore_attr = TRUE)
+ expect_equal(names(actual), c("a", "b"))
Review Comment:
Again, the `do.call()` hides what is being tested...I think it will be clearer just typed out.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org