You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2023/04/22 06:43:12 UTC

[arrow-adbc] branch main updated: feat(ruby): add support for statement.ingest("table", table) (#601)

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 8b02cf4  feat(ruby): add support for statement.ingest("table", table) (#601)
8b02cf4 is described below

commit 8b02cf43fa1986847ae41d2f24919ad668c10e8c
Author: Sutou Kouhei <ko...@clear-code.com>
AuthorDate: Sat Apr 22 15:43:07 2023 +0900

    feat(ruby): add support for statement.ingest("table", table) (#601)
    
    Fixes #600.
---
 ruby/lib/adbc/statement.rb  |  3 ++
 ruby/test/test-statement.rb | 71 ++++++++++++++++++++++++++++-----------------
 2 files changed, 47 insertions(+), 27 deletions(-)

diff --git a/ruby/lib/adbc/statement.rb b/ruby/lib/adbc/statement.rb
index 90f4c23..383267e 100644
--- a/ruby/lib/adbc/statement.rb
+++ b/ruby/lib/adbc/statement.rb
@@ -62,6 +62,9 @@ module ADBC
         message = "wrong number of arguments (given #{n_args}, expected 1 with block)"
         raise ArgumentError, message unless n_args == 1
         values = args[0]
+        if values.is_a?(Arrow::Table)
+          values = Arrow::TableBatchReader.new(values)
+        end
         if values.is_a?(Arrow::RecordBatchReader)
           c_abi_array_stream = values.export
           begin
diff --git a/ruby/test/test-statement.rb b/ruby/test/test-statement.rb
index cac5b94..d306ea6 100644
--- a/ruby/test/test-statement.rb
+++ b/ruby/test/test-statement.rb
@@ -32,33 +32,50 @@ class StatementTest < Test::Unit::TestCase
     end
   end
 
-  def test_ingest
-    numbers = Arrow::Int64Array.new([10, 20, 30])
-    record_batch = Arrow::RecordBatch.new(number: numbers)
-    @statement.ingest("data", record_batch)
-    table, n_rows_affected = @statement.query("SELECT * FROM data")
-    assert_equal([
-                   Arrow::Table.new(number: numbers),
-                   -1,
-                 ],
-                 [
-                   table,
-                   n_rows_affected,
-                 ])
-  end
+  sub_test_case("#ingest") do
+    test("Arrow::RecordBatch") do
+      numbers = Arrow::Int64Array.new([10, 20, 30])
+      record_batch = Arrow::RecordBatch.new(number: numbers)
+      @statement.ingest("data", record_batch)
+      table, n_rows_affected = @statement.query("SELECT * FROM data")
+      assert_equal([
+                     Arrow::Table.new(number: numbers),
+                     -1,
+                   ],
+                   [
+                     table,
+                     n_rows_affected,
+                   ])
+    end
 
-  def test_ingest_stream
-    numbers = Arrow::Int64Array.new([10, 20, 30])
-    record_batch = Arrow::RecordBatch.new(number: numbers)
-    @statement.ingest("data", Arrow::RecordBatchReader.new([record_batch]))
-    table, n_rows_affected = @statement.query("SELECT * FROM data")
-    assert_equal([
-                   Arrow::Table.new(number: numbers),
-                   -1,
-                 ],
-                 [
-                   table,
-                   n_rows_affected,
-                 ])
+    test("Arrow::RecordBatchReader") do
+      numbers = Arrow::Int64Array.new([10, 20, 30])
+      record_batch = Arrow::RecordBatch.new(number: numbers)
+      @statement.ingest("data", Arrow::RecordBatchReader.new([record_batch]))
+      table, n_rows_affected = @statement.query("SELECT * FROM data")
+      assert_equal([
+                     Arrow::Table.new(number: numbers),
+                     -1,
+                   ],
+                   [
+                     table,
+                     n_rows_affected,
+                   ])
+    end
+
+    test("Arrow::Table") do
+      numbers = Arrow::Int64Array.new([10, 20, 30])
+      input_table = Arrow::Table.new(number: numbers)
+      @statement.ingest("data", input_table)
+      table, n_rows_affected = @statement.query("SELECT * FROM data")
+      assert_equal([
+                     input_table,
+                     -1,
+                   ],
+                   [
+                     table,
+                     n_rows_affected,
+                   ])
+    end
   end
 end