You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/12 21:22:38 UTC

[arrow-adbc] branch main updated: fix(c/driver/sqlite): Fix parameter binding when inferring types and when retrieving (#742)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e03c922 fix(c/driver/sqlite): Fix parameter binding when inferring types and when retrieving (#742)
0e03c922 is described below

commit 0e03c922afa3284425d1fce726e48aa971209888
Author: Kirill Müller <kr...@users.noreply.github.com>
AuthorDate: Mon Jun 12 23:22:32 2023 +0200

    fix(c/driver/sqlite): Fix parameter binding when inferring types and when retrieving (#742)
    
    Needs tests on the C side and perhaps also on the R side. Please advise.
    
    Closes #734.
    
    ``` r
    library(adbcdrivermanager)
    # pkgload::load_all()
    
    # Use the driver manager to connect to a database
    db <- adbc_database_init(adbcsqlite::adbcsqlite(), uri = ":memory:")
    con <- adbc_connection_init(db)
    
    # Write a table
    flights <- nycflights13::flights
    # (timestamp not supported yet)
    flights$time_hour <- NULL
    
    stmt <- adbc_statement_init(con, adbc.ingest.target_table = "flights")
    adbc_statement_bind(stmt, flights)
    adbc_statement_execute_query(stmt)
    #> [1] 336776
    adbc_statement_release(stmt)
    
    # March flights
    stmt <- adbc_statement_init(con)
    adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = 3 LIMIT 2")
    stream <- nanoarrow::nanoarrow_allocate_array_stream()
    adbc_statement_execute_query(stmt, stream)
    #> [1] -1
    
    
    result <- tibble::as_tibble(stream)
    adbc_statement_release(stmt)
    
    result
    #> # A tibble: 2 × 18
    #>    year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
    #>   <dbl> <dbl> <dbl>    <dbl>          <dbl>     <dbl>    <dbl>          <dbl>
    #> 1  2013     3     1        4           2159       125      318             56
    #> 2  2013     3     1       50           2358        52      526            438
    #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>,
    #> #   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
    #> #   hour <dbl>, minute <dbl>
    
    
    # March flights with a parameter, not passing parameter
    stmt <- adbc_statement_init(con)
    adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ? LIMIT 2")
    stream <- nanoarrow::nanoarrow_allocate_array_stream()
    adbc_statement_execute_query(stmt, stream)
    #> [1] -1
    result <- tibble::as_tibble(stream)
    adbc_statement_release(stmt)
    
    result
    #> # A tibble: 0 × 18
    #> # ℹ 18 variables: year <dbl>, month <dbl>, day <dbl>, dep_time <dbl>,
    #> #   sched_dep_time <dbl>, dep_delay <dbl>, arr_time <dbl>,
    #> #   sched_arr_time <dbl>, arr_delay <dbl>, carrier <dbl>, flight <dbl>,
    #> #   tailnum <dbl>, origin <dbl>, dest <dbl>, air_time <dbl>, distance <dbl>,
    #> #   hour <dbl>, minute <dbl>
    
    # March flights with a parameter
    stmt <- adbc_statement_init(con)
    adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ? LIMIT 2")
    adbc_statement_bind_stream(stmt, data.frame(a = 3))
    stream <- nanoarrow::nanoarrow_allocate_array_stream()
    adbc_statement_execute_query(stmt, stream)
    #> [1] -1
    result <- tibble::as_tibble(stream)
    adbc_statement_release(stmt)
    
    result
    #> # A tibble: 2 × 18
    #>    year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
    #>   <dbl> <dbl> <dbl>    <dbl>          <dbl>     <dbl>    <dbl>          <dbl>
    #> 1  2013     3     1        4           2159       125      318             56
    #> 2  2013     3     1       50           2358        52      526            438
    #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>,
    #> #   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
    #> #   hour <dbl>, minute <dbl>
    
    # Many March flights with multiple parameters
    stmt <- adbc_statement_init(con)
    adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ?")
    adbc_statement_bind_stream(stmt, data.frame(a = 2:4))
    stream <- nanoarrow::nanoarrow_allocate_array_stream()
    adbc_statement_execute_query(stmt, stream)
    #> [1] -1
    result <- tibble::as_tibble(stream)
    adbc_statement_release(stmt)
    
    result
    #> # A tibble: 24,951 × 18
    #>     year month   day dep_time sched_dep_time dep_delay arr_time sched_arr_time
    #>    <dbl> <dbl> <dbl>    <dbl>          <dbl>     <dbl>    <dbl>          <dbl>
    #>  1  2013     2     1      456            500        -4      652            648
    #>  2  2013     2     1      520            525        -5      816            820
    #>  3  2013     2     1      527            530        -3      837            829
    #>  4  2013     2     1      532            540        -8     1007           1017
    #>  5  2013     2     1      540            540         0      859            850
    #>  6  2013     2     1      552            600        -8      714            715
    #>  7  2013     2     1      552            600        -8      919            910
    #>  8  2013     2     1      552            600        -8      655            709
    #>  9  2013     2     1      553            600        -7      833            815
    #> 10  2013     2     1      553            600        -7      821            825
    #> # ℹ 24,941 more rows
    #> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>,
    #> #   tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
    #> #   hour <dbl>, minute <dbl>
    
    # Clean up
    adbc_connection_release(con)
    adbc_database_release(db)
    ```
    
    <sup>Created on 2023-06-08 with [reprex
    v2.0.2](https://reprex.tidyverse.org)</sup>
    
    ---------
    
    Co-authored-by: David Li <li...@gmail.com>
---
 c/driver/sqlite/sqlite_test.cc     | 41 ++++++++++++++++++
 c/driver/sqlite/statement_reader.c | 85 +++++++++++++++++++++++---------------
 c/validation/adbc_validation.cc    |  3 ++
 3 files changed, 96 insertions(+), 33 deletions(-)

diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index a7088850..5bcca0f7 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -536,6 +536,47 @@ TEST_F(SqliteReaderTest, InferTypedParams) {
                   "[SQLite] Type mismatch in column 0: expected INT64 but got DOUBLE"));
 }
 
+TEST_F(SqliteReaderTest, MultiValueParams) {
+  // Regression test for apache/arrow-adbc#734
+  adbc_validation::StreamReader reader;
+  Handle<struct ArrowSchema> schema;
+  Handle<struct ArrowArray> batch;
+
+  ASSERT_NO_FATAL_FAILURE(Exec("CREATE TABLE foo (col)"));
+  ASSERT_NO_FATAL_FAILURE(
+      Exec("INSERT INTO foo VALUES (1), (2), (2), (3), (3), (3), (4), (4), (4), (4)"));
+
+  ASSERT_THAT(adbc_validation::MakeSchema(&schema.value, {{"", NANOARROW_TYPE_INT64}}),
+              IsOkErrno());
+  ASSERT_THAT(adbc_validation::MakeBatch<int64_t>(&schema.value, &batch.value,
+                                                  /*error=*/nullptr, {4, 1, 3, 2}),
+              IsOkErrno());
+
+  ASSERT_NO_FATAL_FAILURE(Bind(&batch.value, &schema.value));
+  ASSERT_NO_FATAL_FAILURE(
+      Exec("SELECT col FROM foo WHERE col = ?", /*infer_rows=*/3, &reader));
+  ASSERT_EQ(1, reader.schema->n_children);
+  ASSERT_EQ(NANOARROW_TYPE_INT64, reader.fields[0].type);
+
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_NO_FATAL_FAILURE(
+      CompareArray<int64_t>(reader.array_view->children[0], {4, 4, 4}));
+
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_NO_FATAL_FAILURE(
+      CompareArray<int64_t>(reader.array_view->children[0], {4, 1, 3}));
+
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_NO_FATAL_FAILURE(
+      CompareArray<int64_t>(reader.array_view->children[0], {3, 3, 2}));
+
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[0], {2}));
+
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_EQ(nullptr, reader.array->release);
+}
+
 template <typename CType>
 class SqliteNumericParamTest : public SqliteReaderTest,
                                public ::testing::WithParamInterface<ArrowType> {
diff --git a/c/driver/sqlite/statement_reader.c b/c/driver/sqlite/statement_reader.c
index abde44a2..2b17364a 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -382,35 +382,41 @@ int StatementReaderGetNext(struct ArrowArrayStream* self, struct ArrowArray* out
 
   sqlite3_mutex_enter(sqlite3_db_mutex(reader->db));
   while (batch_size < reader->batch_size) {
-    if (reader->binder) {
-      char finished = 0;
-      struct AdbcError error = {0};
-      AdbcStatusCode status = AdbcSqliteBinderBindNext(reader->binder, reader->db,
-                                                       reader->stmt, &finished, &error);
-      if (status != ADBC_STATUS_OK) {
-        reader->done = 1;
-        status = EIO;
-        if (error.release) {
-          strncpy(reader->error.message, error.message, sizeof(reader->error.message));
-          reader->error.message[sizeof(reader->error.message) - 1] = '\0';
-          error.release(&error);
-        }
-        break;
-      } else if (finished) {
+    int rc = sqlite3_step(reader->stmt);
+    if (rc == SQLITE_DONE) {
+      if (!reader->binder) {
         reader->done = 1;
         break;
+      } else {
+        char finished = 0;
+        struct AdbcError error = {0};
+        status = AdbcSqliteBinderBindNext(reader->binder, reader->db, reader->stmt,
+                                          &finished, &error);
+        if (status != ADBC_STATUS_OK) {
+          reader->done = 1;
+          status = EIO;
+          if (error.release) {
+            strncpy(reader->error.message, error.message, sizeof(reader->error.message));
+            reader->error.message[sizeof(reader->error.message) - 1] = '\0';
+            error.release(&error);
+          }
+          break;
+        } else if (finished) {
+          reader->done = 1;
+          break;
+        }
+        continue;
       }
-    }
-
-    int rc = sqlite3_step(reader->stmt);
-    if (rc == SQLITE_DONE) {
-      reader->done = 1;
-      break;
     } else if (rc == SQLITE_ERROR) {
       reader->done = 1;
       status = EIO;
       StatementReaderSetError(reader);
       break;
+    } else if (rc != SQLITE_ROW) {
+      reader->done = 1;
+      status = ADBC_STATUS_INTERNAL;
+      StatementReaderSetError(reader);
+      break;
     }
 
     for (int col = 0; col < reader->schema.n_children; col++) {
@@ -836,26 +842,39 @@ AdbcStatusCode AdbcSqliteExportReader(sqlite3* db, sqlite3_stmt* stmt,
 
   AdbcStatusCode status = StatementReaderInitializeInfer(
       num_columns, batch_size, validity, data, binary, current_type, error);
-  if (status == ADBC_STATUS_OK) {
+
+  if (binder) {
+    char finished = 0;
+    status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error);
+    if (finished) {
+      reader->done = 1;
+    }
+  }
+
+  if (status == ADBC_STATUS_OK && !reader->done) {
     int64_t num_rows = 0;
     while (num_rows < batch_size) {
-      if (binder) {
-        char finished = 0;
-        status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error);
-        if (status != ADBC_STATUS_OK) break;
-        if (finished) {
+      int rc = sqlite3_step(stmt);
+      if (rc == SQLITE_DONE) {
+        if (!binder) {
           reader->done = 1;
           break;
+        } else {
+          char finished = 0;
+          status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error);
+          if (status != ADBC_STATUS_OK) break;
+          if (finished) {
+            reader->done = 1;
+            break;
+          }
         }
-      }
-
-      int rc = sqlite3_step(stmt);
-      if (rc == SQLITE_DONE) {
-        reader->done = 1;
-        break;
+        continue;
       } else if (rc == SQLITE_ERROR) {
         status = ADBC_STATUS_IO;
         break;
+      } else if (rc != SQLITE_ROW) {
+        status = ADBC_STATUS_INTERNAL;
+        break;
       }
 
       for (int col = 0; col < num_columns; col++) {
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index b99f469d..8c25f11f 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1463,6 +1463,9 @@ void StatementTest::TestSqlPrepareSelectParams() {
     auto start = nrows;
     auto end = nrows + reader.array->length;
 
+    ASSERT_LT(start, expected_int32.size());
+    ASSERT_LE(end, expected_int32.size());
+
     switch (reader.fields[0].type) {
       case NANOARROW_TYPE_INT32:
         ASSERT_NO_FATAL_FAILURE(CompareArray<int32_t>(