You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/04/08 18:38:09 UTC

[GitHub] [arrow] paleolimbot commented on a diff in pull request #12751: ARROW-15989: [R] rbind & cbind for Table & RecordBatch

paleolimbot commented on code in PR #12751:
URL: https://github.com/apache/arrow/pull/12751#discussion_r846372100


##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
 
 #' @export
 names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+  abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+  if (length != target_length) {
+    abort(
+      sprintf(
+        "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",

Review Comment:
   ```suggestion
           "Non-scalar inputs must have an equal number of rows. ..1 has %d, ..%d has %d",
   ```



##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
 
 #' @export
 names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+  abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+  if (length != target_length) {
+    abort(
+      sprintf(
+        "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
+        target_length,
+        idx,
+        length
+      )
+    )
+  }
+}
+
+#' @export
+cbind.RecordBatch <- function(...) {
+  inputs <- list(...)
+  num_rows <- inputs[[1]]$num_rows
+
+  batches <- imap(inputs, function(input, idx) {
+    if (inherits(input, "RecordBatch")) {
+      cbind_check_length(num_rows, input$num_rows, idx)
+      input
+    } else if (is.atomic(input) && length(input) == 1) {
+      RecordBatch$create("{idx}" := rep(input, num_rows))
+    } else {
+      tryCatch(
+        {
+          cbind_check_length(num_rows, length(input), idx)
+          RecordBatch$create("{idx}" := input)
+        },
+        error = function(err) {
+          abort(sprintf("Input ..%i cannot be converted to an Arrow Array: %s", idx, err))
+        }
+      )
+    }

Review Comment:
   Does `data.frame()` work here?



##########
r/tests/testthat/test-RecordBatch.R:
##########
@@ -513,6 +513,53 @@ test_that("record_batch() with different length arrays", {
   expect_error(record_batch(a = 1:5, b = 1:6), msg)
 })
 
+test_that("RecordBatch doesn't support rbind", {
+  expect_error(
+    rbind(
+      RecordBatch$create(a = 1:10),
+      RecordBatch$create(a = 2:4)
+    ),
+    regexp = "Use Table\\$create"
+  )
+})
+
+test_that("RecordBatch supports cbind", {
+  expect_error(
+    cbind(
+      RecordBatch$create(a = 1:10, ),
+      RecordBatch$create(a = c("a", "b"))
+    ),
+    regexp = "Non-scalar inputs must have an equal number of rows"
+  )
+
+  batches <- list(
+    RecordBatch$create(a = c(1, 2), b = c("a", "b")),
+    RecordBatch$create(a = c("d", "c")),
+    RecordBatch$create(c = c(2, 3))
+  )
+
+  expected <- RecordBatch$create(
+    do.call(cbind, lapply(batches, function(batch) as.data.frame(batch)))
+  )
+  expect_equal(do.call(cbind, batches), expected, ignore_attr = TRUE)
+
+  # Handles a variety of input types
+  inputs <- list(
+    RecordBatch$create(a = 1:2),
+    b = Array$create(4:5),
+    c = factor(c("a", "b")),
+    d = 1L
+  )
+  r_inputs <- inputs
+  r_inputs[[1]] <- as.data.frame(r_inputs[[1]])
+  r_inputs[["b"]] <- as.vector(r_inputs[["b"]])
+
+  expected <- RecordBatch$create(do.call(cbind, r_inputs))
+  actual <- do.call(cbind, inputs)
+  expect_equal(expected, actual, ignore_attr = TRUE)
+  expect_equal(names(actual), c("a", "b", "c", "d"))
+})

Review Comment:
   It's hard for me to visualize what bits of code you are testing here. Could it be rewritten as one or more of
   
   ```r
   expect_equal(
     rbind(...),
     record_batch(...)
   )
   ```



##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
 
 #' @export
 names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+  abort("Use Table$create to combine record batches")

Review Comment:
   ```suggestion
     abort("Use `Table$create()` to combine record batches")
   ```



##########
r/R/table.R:
##########
@@ -149,6 +149,65 @@ Table$create <- function(..., schema = NULL) {
 #' @export
 names.Table <- function(x) x$ColumnNames()
 
+#' @export
+rbind.Table <- function(...) {
+  tables <- list(...)
+
+  # assert they have same schema
+  schema <- tables[[1]]$schema
+  unequal_schema_idx <- which.min(lapply(tables, function(x) x$schema == schema))
+  if (unequal_schema_idx != 1) {
+    stop(paste0(

Review Comment:
   ```suggestion
       abort(paste0(
   ```



##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
 
 #' @export
 names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+  abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+  if (length != target_length) {
+    abort(
+      sprintf(
+        "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
+        target_length,
+        idx,
+        length
+      )
+    )
+  }
+}
+
+#' @export
+cbind.RecordBatch <- function(...) {
+  inputs <- list(...)
+  num_rows <- inputs[[1]]$num_rows
+
+  batches <- imap(inputs, function(input, idx) {
+    if (inherits(input, "RecordBatch")) {
+      cbind_check_length(num_rows, input$num_rows, idx)
+      input
+    } else if (is.atomic(input) && length(input) == 1) {

Review Comment:
   I'm fairly sure `is.atomic()` doesn't mean what you think it means here. The recycling logic currently lives in https://github.com/apache/arrow/blob/master/r/R/util.R#L156-L184 and I wonder if you can't resuse some of that here.



##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
 
 #' @export
 names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+  abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+  if (length != target_length) {
+    abort(
+      sprintf(
+        "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
+        target_length,
+        idx,
+        length
+      )
+    )
+  }
+}
+
+#' @export
+cbind.RecordBatch <- function(...) {
+  inputs <- list(...)
+  num_rows <- inputs[[1]]$num_rows
+
+  batches <- imap(inputs, function(input, idx) {

Review Comment:
   My reading of this is that `idx` might be an integer (if all arguments are unnamed), `""` (if some arguments were named), or `"the_arg_name"`. Your code below assumes some of both, and you probably want `lapply(seq_along(inputs), function(i) {...})` so that you can access both the name and the location when you need it.



##########
r/R/table.R:
##########
@@ -149,6 +149,65 @@ Table$create <- function(..., schema = NULL) {
 #' @export
 names.Table <- function(x) x$ColumnNames()
 
+#' @export
+rbind.Table <- function(...) {
+  tables <- list(...)
+
+  # assert they have same schema
+  schema <- tables[[1]]$schema
+  unequal_schema_idx <- which.min(lapply(tables, function(x) x$schema == schema))
+  if (unequal_schema_idx != 1) {
+    stop(paste0(
+      sprintf("Schema at index %i does not match the first schema\n", unequal_schema_idx),
+      "Schema 1:\n",
+      schema$ToString(),
+      sprintf("\nSchema %i:\n", unequal_schema_idx),
+      tables[[unequal_schema_idx]]$schema$ToString()
+    ))
+  }
+
+  # create chunked array for each column
+  columns <- map(1:tables[[1]]$num_columns, function(i) {

Review Comment:
   ```suggestion
     columns <- map(seq_len(tables[[1]]$num_columns), function(i) {
   ```
   
   (In case `$num_columns` is 0)



##########
r/tests/testthat/test-RecordBatch.R:
##########
@@ -513,6 +513,53 @@ test_that("record_batch() with different length arrays", {
   expect_error(record_batch(a = 1:5, b = 1:6), msg)
 })
 
+test_that("RecordBatch doesn't support rbind", {
+  expect_error(

Review Comment:
   I think `expect_snapshot_error()` will give you the best test here (a few of these passed even though the error wasn't the one you intended)



##########
r/R/record-batch.R:
##########
@@ -189,3 +189,52 @@ record_batch <- RecordBatch$create
 
 #' @export
 names.RecordBatch <- function(x) x$names()
+
+#' @export
+rbind.RecordBatch <- function(...) {
+  abort("Use Table$create to combine record batches")
+}
+
+cbind_check_length <- function(target_length, length, idx) {
+  if (length != target_length) {
+    abort(
+      sprintf(
+        "Non-scalar inputs must have an equal number of rows. ..1 has %i, ..%i has %i",
+        target_length,
+        idx,
+        length
+      )
+    )
+  }
+}
+
+#' @export
+cbind.RecordBatch <- function(...) {
+  inputs <- list(...)
+  num_rows <- inputs[[1]]$num_rows
+
+  batches <- imap(inputs, function(input, idx) {
+    if (inherits(input, "RecordBatch")) {
+      cbind_check_length(num_rows, input$num_rows, idx)
+      input
+    } else if (is.atomic(input) && length(input) == 1) {
+      RecordBatch$create("{idx}" := rep(input, num_rows))
+    } else {
+      tryCatch(
+        {
+          cbind_check_length(num_rows, length(input), idx)
+          RecordBatch$create("{idx}" := input)
+        },
+        error = function(err) {
+          abort(sprintf("Input ..%i cannot be converted to an Arrow Array: %s", idx, err))

Review Comment:
   ```suggestion
             abort(sprintf("Error processing argument ..%d", idx), parent = err)
   ```



##########
r/R/table.R:
##########
@@ -149,6 +149,65 @@ Table$create <- function(..., schema = NULL) {
 #' @export
 names.Table <- function(x) x$ColumnNames()
 
+#' @export
+rbind.Table <- function(...) {
+  tables <- list(...)
+
+  # assert they have same schema
+  schema <- tables[[1]]$schema
+  unequal_schema_idx <- which.min(lapply(tables, function(x) x$schema == schema))
+  if (unequal_schema_idx != 1) {
+    stop(paste0(
+      sprintf("Schema at index %i does not match the first schema\n", unequal_schema_idx),
+      "Schema 1:\n",
+      schema$ToString(),
+      sprintf("\nSchema %i:\n", unequal_schema_idx),
+      tables[[unequal_schema_idx]]$schema$ToString()
+    ))
+  }
+
+  # create chunked array for each column
+  columns <- map(1:tables[[1]]$num_columns, function(i) {
+    do.call(c, map(tables, ~ .[[i]]))
+  })
+
+  Table$create(!!!set_names(columns, names(schema)), schema = schema)
+}
+
+#' @export
+cbind.Table <- function(...) {
+  inputs <- list(...)
+  num_rows <- inputs[[1]]$num_rows
+
+  tables <- imap(inputs, function(input, idx) {

Review Comment:
   See above comments about the `cbind.RecordBatch()`



##########
r/R/table.R:
##########
@@ -149,6 +149,65 @@ Table$create <- function(..., schema = NULL) {
 #' @export
 names.Table <- function(x) x$ColumnNames()
 
+#' @export
+rbind.Table <- function(...) {
+  tables <- list(...)
+
+  # assert they have same schema
+  schema <- tables[[1]]$schema
+  unequal_schema_idx <- which.min(lapply(tables, function(x) x$schema == schema))
+  if (unequal_schema_idx != 1) {
+    stop(paste0(
+      sprintf("Schema at index %i does not match the first schema\n", unequal_schema_idx),
+      "Schema 1:\n",
+      schema$ToString(),
+      sprintf("\nSchema %i:\n", unequal_schema_idx),

Review Comment:
   ```suggestion
         sprintf("\nSchema %d:\n", unequal_schema_idx),
   ```



##########
r/tests/testthat/test-RecordBatch.R:
##########
@@ -513,6 +513,53 @@ test_that("record_batch() with different length arrays", {
   expect_error(record_batch(a = 1:5, b = 1:6), msg)
 })
 
+test_that("RecordBatch doesn't support rbind", {
+  expect_error(
+    rbind(
+      RecordBatch$create(a = 1:10),
+      RecordBatch$create(a = 2:4)
+    ),
+    regexp = "Use Table\\$create"
+  )
+})
+
+test_that("RecordBatch supports cbind", {
+  expect_error(
+    cbind(
+      RecordBatch$create(a = 1:10, ),
+      RecordBatch$create(a = c("a", "b"))
+    ),
+    regexp = "Non-scalar inputs must have an equal number of rows"
+  )
+
+  batches <- list(
+    RecordBatch$create(a = c(1, 2), b = c("a", "b")),
+    RecordBatch$create(a = c("d", "c")),
+    RecordBatch$create(c = c(2, 3))
+  )
+
+  expected <- RecordBatch$create(
+    do.call(cbind, lapply(batches, function(batch) as.data.frame(batch)))
+  )
+  expect_equal(do.call(cbind, batches), expected, ignore_attr = TRUE)

Review Comment:
   I think this will be clearer to future maintainers of this test if the full `record_batch(...)` and `data.frame()` is typed out.



##########
r/tests/testthat/test-Table.R:
##########
@@ -518,6 +518,77 @@ test_that("Table$create() no recycling with tibbles", {
   )
 })
 
+test_that("Table supports rbind", {
+  expect_error(
+    rbind(
+      Table$create(a = 1:10, b = Scalar$create(5)),
+      Table$create(a = c("a", "b"), b = Scalar$create(5))
+    ),
+    regexp = "Schema at index 2 does not match the first schema"
+  )
+
+  tables <- list(
+    Table$create(a = 1:10, b = Scalar$create("x")),
+    Table$create(a = 2:42, b = Scalar$create("y")),
+    Table$create(a = 8:10, b = Scalar$create("z"))
+  )
+  expected <- Table$create(do.call(rbind, lapply(tables, function(table) as.data.frame(table))))
+  actual <- do.call(rbind, tables)
+  expect_equal(actual, expected, ignore_attr = TRUE)
+})
+
+test_that("Table supports cbind", {
+  expect_error(

Review Comment:
   See comments about the `cbind.RecordBatch` test



##########
r/tests/testthat/test-chunked-array.R:
##########
@@ -100,6 +100,18 @@ test_that("print ChunkedArray", {
   })
 })
 
+test_that("ChunkedArray can be concatenated with c()", {
+  a <- chunked_array(c(1, 2), 3)
+  b <- chunked_array(c(4, 5), 6)
+  expected <- chunked_array(c(1, 2), 3, c(4, 5), 6)
+  expect_equal(c(a, b), expected)
+
+  # Can handle Arrays and base vectors
+  vectors <- list(chunked_array(1:10), arrow::Array$create(1:10), 1:10)

Review Comment:
   ```suggestion
     vectors <- list(chunked_array(1:10), Array$create(1:10), 1:10)
   ```



##########
r/tests/testthat/test-Table.R:
##########
@@ -518,6 +518,77 @@ test_that("Table$create() no recycling with tibbles", {
   )
 })
 
+test_that("Table supports rbind", {
+  expect_error(
+    rbind(
+      Table$create(a = 1:10, b = Scalar$create(5)),
+      Table$create(a = c("a", "b"), b = Scalar$create(5))
+    ),
+    regexp = "Schema at index 2 does not match the first schema"
+  )
+
+  tables <- list(
+    Table$create(a = 1:10, b = Scalar$create("x")),
+    Table$create(a = 2:42, b = Scalar$create("y")),
+    Table$create(a = 8:10, b = Scalar$create("z"))
+  )
+  expected <- Table$create(do.call(rbind, lapply(tables, function(table) as.data.frame(table))))
+  actual <- do.call(rbind, tables)
+  expect_equal(actual, expected, ignore_attr = TRUE)
+})
+
+test_that("Table supports cbind", {
+  expect_error(
+    cbind(
+      Table$create(a = 1:10, ),
+      Table$create(a = c("a", "b"))
+    ),
+    regexp = "Non-scalar inputs must have an equal number of rows"
+  )
+
+  tables <- list(
+    Table$create(a = 1:10, b = Scalar$create("x")),
+    Table$create(a = 11:20, b = Scalar$create("y")),
+    Table$create(c = rnorm(10))
+  )
+  expected <- Table$create(do.call(cbind, lapply(tables, function(table) as.data.frame(table))))
+  actual <- do.call(cbind, tables)
+  expect_equal(actual, expected, ignore_attr = TRUE)
+
+  # Handles a variety of input types
+  inputs <- list(
+    Table$create(a = 1L:2L),
+    b = Array$create(4:5),
+    c = factor(c("a", "b")),
+    d = 1L
+  )
+
+  r_inputs <- inputs
+  r_inputs[[1]] <- as.data.frame(r_inputs[[1]])
+  r_inputs[["b"]] <- as.vector(r_inputs[["b"]])
+
+  expected <- Table$create(do.call(cbind, r_inputs))
+  actual <- do.call(cbind, inputs)
+  expect_equal(expected, actual, ignore_attr = TRUE)
+  expect_equal(names(actual), c("a", "b", "c", "d"))
+})
+
+test_that("cbind.Table handles record batches and tables", {
+  # R 3.6 cbind dispatch rules cause cbind to fall back to default impl if
+  # there are multiple arguments with distinct cbind implementations
+  skip_if(getRversion() < "4.0.0", "R 3.6 cbind dispatch rules prevent this behavior")
+
+  inputs <- list(
+    Table$create(a = 1L:2L),
+    RecordBatch$create(b = 4:5)
+  )
+
+  expected <- Table$create(do.call(cbind, map(inputs, as.data.frame)))
+  actual <- do.call(cbind, inputs)
+  expect_equal(expected, actual, ignore_attr = TRUE)
+  expect_equal(names(actual), c("a", "b"))

Review Comment:
   Again, the `do.call()` hides what is being tested...I think it will be clearer just typed out.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org