You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/12 21:22:38 UTC
[arrow-adbc] branch main updated: fix(c/driver/sqlite): Fix parameter binding when inferring types and when retrieving (#742)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 0e03c922 fix(c/driver/sqlite): Fix parameter binding when inferring types and when retrieving (#742)
0e03c922 is described below
commit 0e03c922afa3284425d1fce726e48aa971209888
Author: Kirill Müller <kr...@users.noreply.github.com>
AuthorDate: Mon Jun 12 23:22:32 2023 +0200
fix(c/driver/sqlite): Fix parameter binding when inferring types and when retrieving (#742)
Needs tests on the C side and perhaps also on the R side. Please advise.
Closes #734.
``` r
library(adbcdrivermanager)
# pkgload::load_all()
# Use the driver manager to connect to a database
db <- adbc_database_init(adbcsqlite::adbcsqlite(), uri = ":memory:")
con <- adbc_connection_init(db)
# Write a table
flights <- nycflights13::flights
# (timestamp not supported yet)
flights$time_hour <- NULL
stmt <- adbc_statement_init(con, adbc.ingest.target_table = "flights")
adbc_statement_bind(stmt, flights)
adbc_statement_execute_query(stmt)
#> [1] 336776
adbc_statement_release(stmt)
# March flights
stmt <- adbc_statement_init(con)
adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = 3 LIMIT 2")
stream <- nanoarrow::nanoarrow_allocate_array_stream()
adbc_statement_execute_query(stmt, stream)
#> [1] -1
result <- tibble::as_tibble(stream)
adbc_statement_release(stmt)
result
#> # A tibble: 2 × 18
#> year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 2013 3 1 4 2159 125 318 56
#> 2 2013 3 1 50 2358 52 526 438
#> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>,
#> # tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#> # hour <dbl>, minute <dbl>
# March flights with a parameter, not passing parameter
stmt <- adbc_statement_init(con)
adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ? LIMIT 2")
stream <- nanoarrow::nanoarrow_allocate_array_stream()
adbc_statement_execute_query(stmt, stream)
#> [1] -1
result <- tibble::as_tibble(stream)
adbc_statement_release(stmt)
result
#> # A tibble: 0 × 18
#> # ℹ 18 variables: year <dbl>, month <dbl>, day <dbl>, dep_time <dbl>,
#> # sched_dep_time <dbl>, dep_delay <dbl>, arr_time <dbl>,
#> # sched_arr_time <dbl>, arr_delay <dbl>, carrier <dbl>, flight <dbl>,
#> # tailnum <dbl>, origin <dbl>, dest <dbl>, air_time <dbl>, distance <dbl>,
#> # hour <dbl>, minute <dbl>
# March flights with a parameter
stmt <- adbc_statement_init(con)
adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ? LIMIT 2")
adbc_statement_bind_stream(stmt, data.frame(a = 3))
stream <- nanoarrow::nanoarrow_allocate_array_stream()
adbc_statement_execute_query(stmt, stream)
#> [1] -1
result <- tibble::as_tibble(stream)
adbc_statement_release(stmt)
result
#> # A tibble: 2 × 18
#> year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 2013 3 1 4 2159 125 318 56
#> 2 2013 3 1 50 2358 52 526 438
#> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>,
#> # tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#> # hour <dbl>, minute <dbl>
# Many March flights with multiple parameters
stmt <- adbc_statement_init(con)
adbc_statement_set_sql_query(stmt, "SELECT * from flights WHERE month = ?")
adbc_statement_bind_stream(stmt, data.frame(a = 2:4))
stream <- nanoarrow::nanoarrow_allocate_array_stream()
adbc_statement_execute_query(stmt, stream)
#> [1] -1
result <- tibble::as_tibble(stream)
adbc_statement_release(stmt)
result
#> # A tibble: 24,951 × 18
#> year month day dep_time sched_dep_time dep_delay arr_time sched_arr_time
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 2013 2 1 456 500 -4 652 648
#> 2 2013 2 1 520 525 -5 816 820
#> 3 2013 2 1 527 530 -3 837 829
#> 4 2013 2 1 532 540 -8 1007 1017
#> 5 2013 2 1 540 540 0 859 850
#> 6 2013 2 1 552 600 -8 714 715
#> 7 2013 2 1 552 600 -8 919 910
#> 8 2013 2 1 552 600 -8 655 709
#> 9 2013 2 1 553 600 -7 833 815
#> 10 2013 2 1 553 600 -7 821 825
#> # ℹ 24,941 more rows
#> # ℹ 10 more variables: arr_delay <dbl>, carrier <chr>, flight <dbl>,
#> # tailnum <chr>, origin <chr>, dest <chr>, air_time <dbl>, distance <dbl>,
#> # hour <dbl>, minute <dbl>
# Clean up
adbc_connection_release(con)
adbc_database_release(db)
```
<sup>Created on 2023-06-08 with [reprex
v2.0.2](https://reprex.tidyverse.org)</sup>
---------
Co-authored-by: David Li <li...@gmail.com>
---
c/driver/sqlite/sqlite_test.cc | 41 ++++++++++++++++++
c/driver/sqlite/statement_reader.c | 85 +++++++++++++++++++++++---------------
c/validation/adbc_validation.cc | 3 ++
3 files changed, 96 insertions(+), 33 deletions(-)
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index a7088850..5bcca0f7 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -536,6 +536,47 @@ TEST_F(SqliteReaderTest, InferTypedParams) {
"[SQLite] Type mismatch in column 0: expected INT64 but got DOUBLE"));
}
+TEST_F(SqliteReaderTest, MultiValueParams) {
+ // Regression test for apache/arrow-adbc#734
+ adbc_validation::StreamReader reader;
+ Handle<struct ArrowSchema> schema;
+ Handle<struct ArrowArray> batch;
+
+ ASSERT_NO_FATAL_FAILURE(Exec("CREATE TABLE foo (col)"));
+ ASSERT_NO_FATAL_FAILURE(
+ Exec("INSERT INTO foo VALUES (1), (2), (2), (3), (3), (3), (4), (4), (4), (4)"));
+
+ ASSERT_THAT(adbc_validation::MakeSchema(&schema.value, {{"", NANOARROW_TYPE_INT64}}),
+ IsOkErrno());
+ ASSERT_THAT(adbc_validation::MakeBatch<int64_t>(&schema.value, &batch.value,
+ /*error=*/nullptr, {4, 1, 3, 2}),
+ IsOkErrno());
+
+ ASSERT_NO_FATAL_FAILURE(Bind(&batch.value, &schema.value));
+ ASSERT_NO_FATAL_FAILURE(
+ Exec("SELECT col FROM foo WHERE col = ?", /*infer_rows=*/3, &reader));
+ ASSERT_EQ(1, reader.schema->n_children);
+ ASSERT_EQ(NANOARROW_TYPE_INT64, reader.fields[0].type);
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NO_FATAL_FAILURE(
+ CompareArray<int64_t>(reader.array_view->children[0], {4, 4, 4}));
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NO_FATAL_FAILURE(
+ CompareArray<int64_t>(reader.array_view->children[0], {4, 1, 3}));
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NO_FATAL_FAILURE(
+ CompareArray<int64_t>(reader.array_view->children[0], {3, 3, 2}));
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[0], {2}));
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_EQ(nullptr, reader.array->release);
+}
+
template <typename CType>
class SqliteNumericParamTest : public SqliteReaderTest,
public ::testing::WithParamInterface<ArrowType> {
diff --git a/c/driver/sqlite/statement_reader.c b/c/driver/sqlite/statement_reader.c
index abde44a2..2b17364a 100644
--- a/c/driver/sqlite/statement_reader.c
+++ b/c/driver/sqlite/statement_reader.c
@@ -382,35 +382,41 @@ int StatementReaderGetNext(struct ArrowArrayStream* self, struct ArrowArray* out
sqlite3_mutex_enter(sqlite3_db_mutex(reader->db));
while (batch_size < reader->batch_size) {
- if (reader->binder) {
- char finished = 0;
- struct AdbcError error = {0};
- AdbcStatusCode status = AdbcSqliteBinderBindNext(reader->binder, reader->db,
- reader->stmt, &finished, &error);
- if (status != ADBC_STATUS_OK) {
- reader->done = 1;
- status = EIO;
- if (error.release) {
- strncpy(reader->error.message, error.message, sizeof(reader->error.message));
- reader->error.message[sizeof(reader->error.message) - 1] = '\0';
- error.release(&error);
- }
- break;
- } else if (finished) {
+ int rc = sqlite3_step(reader->stmt);
+ if (rc == SQLITE_DONE) {
+ if (!reader->binder) {
reader->done = 1;
break;
+ } else {
+ char finished = 0;
+ struct AdbcError error = {0};
+ status = AdbcSqliteBinderBindNext(reader->binder, reader->db, reader->stmt,
+ &finished, &error);
+ if (status != ADBC_STATUS_OK) {
+ reader->done = 1;
+ status = EIO;
+ if (error.release) {
+ strncpy(reader->error.message, error.message, sizeof(reader->error.message));
+ reader->error.message[sizeof(reader->error.message) - 1] = '\0';
+ error.release(&error);
+ }
+ break;
+ } else if (finished) {
+ reader->done = 1;
+ break;
+ }
+ continue;
}
- }
-
- int rc = sqlite3_step(reader->stmt);
- if (rc == SQLITE_DONE) {
- reader->done = 1;
- break;
} else if (rc == SQLITE_ERROR) {
reader->done = 1;
status = EIO;
StatementReaderSetError(reader);
break;
+ } else if (rc != SQLITE_ROW) {
+ reader->done = 1;
+ status = ADBC_STATUS_INTERNAL;
+ StatementReaderSetError(reader);
+ break;
}
for (int col = 0; col < reader->schema.n_children; col++) {
@@ -836,26 +842,39 @@ AdbcStatusCode AdbcSqliteExportReader(sqlite3* db, sqlite3_stmt* stmt,
AdbcStatusCode status = StatementReaderInitializeInfer(
num_columns, batch_size, validity, data, binary, current_type, error);
- if (status == ADBC_STATUS_OK) {
+
+ if (binder) {
+ char finished = 0;
+ status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error);
+ if (finished) {
+ reader->done = 1;
+ }
+ }
+
+ if (status == ADBC_STATUS_OK && !reader->done) {
int64_t num_rows = 0;
while (num_rows < batch_size) {
- if (binder) {
- char finished = 0;
- status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error);
- if (status != ADBC_STATUS_OK) break;
- if (finished) {
+ int rc = sqlite3_step(stmt);
+ if (rc == SQLITE_DONE) {
+ if (!binder) {
reader->done = 1;
break;
+ } else {
+ char finished = 0;
+ status = AdbcSqliteBinderBindNext(binder, db, stmt, &finished, error);
+ if (status != ADBC_STATUS_OK) break;
+ if (finished) {
+ reader->done = 1;
+ break;
+ }
}
- }
-
- int rc = sqlite3_step(stmt);
- if (rc == SQLITE_DONE) {
- reader->done = 1;
- break;
+ continue;
} else if (rc == SQLITE_ERROR) {
status = ADBC_STATUS_IO;
break;
+ } else if (rc != SQLITE_ROW) {
+ status = ADBC_STATUS_INTERNAL;
+ break;
}
for (int col = 0; col < num_columns; col++) {
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index b99f469d..8c25f11f 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1463,6 +1463,9 @@ void StatementTest::TestSqlPrepareSelectParams() {
auto start = nrows;
auto end = nrows + reader.array->length;
+ ASSERT_LT(start, expected_int32.size());
+ ASSERT_LE(end, expected_int32.size());
+
switch (reader.fields[0].type) {
case NANOARROW_TYPE_INT32:
ASSERT_NO_FATAL_FAILURE(CompareArray<int32_t>(