You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/06/09 14:53:20 UTC
[arrow-adbc] branch main updated: Refactor initialization for consistency (#6)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new a9f3e78 Refactor initialization for consistency (#6)
a9f3e78 is described below
commit a9f3e78eac0d56a991dc19fd23c4b6b88419ea97
Author: David Li <li...@gmail.com>
AuthorDate: Thu Jun 9 10:53:15 2022 -0400
Refactor initialization for consistency (#6)
* Move away from ODBC-style parameter strings
Instead, start refactoring things to follow a multi-step init
similar to DuckDB: New(), then SetOption(), then Init().
Also, update the driver manager so it no longer requires a special
entrypoint. Instead, just call AdbcDatabaseNew(), then set options
to indicate the driver to use.
* Refactor how statements are handled
Make them more consistent and reduce the surface of the
SQL-specific portion.
* Fix formatting
---
adbc.h | 166 +++++-----
adbc_driver_manager/adbc_driver_manager.cc | 237 ++++++++------
adbc_driver_manager/adbc_driver_manager.h | 14 +-
adbc_driver_manager/adbc_driver_manager_test.cc | 64 ++--
drivers/flight_sql/flight_sql.cc | 394 +++++++++++++-----------
drivers/flight_sql/flight_sql_test.cc | 90 +++---
drivers/sqlite/sqlite.cc | 323 ++++++++++---------
drivers/sqlite/sqlite_test.cc | 275 +++++++----------
drivers/test_util.h | 2 +-
drivers/util.h | 6 +
10 files changed, 832 insertions(+), 739 deletions(-)
diff --git a/adbc.h b/adbc.h
index 29aff22..a692e91 100644
--- a/adbc.h
+++ b/adbc.h
@@ -181,18 +181,6 @@ struct AdbcError {
/// common connection state.
/// @{
-/// \brief A set of database options.
-struct AdbcDatabaseOptions {
- /// \brief A driver-specific database string.
- ///
- /// Should be in ODBC-style format ("Key1=Value1;Key2=Value2").
- const char* target;
-
- /// \brief The associated driver. Required if using the driver
- /// manager; not required if directly calling into a driver.
- AdbcDriver* driver;
-};
-
/// \brief An instance of a database.
///
/// Must be kept alive as long as any connections exist.
@@ -205,9 +193,18 @@ struct AdbcDatabase {
AdbcDriver* private_driver;
};
-/// \brief Initialize a new database.
-AdbcStatusCode AdbcDatabaseInit(const struct AdbcDatabaseOptions* options,
- struct AdbcDatabase* out, struct AdbcError* error);
+/// \brief Allocate a new (but uninitialized) database.
+AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error);
+
+/// \brief Set a char* option.
+AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error);
+
+/// \brief Finish setting options and initialize the database.
+///
+/// Some backends may support setting options after initialization
+/// as well.
+AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error);
/// \brief Destroy this database. No connections may exist.
/// \param[in] database The database to release.
@@ -248,9 +245,16 @@ struct AdbcConnection {
AdbcDriver* private_driver;
};
-/// \brief Create a new connection to a database.
-AdbcStatusCode AdbcConnectionInit(const struct AdbcConnectionOptions* options,
- struct AdbcConnection* connection,
+/// \brief Allocate a new (but uninitialized) connection.
+AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
+ struct AdbcConnection* connection,
+ struct AdbcError* error);
+
+AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key,
+ const char* value, struct AdbcError* error);
+
+/// \brief Finish setting options and initialize the connection.
+AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
struct AdbcError* error);
/// \brief Destroy this connection.
@@ -260,49 +264,6 @@ AdbcStatusCode AdbcConnectionInit(const struct AdbcConnectionOptions* options,
AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
struct AdbcError* error);
-/// \defgroup adbc-connection-sql SQL Semantics
-/// Functions for executing SQL queries, or querying SQL-related
-/// metadata. Drivers are not required to support both SQL and
-/// Substrait semantics. If they do, it may be via converting
-/// between representations internally.
-/// @{
-
-/// \brief Execute a one-shot query.
-///
-/// For queries expected to be executed repeatedly, create a
-/// prepared statement.
-///
-/// \param[in] connection The database connection.
-/// \param[in] query The query to execute.
-/// \param[in,out] statement The result set. Allocate with AdbcStatementInit.
-/// \param[out] error Error details, if an error occurs.
-AdbcStatusCode AdbcConnectionSqlExecute(struct AdbcConnection* connection,
- const char* query,
- struct AdbcStatement* statement,
- struct AdbcError* error);
-
-/// \brief Prepare a query to be executed multiple times.
-///
-/// TODO: this should return AdbcPreparedStatement to disaggregate
-/// preparation and execution
-AdbcStatusCode AdbcConnectionSqlPrepare(struct AdbcConnection* connection,
- const char* query,
- struct AdbcStatement* statement,
- struct AdbcError* error);
-
-/// }@
-
-/// \defgroup adbc-connection-substrait Substrait Semantics
-/// Functions for executing Substrait plans, or querying
-/// Substrait-related metadata. Drivers are not required to support
-/// both SQL and Substrait semantics. If they do, it may be via
-/// converting between representations internally.
-/// @{
-
-// TODO: not yet defined
-
-/// }@
-
/// \defgroup adbc-connection-partition Partitioned Results
/// Some databases may internally partition the results. These
/// partitions are exposed to clients who may wish to integrate them
@@ -449,13 +410,19 @@ struct AdbcStatement {
};
/// \brief Create a new statement for a given connection.
-AdbcStatusCode AdbcStatementInit(struct AdbcConnection* connection,
- struct AdbcStatement* statement,
- struct AdbcError* error);
+///
+/// Set options on the statement, then call AdbcStatementExecute or
+/// AdbcStatementPrepare.
+AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement, struct AdbcError* error);
-/// \brief Set an integer option on a statement.
-AdbcStatusCode AdbcStatementSetOptionInt64(struct AdbcStatement* statement,
- struct AdbcError* error);
+/// \brief Execute a statement.
+AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
+ struct AdbcError* error);
+
+/// \brief Create a prepared statement to be executed multiple times.
+AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
+ struct AdbcError* error);
/// \brief Destroy a statement.
/// \param[in] statement The statement to release.
@@ -464,6 +431,38 @@ AdbcStatusCode AdbcStatementSetOptionInt64(struct AdbcStatement* statement,
AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
struct AdbcError* error);
+/// \defgroup adbc-statement-sql SQL Semantics
+/// Functions for executing SQL queries, or querying SQL-related
+/// metadata. Drivers are not required to support both SQL and
+/// Substrait semantics. If they do, it may be via converting
+/// between representations internally.
+/// @{
+
+/// \brief Execute a one-shot query.
+///
+/// For queries expected to be executed repeatedly, create a
+/// prepared statement.
+///
+/// \param[in] connection The database connection.
+/// \param[in] query The query to execute.
+/// \param[in,out] statement The result set. Allocate with AdbcStatementInit.
+/// \param[out] error Error details, if an error occurs.
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* connection,
+ const char* query, struct AdbcError* error);
+
+/// }@
+
+/// \defgroup adbc-statement-substrait Substrait Semantics
+/// Functions for executing Substrait plans, or querying
+/// Substrait-related metadata. Drivers are not required to support
+/// both SQL and Substrait semantics. If they do, it may be via
+/// converting between representations internally.
+/// @{
+
+// TODO: not yet defined
+
+/// }@
+
/// \brief Bind parameter values for parameterized statements.
/// \param[in] statement The statement to bind to.
/// \param[in] values The values to bind. The driver will not call the
@@ -476,12 +475,6 @@ AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
struct ArrowArray* values, struct ArrowSchema* schema,
struct AdbcError* error);
-/// \brief Execute a statement.
-///
-/// Not called for one-shot queries (e.g. AdbcConnectionSqlExecute).
-AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
- struct AdbcError* error);
-
/// \brief Read the result of a statement.
///
/// This method can be called only once per execution of the
@@ -567,12 +560,17 @@ struct AdbcDriver {
void* private_data;
// TODO: DriverRelease
- AdbcStatusCode (*DatabaseInit)(const struct AdbcDatabaseOptions*, struct AdbcDatabase*,
- struct AdbcError*);
+ AdbcStatusCode (*DatabaseNew)(struct AdbcDatabase*, struct AdbcError*);
+ AdbcStatusCode (*DatabaseSetOption)(struct AdbcDatabase*, const char*, const char*,
+ struct AdbcError*);
+ AdbcStatusCode (*DatabaseInit)(struct AdbcDatabase*, struct AdbcError*);
AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*);
- AdbcStatusCode (*ConnectionInit)(const struct AdbcConnectionOptions*,
- struct AdbcConnection*, struct AdbcError*);
+ AdbcStatusCode (*ConnectionNew)(struct AdbcDatabase*, struct AdbcConnection*,
+ struct AdbcError*);
+ AdbcStatusCode (*ConnectionSetOption)(struct AdbcConnection*, const char*, const char*,
+ struct AdbcError*);
+ AdbcStatusCode (*ConnectionInit)(struct AdbcConnection*, struct AdbcError*);
AdbcStatusCode (*ConnectionRelease)(struct AdbcConnection*, struct AdbcError*);
AdbcStatusCode (*ConnectionSqlExecute)(struct AdbcConnection*, const char*,
struct AdbcStatement*, struct AdbcError*);
@@ -593,19 +591,21 @@ struct AdbcDriver {
const char*, const char**, struct AdbcStatement*,
struct AdbcError*);
- AdbcStatusCode (*StatementInit)(struct AdbcConnection*, struct AdbcStatement*,
- struct AdbcError*);
- AdbcStatusCode (*StatementSetOptionInt64)(struct AdbcStatement*, struct AdbcError*);
+ AdbcStatusCode (*StatementNew)(struct AdbcConnection*, struct AdbcStatement*,
+ struct AdbcError*);
AdbcStatusCode (*StatementRelease)(struct AdbcStatement*, struct AdbcError*);
AdbcStatusCode (*StatementBind)(struct AdbcStatement*, struct ArrowArray*,
struct ArrowSchema*, struct AdbcError*);
AdbcStatusCode (*StatementExecute)(struct AdbcStatement*, struct AdbcError*);
+ AdbcStatusCode (*StatementPrepare)(struct AdbcStatement*, struct AdbcError*);
AdbcStatusCode (*StatementGetStream)(struct AdbcStatement*, struct ArrowArrayStream*,
struct AdbcError*);
AdbcStatusCode (*StatementGetPartitionDescSize)(struct AdbcStatement*, size_t*,
struct AdbcError*);
AdbcStatusCode (*StatementGetPartitionDesc)(struct AdbcStatement*, uint8_t*,
struct AdbcError*);
+ AdbcStatusCode (*StatementSetSqlQuery)(struct AdbcStatement*, const char*,
+ struct AdbcError*);
// Do not edit fields. New fields can only be appended to the end.
};
@@ -622,13 +622,13 @@ struct AdbcDriver {
/// \param[out] error An optional location to return an error message
/// if necessary.
typedef AdbcStatusCode (*AdbcDriverInitFunc)(size_t count, struct AdbcDriver* driver,
- size_t* initialized, struct AdbcError* error);
-// TODO: how best to report errors here?
+ size_t* initialized,
+ struct AdbcError* error);
// TODO: use sizeof() instead of count, or version the
// struct/entrypoint instead?
// For use with count
-#define ADBC_VERSION_0_0_1 19
+#define ADBC_VERSION_0_0_1 21
/// }@
diff --git a/adbc_driver_manager/adbc_driver_manager.cc b/adbc_driver_manager/adbc_driver_manager.cc
index a5f899d..e20815a 100644
--- a/adbc_driver_manager/adbc_driver_manager.cc
+++ b/adbc_driver_manager/adbc_driver_manager.cc
@@ -19,34 +19,11 @@
#include <dlfcn.h>
#include <algorithm>
+#include <cstring>
#include <string>
#include <unordered_map>
namespace {
-std::unordered_map<std::string, std::string> ParseConnectionString(
- const std::string& target) {
- // TODO: this does not properly implement the ODBC connection string format.
- std::unordered_map<std::string, std::string> option_pairs;
- size_t cur = 0;
-
- while (cur < target.size()) {
- auto divider = target.find('=', cur);
- if (divider == std::string::npos) break;
-
- std::string key = target.substr(cur, divider - cur);
- cur = divider + 1;
- auto end = target.find(';', cur);
- if (end == std::string::npos) {
- option_pairs.insert({std::move(key), target.substr(cur)});
- break;
- } else {
- option_pairs.insert({std::string(key), target.substr(cur, end - cur)});
- cur = end + 1;
- }
- }
- return option_pairs;
-}
-
void ReleaseError(struct AdbcError* error) {
if (error) {
delete[] error->message;
@@ -75,24 +52,89 @@ AdbcStatusCode StatementBind(struct AdbcStatement*, struct ArrowArray*,
AdbcStatusCode StatementExecute(struct AdbcStatement*, struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
+
+// Temporary
+struct TempDatabase {
+ std::unordered_map<std::string, std::string> options;
+ std::string driver;
+ std::string entrypoint;
+};
} // namespace
-#define FILL_DEFAULT(DRIVER, STUB) \
- if (!DRIVER->STUB) { \
- DRIVER->STUB = &STUB; \
+// Direct implementations of API methods
+
+AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) {
+ // Allocate a temporary structure to store options pre-Init
+ database->private_data = new TempDatabase;
+ database->private_driver = nullptr;
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error) {
+ if (database->private_driver) {
+ return database->private_driver->DatabaseSetOption(database, key, value, error);
}
-// Direct implementations of API methods
+ TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
+ if (std::strcmp(key, "driver") == 0) {
+ args->driver = value;
+ } else if (std::strcmp(key, "entrypoint") == 0) {
+ args->entrypoint = value;
+ } else {
+ args->options[key] = value;
+ }
+ return ADBC_STATUS_OK;
+}
-AdbcStatusCode AdbcDatabaseInit(const struct AdbcDatabaseOptions* options,
- struct AdbcDatabase* out, struct AdbcError* error) {
- if (!options->driver) {
- // TODO: set error
+AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) {
+ if (!database->private_data) {
+ SetError(error, "Must call AdbcDatabaseNew first");
+ return ADBC_STATUS_UNINITIALIZED;
+ }
+ TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
+ if (args->driver.empty()) {
+ delete args;
+ SetError(error, "Must provide 'driver' parameter");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ } else if (args->entrypoint.empty()) {
+ delete args;
+ SetError(error, "Must provide 'entrypoint' parameter");
return ADBC_STATUS_INVALID_ARGUMENT;
}
- auto status = options->driver->DatabaseInit(options, out, error);
- out->private_driver = options->driver;
- return status;
+
+ database->private_driver = new AdbcDriver;
+ size_t initialized = 0;
+ AdbcStatusCode status =
+ AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), ADBC_VERSION_0_0_1,
+ database->private_driver, &initialized, error);
+ if (status != ADBC_STATUS_OK) {
+ delete args;
+ delete database->private_driver;
+ return status;
+ } else if (initialized < ADBC_VERSION_0_0_1) {
+ delete args;
+ delete database->private_driver;
+ SetError(error, "Database version is too old"); // TODO: clearer error
+ return status;
+ }
+ status = database->private_driver->DatabaseNew(database, error);
+ if (status != ADBC_STATUS_OK) {
+ delete args;
+ delete database->private_driver;
+ return status;
+ }
+ for (const auto& option : args->options) {
+ status = database->private_driver->DatabaseSetOption(database, option.first.c_str(),
+ option.second.c_str(), error);
+ if (status != ADBC_STATUS_OK) {
+ delete args;
+ delete database->private_driver;
+ return status;
+ }
+ }
+ delete args;
+ return database->private_driver->DatabaseInit(database, error);
}
AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
@@ -101,19 +143,28 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
return ADBC_STATUS_UNINITIALIZED;
}
auto status = database->private_driver->DatabaseRelease(database, error);
- database->private_driver = nullptr;
+ delete database->private_driver;
return status;
}
-AdbcStatusCode AdbcConnectionInit(const struct AdbcConnectionOptions* options,
- struct AdbcConnection* out, struct AdbcError* error) {
- if (!options->database->private_driver) {
+AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
+ struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!database->private_driver) {
+ return ADBC_STATUS_UNINITIALIZED;
+ }
+ auto status = database->private_driver->ConnectionNew(database, connection, error);
+ connection->private_driver = database->private_driver;
+ return status;
+}
+
+AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_driver) {
// TODO: set error
return ADBC_STATUS_INVALID_ARGUMENT;
}
- auto status = options->database->private_driver->ConnectionInit(options, out, error);
- out->private_driver = options->database->private_driver;
- return status;
+ return connection->private_driver->ConnectionInit(connection, error);
}
AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
@@ -126,74 +177,67 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
return status;
}
-AdbcStatusCode AdbcConnectionSqlExecute(struct AdbcConnection* connection,
- const char* query,
- struct AdbcStatement* statement,
- struct AdbcError* error) {
- if (!connection->private_driver) {
+AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
+ struct ArrowArray* values, struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!statement->private_driver) {
return ADBC_STATUS_UNINITIALIZED;
}
- return connection->private_driver->ConnectionSqlExecute(connection, query, statement,
- error);
+ return statement->private_driver->StatementBind(statement, values, schema, error);
}
-AdbcStatusCode AdbcConnectionSqlPrepare(struct AdbcConnection* connection,
- const char* query,
- struct AdbcStatement* statement,
- struct AdbcError* error) {
- if (!connection->private_driver) {
+AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!statement->private_driver) {
return ADBC_STATUS_UNINITIALIZED;
}
- return connection->private_driver->ConnectionSqlPrepare(connection, query, statement,
- error);
+ return statement->private_driver->StatementExecute(statement, error);
}
-AdbcStatusCode AdbcStatementInit(struct AdbcConnection* connection,
- struct AdbcStatement* statement,
- struct AdbcError* error) {
+AdbcStatusCode AdbcStatementGetStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* out,
+ struct AdbcError* error) {
+ if (!statement->private_driver) {
+ return ADBC_STATUS_UNINITIALIZED;
+ }
+ return statement->private_driver->StatementGetStream(statement, out, error);
+}
+
+AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
if (!connection->private_driver) {
- // TODO: set error
return ADBC_STATUS_INVALID_ARGUMENT;
}
- auto status = connection->private_driver->StatementInit(connection, statement, error);
+ auto status = connection->private_driver->StatementNew(connection, statement, error);
statement->private_driver = connection->private_driver;
return status;
}
-AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
+AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
struct AdbcError* error) {
if (!statement->private_driver) {
return ADBC_STATUS_UNINITIALIZED;
}
- auto status = statement->private_driver->StatementRelease(statement, error);
- statement->private_driver = nullptr;
- return status;
+ return statement->private_driver->StatementPrepare(statement, error);
}
-AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
- struct ArrowArray* values, struct ArrowSchema* schema,
- struct AdbcError* error) {
- if (!statement->private_driver) {
- return ADBC_STATUS_UNINITIALIZED;
- }
- return statement->private_driver->StatementBind(statement, values, schema, error);
-}
-
-AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
+AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
struct AdbcError* error) {
if (!statement->private_driver) {
return ADBC_STATUS_UNINITIALIZED;
}
- return statement->private_driver->StatementExecute(statement, error);
+ auto status = statement->private_driver->StatementRelease(statement, error);
+ statement->private_driver = nullptr;
+ return status;
}
-AdbcStatusCode AdbcStatementGetStream(struct AdbcStatement* statement,
- struct ArrowArrayStream* out,
- struct AdbcError* error) {
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
+ const char* query, struct AdbcError* error) {
if (!statement->private_driver) {
return ADBC_STATUS_UNINITIALIZED;
}
- return statement->private_driver->StatementGetStream(statement, out, error);
+ return statement->private_driver->StatementSetSqlQuery(statement, query, error);
}
const char* AdbcStatusCodeMessage(AdbcStatusCode code) {
@@ -219,24 +263,20 @@ const char* AdbcStatusCodeMessage(AdbcStatusCode code) {
#undef STRINGIFY
}
-AdbcStatusCode AdbcLoadDriver(const char* connection, size_t count,
- struct AdbcDriver* driver, size_t* initialized,
- struct AdbcError* error) {
- auto params = ParseConnectionString(connection);
-
- auto driver_str = params.find("Driver");
- if (driver_str == params.end()) {
- SetError(error, "Must provide Driver parameter");
- return ADBC_STATUS_INVALID_ARGUMENT;
+AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint,
+ size_t count, struct AdbcDriver* driver,
+ size_t* initialized, struct AdbcError* error) {
+#define FILL_DEFAULT(DRIVER, STUB) \
+ if (!DRIVER->STUB) { \
+ DRIVER->STUB = &STUB; \
}
-
- auto entrypoint_str = params.find("Entrypoint");
- if (entrypoint_str == params.end()) {
- SetError(error, "Must provide Entrypoint parameter");
- return ADBC_STATUS_INVALID_ARGUMENT;
+#define CHECK_REQUIRED(DRIVER, STUB) \
+ if (!DRIVER->STUB) { \
+ SetError(error, "Driver does not implement required function Adbc" #STUB); \
+ return ADBC_STATUS_INTERNAL; \
}
- void* handle = dlopen(driver_str->second.c_str(), RTLD_NOW | RTLD_LOCAL);
+ void* handle = dlopen(driver_name, RTLD_NOW | RTLD_LOCAL);
if (!handle) {
std::string message = "dlopen() failed: ";
message += dlerror();
@@ -244,7 +284,7 @@ AdbcStatusCode AdbcLoadDriver(const char* connection, size_t count,
return ADBC_STATUS_UNKNOWN;
}
- void* load_handle = dlsym(handle, entrypoint_str->second.c_str());
+ void* load_handle = dlsym(handle, entrypoint);
auto* load = reinterpret_cast<AdbcDriverInitFunc>(load_handle);
if (!load) {
std::string message = "dlsym() failed: ";
@@ -258,8 +298,13 @@ AdbcStatusCode AdbcLoadDriver(const char* connection, size_t count,
return result;
}
+ CHECK_REQUIRED(driver, DatabaseNew);
+ CHECK_REQUIRED(driver, DatabaseInit);
FILL_DEFAULT(driver, ConnectionSqlPrepare);
FILL_DEFAULT(driver, StatementBind);
FILL_DEFAULT(driver, StatementExecute);
return ADBC_STATUS_OK;
+
+#undef FILL_DEFAULT
+#undef CHECK_REQUIRED
}
diff --git a/adbc_driver_manager/adbc_driver_manager.h b/adbc_driver_manager/adbc_driver_manager.h
index 00f1077..3a67706 100644
--- a/adbc_driver_manager/adbc_driver_manager.h
+++ b/adbc_driver_manager/adbc_driver_manager.h
@@ -33,8 +33,10 @@ extern "C" {
/// of functionality for this to be possible, however, and some
/// functions must be implemented by the driver.
///
-/// \param[in] connection The driver to initialize. Should be in
-/// ODBC-style format ("Key1=Value1;Key2=Value2").
+/// \param[in] driver_name An identifier for the driver (e.g. a path to a
+/// shared library on Linux).
+/// \param[in] entrypoint An identifier for the entrypoint (e.g. the
+/// symbol to call for AdbcDriverInitFunc on Linux).
/// \param[in] count The number of entries to initialize. Provides
/// backwards compatibility if the struct definition is changed.
/// \param[out] driver The table of function pointers to initialize.
@@ -42,9 +44,11 @@ extern "C" {
/// initialized (can be less than count).
/// \param[out] error An optional location to return an error message
/// if necessary.
-AdbcStatusCode AdbcLoadDriver(const char* connection, size_t count,
- struct AdbcDriver* driver, size_t* initialized,
- struct AdbcError* error);
+AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint,
+ size_t count, struct AdbcDriver* driver,
+ size_t* initialized, struct AdbcError* error);
+
+const char* AdbcStatusCodeMessage(AdbcStatusCode code);
#endif // ADBC_DRIVER_MANAGER_H
diff --git a/adbc_driver_manager/adbc_driver_manager_test.cc b/adbc_driver_manager/adbc_driver_manager_test.cc
index 65ec7e4..ae53e84 100644
--- a/adbc_driver_manager/adbc_driver_manager_test.cc
+++ b/adbc_driver_manager/adbc_driver_manager_test.cc
@@ -36,22 +36,24 @@ class DriverManager : public ::testing::Test {
public:
void SetUp() override {
size_t initialized = 0;
- ADBC_ASSERT_OK_WITH_ERROR(error,
- AdbcLoadDriver("Driver=libadbc_driver_sqlite.so;Entrypoint=AdbcSqliteDriverInit",
- ADBC_VERSION_0_0_1, &driver, &initialized, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcLoadDriver("libadbc_driver_sqlite.so", "AdbcSqliteDriverInit",
+ ADBC_VERSION_0_0_1, &driver, &initialized, &error));
ASSERT_EQ(initialized, ADBC_VERSION_0_0_1);
- AdbcDatabaseOptions db_options;
- std::memset(&db_options, 0, sizeof(db_options));
- db_options.driver = &driver;
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&db_options, &database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+ ASSERT_NE(database.private_data, nullptr);
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error,
+ AdbcDatabaseSetOption(&database, "driver", "libadbc_driver_sqlite.so", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error,
+ AdbcDatabaseSetOption(&database, "entrypoint", "AdbcSqliteDriverInit", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
ASSERT_NE(database.private_data, nullptr);
- AdbcConnectionOptions conn_options;
- std::memset(&conn_options, 0, sizeof(conn_options));
- conn_options.database = &database;
- ADBC_ASSERT_OK_WITH_ERROR(error,
- AdbcConnectionInit(&conn_options, &connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
ASSERT_NE(connection.private_data, nullptr);
}
@@ -74,9 +76,10 @@ TEST_F(DriverManager, SqlExecute) {
std::string query = "SELECT 1";
AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlExecute(&connection, query.c_str(), &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
@@ -93,9 +96,8 @@ TEST_F(DriverManager, SqlExecuteInvalid) {
std::string query = "INVALID";
AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ASSERT_NE(AdbcConnectionSqlExecute(&connection, query.c_str(), &statement, &error),
- ADBC_STATUS_OK);
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ASSERT_NE(AdbcStatementSetSqlQuery(&statement, query.c_str(), &error), ADBC_STATUS_OK);
ADBC_ASSERT_ERROR_THAT(
error, ::testing::AllOf(::testing::HasSubstr("[SQLite3] sqlite3_prepare_v2:"),
::testing::HasSubstr("syntax error")));
@@ -106,9 +108,10 @@ TEST_F(DriverManager, SqlPrepare) {
std::string query = "SELECT 1";
AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlPrepare(&connection, query.c_str(), &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementPrepare(&statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
@@ -137,9 +140,10 @@ TEST_F(DriverManager, SqlPrepareMultipleParams) {
&export_params));
ASSERT_OK(ExportSchema(*param_schema, &export_schema));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlPrepare(&connection, query.c_str(), &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementPrepare(&statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(
error, AdbcStatementBind(&statement, &export_params, &export_schema, &error));
@@ -150,12 +154,12 @@ TEST_F(DriverManager, SqlPrepareMultipleParams) {
ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
ASSERT_SCHEMA_EQ(*schema, *arrow::schema({arrow::field("?", arrow::int64()),
arrow::field("?", arrow::utf8())}));
- EXPECT_THAT(batches, ::testing::UnorderedPointwise(
- PointeesEqual(),
- {
- adbc::RecordBatchFromJSON(schema, R"([[1, "foo"], [2,
- "bar"]])"),
- }));
+ EXPECT_THAT(batches,
+ ::testing::UnorderedPointwise(
+ PointeesEqual(),
+ {
+ adbc::RecordBatchFromJSON(schema, R"([[1, "foo"], [2, "bar"]])"),
+ }));
}
} // namespace adbc
diff --git a/drivers/flight_sql/flight_sql.cc b/drivers/flight_sql/flight_sql.cc
index 584abb1..66226e6 100644
--- a/drivers/flight_sql/flight_sql.cc
+++ b/drivers/flight_sql/flight_sql.cc
@@ -61,15 +61,53 @@ void SetError(struct AdbcError* error, Args&&... args) {
class FlightSqlDatabaseImpl {
public:
- explicit FlightSqlDatabaseImpl(std::unique_ptr<flightsql::FlightSqlClient> client)
- : client_(std::move(client)), connection_count_(0) {}
+ explicit FlightSqlDatabaseImpl() : client_(nullptr), connection_count_(0) {}
flightsql::FlightSqlClient* Connect() {
std::lock_guard<std::mutex> guard(mutex_);
- ++connection_count_;
+ if (client_) ++connection_count_;
return client_.get();
}
+ AdbcStatusCode Init(struct AdbcError* error) {
+ if (client_) {
+ SetError(error, "Database already initialized");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ auto it = options_.find("location");
+ if (it == options_.end()) {
+ SetError(error, "Must provide 'location' option");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ flight::Location location;
+ arrow::Status status = flight::Location::Parse(it->second).Value(&location);
+ if (!status.ok()) {
+ SetError(error, status);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ std::unique_ptr<flight::FlightClient> flight_client;
+ status = flight::FlightClient::Connect(location).Value(&flight_client);
+ if (!status.ok()) {
+ SetError(error, status);
+ return ADBC_STATUS_IO;
+ }
+
+ client_.reset(new flightsql::FlightSqlClient(std::move(flight_client)));
+ options_.clear();
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) {
+ if (client_) {
+ SetError(error, "Database already initialized");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ options_[key] = value;
+ return ADBC_STATUS_OK;
+ }
+
AdbcStatusCode Disconnect(struct AdbcError* error) {
std::lock_guard<std::mutex> guard(mutex_);
if (--connection_count_ < 0) {
@@ -99,19 +137,44 @@ class FlightSqlDatabaseImpl {
private:
std::unique_ptr<flightsql::FlightSqlClient> client_;
int connection_count_;
+ std::unordered_map<std::string, std::string> options_;
std::mutex mutex_;
};
-class FlightSqlStatementImpl : public arrow::RecordBatchReader {
+class FlightSqlConnectionImpl {
public:
- FlightSqlStatementImpl() : client_(nullptr), info_() {}
+ explicit FlightSqlConnectionImpl(std::shared_ptr<FlightSqlDatabaseImpl> database)
+ : database_(std::move(database)), client_(nullptr) {}
+
+ //----------------------------------------------------------
+ // Common Functions
+ //----------------------------------------------------------
+
+ flightsql::FlightSqlClient* client() const { return client_; }
- void Init(flightsql::FlightSqlClient* client,
- std::unique_ptr<flight::FlightInfo> info) {
- client_ = client;
- info_ = std::move(info);
+ AdbcStatusCode Init(struct AdbcError* error) {
+ client_ = database_->Connect();
+ if (!client_) {
+ SetError(error, "Database not yet initialized!");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ return ADBC_STATUS_OK;
}
+ AdbcStatusCode Close(struct AdbcError* error) { return database_->Disconnect(error); }
+
+ private:
+ std::shared_ptr<FlightSqlDatabaseImpl> database_;
+ flightsql::FlightSqlClient* client_;
+};
+
+class FlightSqlStatementImpl : public arrow::RecordBatchReader {
+ public:
+ FlightSqlStatementImpl(std::shared_ptr<FlightSqlConnectionImpl> connection)
+ : connection_(std::move(connection)), info_() {}
+
+ void Init(std::unique_ptr<flight::FlightInfo> info) { info_ = std::move(info); }
+
//----------------------------------------------------------
// arrow::RecordBatchReader Methods
//----------------------------------------------------------
@@ -144,9 +207,23 @@ class FlightSqlStatementImpl : public arrow::RecordBatchReader {
// Statement Functions
//----------------------------------------------------------
+ AdbcStatusCode Execute(const std::shared_ptr<FlightSqlStatementImpl>& self,
+ struct AdbcError* error) {
+ flight::FlightCallOptions call_options;
+ std::unique_ptr<flight::FlightInfo> flight_info;
+ auto status =
+ connection_->client()->Execute(call_options, query_).Value(&flight_info);
+ if (!status.ok()) {
+ SetError(error, status);
+ return ADBC_STATUS_IO;
+ }
+ Init(std::move(flight_info));
+ return ADBC_STATUS_OK;
+ }
+
AdbcStatusCode GetStream(const std::shared_ptr<FlightSqlStatementImpl>& self,
struct ArrowArrayStream* out, struct AdbcError* error) {
- if (!client_) {
+ if (!info_) {
SetError(error, "Statement has not yet been executed");
return ADBC_STATUS_UNINITIALIZED;
}
@@ -168,8 +245,53 @@ class FlightSqlStatementImpl : public arrow::RecordBatchReader {
return ADBC_STATUS_OK;
}
+ AdbcStatusCode SetSqlQuery(const std::shared_ptr<FlightSqlStatementImpl>&,
+ const char* query, struct AdbcError* error) {
+ query_ = query;
+ return ADBC_STATUS_OK;
+ }
+
+ //----------------------------------------------------------
+ // Metadata
+ //----------------------------------------------------------
+
+ AdbcStatusCode GetTableTypes(struct AdbcError* error) {
+ flight::FlightCallOptions call_options;
+ std::unique_ptr<flight::FlightInfo> flight_info;
+ auto status = connection_->client()->GetTableTypes(call_options).Value(&flight_info);
+ if (!status.ok()) {
+ SetError(error, status);
+ return ADBC_STATUS_IO;
+ }
+ Init(std::move(flight_info));
+ return ADBC_STATUS_OK;
+ }
+
+ //----------------------------------------------------------
+ // Partitioned Results
+ //----------------------------------------------------------
+
+ AdbcStatusCode DeserializePartitionDesc(const uint8_t* serialized_partition,
+ size_t serialized_length,
+ struct AdbcError* error) {
+ std::vector<flight::FlightEndpoint> endpoints(1);
+ endpoints[0].ticket.ticket = std::string(
+ reinterpret_cast<const char*>(serialized_partition), serialized_length);
+ auto maybe_info = flight::FlightInfo::Make(
+ *arrow::schema({}), flight::FlightDescriptor::Command(""), endpoints,
+ /*total_records=*/-1, /*total_bytes=*/-1);
+ if (!maybe_info.ok()) {
+ SetError(error, maybe_info.status());
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ std::unique_ptr<flight::FlightInfo> flight_info(
+ new flight::FlightInfo(maybe_info.MoveValueUnsafe()));
+ Init(std::move(flight_info));
+ return ADBC_STATUS_OK;
+ }
+
AdbcStatusCode GetPartitionDescSize(size_t* length, struct AdbcError* error) const {
- if (!client_) {
+ if (!info_) {
SetError(error, "Statement has not yet been executed");
return ADBC_STATUS_UNINITIALIZED;
}
@@ -184,7 +306,7 @@ class FlightSqlStatementImpl : public arrow::RecordBatchReader {
}
AdbcStatusCode GetPartitionDesc(uint8_t* partition_desc, struct AdbcError* error) {
- if (!client_) {
+ if (!info_) {
SetError(error, "Statement has not yet been executed");
return ADBC_STATUS_UNINITIALIZED;
}
@@ -206,9 +328,9 @@ class FlightSqlStatementImpl : public arrow::RecordBatchReader {
}
// TODO: this needs to account for location
flight::FlightCallOptions call_options;
- ARROW_ASSIGN_OR_RAISE(
- current_stream_,
- client_->DoGet(call_options, info_->endpoints()[next_endpoint_].ticket));
+ ARROW_ASSIGN_OR_RAISE(current_stream_,
+ connection_->client()->DoGet(
+ call_options, info_->endpoints()[next_endpoint_].ticket));
next_endpoint_++;
if (!schema_) {
ARROW_ASSIGN_OR_RAISE(schema_, current_stream_->GetSchema());
@@ -216,148 +338,38 @@ class FlightSqlStatementImpl : public arrow::RecordBatchReader {
return Status::OK();
}
- flightsql::FlightSqlClient* client_;
+ std::shared_ptr<FlightSqlConnectionImpl> connection_;
std::unique_ptr<flight::FlightInfo> info_;
+ std::string query_;
std::shared_ptr<arrow::Schema> schema_;
size_t next_endpoint_ = 0;
std::unique_ptr<flight::FlightStreamReader> current_stream_;
};
-class AdbcFlightSqlImpl {
- public:
- explicit AdbcFlightSqlImpl(std::shared_ptr<FlightSqlDatabaseImpl> database)
- : database_(std::move(database)), client_(database_->Connect()) {}
-
- //----------------------------------------------------------
- // Common Functions
- //----------------------------------------------------------
-
- AdbcStatusCode Close(struct AdbcError* error) { return database_->Disconnect(error); }
-
- //----------------------------------------------------------
- // Metadata
- //----------------------------------------------------------
-
- AdbcStatusCode GetTableTypes(struct AdbcStatement* out, struct AdbcError* error) {
- if (!out->private_data) {
- SetError(error, "Statement is uninitialized, use AdbcStatementInit");
- return ADBC_STATUS_UNINITIALIZED;
- }
- auto* ptr =
- reinterpret_cast<std::shared_ptr<FlightSqlStatementImpl>*>(out->private_data);
- auto* impl = ptr->get();
-
- flight::FlightCallOptions call_options;
- std::unique_ptr<flight::FlightInfo> flight_info;
- auto status = client_->GetTableTypes(call_options).Value(&flight_info);
- if (!status.ok()) {
- SetError(error, status);
- return ADBC_STATUS_IO;
- }
- impl->Init(client_, std::move(flight_info));
- return ADBC_STATUS_OK;
- }
-
- //----------------------------------------------------------
- // SQL Semantics
- //----------------------------------------------------------
-
- AdbcStatusCode SqlExecute(const char* query, struct AdbcStatement* out,
- struct AdbcError* error) {
- if (!out->private_data) {
- SetError(error, "Statement is uninitialized, use AdbcStatementInit");
- return ADBC_STATUS_UNINITIALIZED;
- }
- auto* ptr =
- reinterpret_cast<std::shared_ptr<FlightSqlStatementImpl>*>(out->private_data);
- auto* impl = ptr->get();
-
- flight::FlightCallOptions call_options;
- std::unique_ptr<flight::FlightInfo> flight_info;
- auto status = client_->Execute(call_options, std::string(query)).Value(&flight_info);
- if (!status.ok()) {
- SetError(error, status);
- return ADBC_STATUS_IO;
- }
- impl->Init(client_, std::move(flight_info));
- return ADBC_STATUS_OK;
- }
-
- //----------------------------------------------------------
- // Partitioned Results
- //----------------------------------------------------------
-
- AdbcStatusCode DeserializePartitionDesc(const uint8_t* serialized_partition,
- size_t serialized_length,
- struct AdbcStatement* out,
- struct AdbcError* error) {
- if (!out->private_data) {
- SetError(error, "Statement is uninitialized, use AdbcStatementInit");
- return ADBC_STATUS_UNINITIALIZED;
- }
- auto* ptr =
- reinterpret_cast<std::shared_ptr<FlightSqlStatementImpl>*>(out->private_data);
- auto* impl = ptr->get();
-
- std::vector<flight::FlightEndpoint> endpoints(1);
- endpoints[0].ticket.ticket = std::string(
- reinterpret_cast<const char*>(serialized_partition), serialized_length);
- auto maybe_info = flight::FlightInfo::Make(
- *arrow::schema({}), flight::FlightDescriptor::Command(""), endpoints,
- /*total_records=*/-1, /*total_bytes=*/-1);
- if (!maybe_info.ok()) {
- SetError(error, maybe_info.status());
- return ADBC_STATUS_INVALID_ARGUMENT;
- }
- std::unique_ptr<flight::FlightInfo> flight_info(
- new flight::FlightInfo(maybe_info.MoveValueUnsafe()));
-
- impl->Init(client_, std::move(flight_info));
- return ADBC_STATUS_OK;
- }
-
- private:
- std::shared_ptr<FlightSqlDatabaseImpl> database_;
- flightsql::FlightSqlClient* client_;
-};
-
} // namespace
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcDatabaseInit(const struct AdbcDatabaseOptions* options,
- struct AdbcDatabase* out, struct AdbcError* error) {
- std::unordered_map<std::string, std::string> option_pairs;
- auto status = adbc::ParseConnectionString(arrow::util::string_view(options->target))
- .Value(&option_pairs);
- if (!status.ok()) {
- SetError(error, status);
- return ADBC_STATUS_INVALID_ARGUMENT;
- }
- auto location_it = option_pairs.find("Location");
- if (location_it == option_pairs.end()) {
- SetError(error, Status::Invalid("Must provide Location option"));
- return ADBC_STATUS_INVALID_ARGUMENT;
- }
-
- flight::Location location;
- status = flight::Location::Parse(location_it->second).Value(&location);
- if (!status.ok()) {
- SetError(error, status);
- return ADBC_STATUS_INVALID_ARGUMENT;
- }
+AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) {
+ auto impl = std::make_shared<FlightSqlDatabaseImpl>();
+ database->private_data = new std::shared_ptr<FlightSqlDatabaseImpl>(impl);
+ return ADBC_STATUS_OK;
+}
- std::unique_ptr<flight::FlightClient> flight_client;
- status = flight::FlightClient::Connect(location).Value(&flight_client);
- if (!status.ok()) {
- SetError(error, status);
- return ADBC_STATUS_IO;
- }
- std::unique_ptr<flightsql::FlightSqlClient> client(
- new flightsql::FlightSqlClient(std::move(flight_client)));
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error) {
+ if (!database || !database->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>(database->private_data);
+ return (*ptr)->SetOption(key, value, error);
+}
- auto impl = std::make_shared<FlightSqlDatabaseImpl>(std::move(client));
- out->private_data = new std::shared_ptr<FlightSqlDatabaseImpl>(impl);
- return ADBC_STATUS_OK;
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) {
+ if (!database->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>(database->private_data);
+ return (*ptr)->Init(error);
}
ADBC_DRIVER_EXPORT
@@ -378,43 +390,54 @@ AdbcStatusCode AdbcConnectionDeserializePartitionDesc(struct AdbcConnection* con
size_t serialized_length,
struct AdbcStatement* statement,
struct AdbcError* error) {
- if (!connection->private_data) return ADBC_STATUS_UNINITIALIZED;
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
auto* ptr =
- reinterpret_cast<std::shared_ptr<AdbcFlightSqlImpl>*>(connection->private_data);
- return (*ptr)->DeserializePartitionDesc(serialized_partition, serialized_length,
- statement, error);
+ reinterpret_cast<std::shared_ptr<FlightSqlStatementImpl>*>(statement->private_data);
+ return (*ptr)->DeserializePartitionDesc(serialized_partition, serialized_length, error);
}
ADBC_DRIVER_EXPORT
AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
struct AdbcStatement* statement,
struct AdbcError* error) {
- if (!connection->private_data) return ADBC_STATUS_UNINITIALIZED;
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
auto* ptr =
- reinterpret_cast<std::shared_ptr<AdbcFlightSqlImpl>*>(connection->private_data);
- return (*ptr)->GetTableTypes(statement, error);
+ reinterpret_cast<std::shared_ptr<FlightSqlStatementImpl>*>(statement->private_data);
+ return (*ptr)->GetTableTypes(error);
}
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcConnectionInit(const struct AdbcConnectionOptions* options,
- struct AdbcConnection* out, struct AdbcError* error) {
- if (!options->database || !options->database->private_data) {
- SetError(error, "Must provide database");
- return ADBC_STATUS_INVALID_ARGUMENT;
- }
- auto ptr = reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>(
- options->database->private_data);
- auto impl = std::make_shared<AdbcFlightSqlImpl>(*ptr);
- out->private_data = new std::shared_ptr<AdbcFlightSqlImpl>(impl);
+AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
+ struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<FlightSqlDatabaseImpl>*>(database->private_data);
+ auto impl = std::make_shared<FlightSqlConnectionImpl>(*ptr);
+ connection->private_data = new std::shared_ptr<FlightSqlConnectionImpl>(impl);
return ADBC_STATUS_OK;
}
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key,
+ const char* value, struct AdbcError* error) {
+ return ADBC_STATUS_OK;
+}
+
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto ptr = reinterpret_cast<std::shared_ptr<FlightSqlConnectionImpl>*>(
+ connection->private_data);
+ return (*ptr)->Init(error);
+}
+
ADBC_DRIVER_EXPORT
AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
struct AdbcError* error) {
if (!connection->private_data) return ADBC_STATUS_UNINITIALIZED;
- auto* ptr =
- reinterpret_cast<std::shared_ptr<AdbcFlightSqlImpl>*>(connection->private_data);
+ auto* ptr = reinterpret_cast<std::shared_ptr<FlightSqlConnectionImpl>*>(
+ connection->private_data);
auto status = (*ptr)->Close(error);
delete ptr;
connection->private_data = nullptr;
@@ -422,13 +445,12 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
}
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcConnectionSqlExecute(struct AdbcConnection* connection,
- const char* query, struct AdbcStatement* out,
- struct AdbcError* error) {
- if (!connection->private_data) return ADBC_STATUS_UNINITIALIZED;
+AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
auto* ptr =
- reinterpret_cast<std::shared_ptr<AdbcFlightSqlImpl>*>(connection->private_data);
- return (*ptr)->SqlExecute(query, out, error);
+ reinterpret_cast<std::shared_ptr<FlightSqlStatementImpl>*>(statement->private_data);
+ return (*ptr)->Execute(*ptr, error);
}
ADBC_DRIVER_EXPORT
@@ -462,10 +484,12 @@ AdbcStatusCode AdbcStatementGetStream(struct AdbcStatement* statement,
}
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcStatementInit(struct AdbcConnection* connection,
- struct AdbcStatement* statement,
- struct AdbcError* error) {
- auto impl = std::make_shared<FlightSqlStatementImpl>();
+AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ auto* ptr = reinterpret_cast<std::shared_ptr<FlightSqlConnectionImpl>*>(
+ connection->private_data);
+ auto impl = std::make_shared<FlightSqlStatementImpl>(*ptr);
statement->private_data = new std::shared_ptr<FlightSqlStatementImpl>(impl);
return ADBC_STATUS_OK;
}
@@ -482,6 +506,15 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
return status;
}
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
+ const char* query, struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<FlightSqlStatementImpl>*>(statement->private_data);
+ return (*ptr)->SetSqlQuery(*ptr, query, error);
+}
+
extern "C" {
ARROW_EXPORT
AdbcStatusCode AdbcFlightSqlDriverInit(size_t count, struct AdbcDriver* driver,
@@ -489,18 +522,25 @@ AdbcStatusCode AdbcFlightSqlDriverInit(size_t count, struct AdbcDriver* driver,
if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
std::memset(driver, 0, sizeof(*driver));
+ driver->DatabaseNew = AdbcDatabaseNew;
+ driver->DatabaseSetOption = AdbcDatabaseSetOption;
driver->DatabaseInit = AdbcDatabaseInit;
driver->DatabaseRelease = AdbcDatabaseRelease;
+
+ driver->ConnectionNew = AdbcConnectionNew;
+ driver->ConnectionSetOption = AdbcConnectionSetOption;
driver->ConnectionInit = AdbcConnectionInit;
driver->ConnectionRelease = AdbcConnectionRelease;
- driver->ConnectionSqlExecute = AdbcConnectionSqlExecute;
driver->ConnectionDeserializePartitionDesc = AdbcConnectionDeserializePartitionDesc;
driver->ConnectionGetTableTypes = AdbcConnectionGetTableTypes;
+
+ driver->StatementExecute = AdbcStatementExecute;
driver->StatementGetPartitionDesc = AdbcStatementGetPartitionDesc;
driver->StatementGetPartitionDescSize = AdbcStatementGetPartitionDescSize;
driver->StatementGetStream = AdbcStatementGetStream;
- driver->StatementInit = AdbcStatementInit;
+ driver->StatementNew = AdbcStatementNew;
driver->StatementRelease = AdbcStatementRelease;
+ driver->StatementSetSqlQuery = AdbcStatementSetSqlQuery;
*initialized = ADBC_VERSION_0_0_1;
return ADBC_STATUS_OK;
}
diff --git a/drivers/flight_sql/flight_sql_test.cc b/drivers/flight_sql/flight_sql_test.cc
index 04c8ce9..499477b 100644
--- a/drivers/flight_sql/flight_sql_test.cc
+++ b/drivers/flight_sql/flight_sql_test.cc
@@ -37,15 +37,12 @@ class AdbcFlightSqlTest : public ::testing::Test {
public:
void SetUp() override {
if (const char* location = std::getenv(kServerEnvVar.c_str())) {
- AdbcDatabaseOptions db_options;
- std::string target = "Location=";
- target += location;
- db_options.target = target.c_str();
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&db_options, &database, &error));
- AdbcConnectionOptions conn_options;
- conn_options.database = &database;
- ADBC_ASSERT_OK_WITH_ERROR(error,
- AdbcConnectionInit(&conn_options, &connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcDatabaseSetOption(&database, "location", location, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
} else {
FAIL() << "Must provide location of Flight SQL server at " << kServerEnvVar;
}
@@ -64,34 +61,33 @@ class AdbcFlightSqlTest : public ::testing::Test {
};
TEST_F(AdbcFlightSqlTest, Metadata) {
- {
- AdbcStatement statement;
- std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionGetTableTypes(&connection, &statement, &error));
-
- std::shared_ptr<arrow::Schema> schema;
- arrow::RecordBatchVector batches;
- ReadStatement(&statement, &schema, &batches);
- ASSERT_SCHEMA_EQ(
- *schema,
- *arrow::schema({arrow::field("table_type", arrow::utf8(), /*nullable=*/false)}));
- EXPECT_THAT(batches, ::testing::UnorderedPointwise(
- PointeesEqual(),
- {
- adbc::RecordBatchFromJSON(schema, R"([["table"]])"),
- }));
- }
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcConnectionGetTableTypes(&connection, &statement, &error));
+
+ std::shared_ptr<arrow::Schema> schema;
+ arrow::RecordBatchVector batches;
+ ReadStatement(&statement, &schema, &batches);
+ ASSERT_SCHEMA_EQ(
+ *schema,
+ *arrow::schema({arrow::field("table_type", arrow::utf8(), /*nullable=*/false)}));
+ EXPECT_THAT(batches, ::testing::UnorderedPointwise(
+ PointeesEqual(),
+ {
+ adbc::RecordBatchFromJSON(schema, R"([["table"]])"),
+ }));
}
TEST_F(AdbcFlightSqlTest, SqlExecute) {
std::string query = "SELECT 1";
AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlExecute(&connection, query.c_str(), &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
@@ -104,6 +100,18 @@ TEST_F(AdbcFlightSqlTest, SqlExecute) {
}));
}
+TEST_F(AdbcFlightSqlTest, SqlExecuteInvalid) {
+ std::string query = "INVALID";
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ASSERT_NE(AdbcStatementExecute(&statement, &error), ADBC_STATUS_OK);
+ ADBC_ASSERT_ERROR_THAT(error, ::testing::HasSubstr("syntax error"));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementRelease(&statement, &error));
+}
+
TEST_F(AdbcFlightSqlTest, Partitions) {
// Serialize the query result handle into a partition so it can be
// retrieved separately. (With multiple partitions we could
@@ -112,9 +120,10 @@ TEST_F(AdbcFlightSqlTest, Partitions) {
std::string query = "SELECT 42";
AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlExecute(&connection, query.c_str(), &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
std::vector<std::vector<uint8_t>> descs;
@@ -132,7 +141,7 @@ TEST_F(AdbcFlightSqlTest, Partitions) {
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementRelease(&statement, &error));
// Reconstruct the partition
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionDeserializePartitionDesc(
&connection, descs.back().data(),
descs.back().size(), &statement, &error));
@@ -148,15 +157,4 @@ TEST_F(AdbcFlightSqlTest, Partitions) {
}));
}
-TEST_F(AdbcFlightSqlTest, InvalidSql) {
- std::string query = "INVALID";
- AdbcStatement statement;
- std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ASSERT_NE(AdbcConnectionSqlExecute(&connection, query.c_str(), &statement, &error),
- ADBC_STATUS_OK);
- ADBC_ASSERT_ERROR_THAT(error, ::testing::HasSubstr("syntax error"));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementRelease(&statement, &error));
-}
-
} // namespace adbc
diff --git a/drivers/sqlite/sqlite.cc b/drivers/sqlite/sqlite.cc
index e55e0b5..8d8956a 100644
--- a/drivers/sqlite/sqlite.cc
+++ b/drivers/sqlite/sqlite.cc
@@ -21,6 +21,7 @@
#include <memory>
#include <mutex>
#include <string>
+#include <unordered_map>
#include <arrow/builder.h>
#include <arrow/c/bridge.h>
@@ -106,14 +107,44 @@ std::shared_ptr<arrow::Schema> StatementToSchema(sqlite3_stmt* stmt) {
class SqliteDatabaseImpl {
public:
- explicit SqliteDatabaseImpl(sqlite3* db) : db_(db), connection_count_(0) {}
+ explicit SqliteDatabaseImpl() : db_(nullptr), connection_count_(0) {}
sqlite3* Connect() {
std::lock_guard<std::mutex> guard(mutex_);
- ++connection_count_;
+ if (db_) ++connection_count_;
return db_;
}
+ AdbcStatusCode Init(struct AdbcError* error) {
+ if (db_) {
+ SetError(error, "Database already initialized");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ const char* filename = ":memory:";
+ auto it = options_.find("filename");
+ if (it != options_.end()) filename = it->second.c_str();
+
+ auto status = sqlite3_open_v2(
+ filename, &db_, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, /*zVfs=*/nullptr);
+ if (status != SQLITE_OK) {
+ if (db_) {
+ SetError(db_, "sqlite3_open_v2", error);
+ }
+ return ADBC_STATUS_IO;
+ }
+ options_.clear();
+ return ADBC_STATUS_OK;
+ }
+
+ AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error) {
+ if (db_) {
+ SetError(error, "Database already initialized");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ options_[key] = value;
+ return ADBC_STATUS_OK;
+ }
+
AdbcStatusCode Disconnect(struct AdbcError* error) {
std::lock_guard<std::mutex> guard(mutex_);
if (--connection_count_ < 0) {
@@ -143,19 +174,46 @@ class SqliteDatabaseImpl {
private:
sqlite3* db_;
int connection_count_;
+ std::unordered_map<std::string, std::string> options_;
std::mutex mutex_;
};
-class SqliteStatementImpl : public arrow::RecordBatchReader {
+class SqliteConnectionImpl {
public:
- SqliteStatementImpl()
- : db_(nullptr), stmt_(nullptr), schema_(nullptr), bind_index_(0), done_(false) {}
+ explicit SqliteConnectionImpl(std::shared_ptr<SqliteDatabaseImpl> database)
+ : database_(std::move(database)), db_(nullptr) {}
- void Init(sqlite3* db, sqlite3_stmt* stmt) {
- db_ = db;
- stmt_ = stmt;
+ //----------------------------------------------------------
+ // Common Functions
+ //----------------------------------------------------------
+
+ sqlite3* db() const { return db_; }
+
+ AdbcStatusCode Init(struct AdbcError* error) {
+ db_ = database_->Connect();
+ if (!db_) {
+ SetError(error, "Database not yet initialized!");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ return ADBC_STATUS_OK;
}
+ AdbcStatusCode Release(struct AdbcError* error) { return database_->Disconnect(error); }
+
+ private:
+ std::shared_ptr<SqliteDatabaseImpl> database_;
+ sqlite3* db_;
+};
+
+class SqliteStatementImpl : public arrow::RecordBatchReader {
+ public:
+ SqliteStatementImpl(std::shared_ptr<SqliteConnectionImpl> connection)
+ : connection_(std::move(connection)),
+ stmt_(nullptr),
+ schema_(nullptr),
+ bind_index_(0),
+ done_(false) {}
+
//----------------------------------------------------------
// arrow::RecordBatchReader
//----------------------------------------------------------
@@ -167,7 +225,6 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch) override {
constexpr int64_t kBatchSize = 1024;
-
if (done_) {
*batch = nullptr;
return Status::OK();
@@ -180,6 +237,8 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
schema_->field(i)->type(), &builders[i]));
}
+ sqlite3* db = connection_->db();
+
// The statement was stepped once at the start, so step at the end of the loop
int64_t num_rows = 0;
for (int64_t row = 0; row < kBatchSize; row++) {
@@ -219,7 +278,7 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
if (bind_parameters_ && bind_index_ < bind_parameters_->num_rows()) {
status = sqlite3_reset(stmt_);
if (status != SQLITE_OK) {
- return Status::IOError("[SQLite3] sqlite3_reset: ", sqlite3_errmsg(db_));
+ return Status::IOError("[SQLite3] sqlite3_reset: ", sqlite3_errmsg(db));
}
RETURN_NOT_OK(BindNext());
status = sqlite3_step(stmt_);
@@ -230,7 +289,7 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
}
break;
}
- return Status::IOError("[SQLite3] sqlite3_step: ", sqlite3_errmsg(db_));
+ return Status::IOError("[SQLite3] sqlite3_step: ", sqlite3_errmsg(db));
}
arrow::ArrayVector arrays(builders.size());
@@ -249,11 +308,12 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
if (stmt_) {
auto status = sqlite3_finalize(stmt_);
if (status != SQLITE_OK) {
- SetError(db_, "sqlite3_finalize", error);
+ SetError(connection_->db(), "sqlite3_finalize", error);
return ADBC_STATUS_UNKNOWN;
}
stmt_ = nullptr;
bind_parameters_.reset();
+ connection_.reset();
}
return ADBC_STATUS_OK;
}
@@ -275,24 +335,25 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
AdbcStatusCode Execute(const std::shared_ptr<SqliteStatementImpl>& self,
struct AdbcError* error) {
+ sqlite3* db = connection_->db();
int rc = 0;
if (schema_) {
rc = sqlite3_clear_bindings(stmt_);
if (rc != SQLITE_OK) {
- SetError(db_, "sqlite3_reset_bindings", error);
+ SetError(db, "sqlite3_reset_bindings", error);
rc = sqlite3_finalize(stmt_);
if (rc != SQLITE_OK) {
- SetError(db_, "sqlite3_finalize", error);
+ SetError(db, "sqlite3_finalize", error);
}
return ADBC_STATUS_IO;
}
rc = sqlite3_reset(stmt_);
if (rc != SQLITE_OK) {
- SetError(db_, "sqlite3_reset", error);
+ SetError(db, "sqlite3_reset", error);
rc = sqlite3_finalize(stmt_);
if (rc != SQLITE_OK) {
- SetError(db_, "sqlite3_finalize", error);
+ SetError(db, "sqlite3_finalize", error);
}
return ADBC_STATUS_IO;
}
@@ -308,19 +369,19 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
SetError(error, status);
return ADBC_STATUS_IO;
} else if (rc != SQLITE_OK) {
- SetError(db_, "sqlite3_bind", error);
+ SetError(db, "sqlite3_bind", error);
rc = sqlite3_finalize(stmt_);
if (rc != SQLITE_OK) {
- SetError(db_, "sqlite3_finalize", error);
+ SetError(db, "sqlite3_finalize", error);
}
return ADBC_STATUS_IO;
}
rc = sqlite3_step(stmt_);
if (rc == SQLITE_ERROR) {
- SetError(db_, "sqlite3_step", error);
+ SetError(db, "sqlite3_step", error);
rc = sqlite3_finalize(stmt_);
if (rc != SQLITE_OK) {
- SetError(db_, "sqlite3_finalize", error);
+ SetError(db, "sqlite3_finalize", error);
}
return ADBC_STATUS_IO;
}
@@ -343,6 +404,24 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
return ADBC_STATUS_OK;
}
+ AdbcStatusCode SetSqlQuery(const std::shared_ptr<SqliteStatementImpl>& self,
+ const char* query, struct AdbcError* error) {
+ sqlite3* db = connection_->db();
+ int rc = sqlite3_prepare_v2(db, query, static_cast<int>(std::strlen(query)), &stmt_,
+ /*pzTail=*/nullptr);
+ if (rc != SQLITE_OK) {
+ if (stmt_) {
+ rc = sqlite3_finalize(stmt_);
+ if (rc != SQLITE_OK) {
+ SetError(db, "sqlite3_finalize", error);
+ }
+ }
+ SetError(db, "sqlite3_prepare_v2", error);
+ return ADBC_STATUS_UNKNOWN;
+ }
+ return ADBC_STATUS_OK;
+ }
+
private:
arrow::Result<int> BindNext() {
if (!bind_parameters_ || bind_index_ >= bind_parameters_->num_rows()) {
@@ -384,7 +463,7 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
return SQLITE_OK;
}
- sqlite3* db_;
+ std::shared_ptr<SqliteConnectionImpl> connection_;
sqlite3_stmt* stmt_;
std::shared_ptr<arrow::Schema> schema_;
std::shared_ptr<arrow::RecordBatch> bind_parameters_;
@@ -392,83 +471,32 @@ class SqliteStatementImpl : public arrow::RecordBatchReader {
bool done_;
};
-class SqliteConnectionImpl {
- public:
- explicit SqliteConnectionImpl(std::shared_ptr<SqliteDatabaseImpl> database)
- : database_(std::move(database)), db_(database_->Connect()) {}
-
- //----------------------------------------------------------
- // Common Functions
- //----------------------------------------------------------
-
- AdbcStatusCode Release(struct AdbcError* error) { return database_->Disconnect(error); }
-
- //----------------------------------------------------------
- // SQL Semantics
- //----------------------------------------------------------
-
- AdbcStatusCode SqlExecute(const char* query, struct AdbcStatement* out,
- struct AdbcError* error) {
- auto status = SqlPrepare(query, out, error);
- if (status != ADBC_STATUS_OK) return status;
-
- return AdbcStatementExecute(out, error);
- }
-
- AdbcStatusCode SqlPrepare(const char* query, struct AdbcStatement* out,
- struct AdbcError* error) {
- if (!out->private_data) {
- SetError(error, "Statement is uninitialized, use AdbcStatementInit");
- return ADBC_STATUS_UNINITIALIZED;
- }
- auto* ptr =
- reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(out->private_data);
- auto* impl = ptr->get();
-
- sqlite3_stmt* stmt = nullptr;
- auto rc = sqlite3_prepare_v2(db_, query, static_cast<int>(std::strlen(query)), &stmt,
- /*pzTail=*/nullptr);
- if (rc != SQLITE_OK) {
- if (stmt) {
- rc = sqlite3_finalize(stmt);
- if (rc != SQLITE_OK) {
- SetError(db_, "sqlite3_finalize", error);
- return ADBC_STATUS_UNKNOWN;
- }
- }
- SetError(db_, "sqlite3_prepare_v2", error);
- return ADBC_STATUS_UNKNOWN;
- }
-
- impl->Init(db_, stmt);
- return ADBC_STATUS_OK;
- }
-
- private:
- std::shared_ptr<SqliteDatabaseImpl> database_;
- sqlite3* db_;
-};
-
} // namespace
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcDatabaseInit(const struct AdbcDatabaseOptions* options,
- struct AdbcDatabase* out, struct AdbcError* error) {
- sqlite3* db = nullptr;
- auto status = sqlite3_open_v2(
- ":memory:", &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, /*zVfs=*/nullptr);
- if (status != SQLITE_OK) {
- if (db) {
- SetError(db, "sqlite3_open_v2", error);
- }
- return ADBC_STATUS_UNKNOWN;
- }
-
- auto impl = std::make_shared<SqliteDatabaseImpl>(db);
- out->private_data = new std::shared_ptr<SqliteDatabaseImpl>(impl);
+AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) {
+ auto impl = std::make_shared<SqliteDatabaseImpl>();
+ database->private_data = new std::shared_ptr<SqliteDatabaseImpl>(impl);
return ADBC_STATUS_OK;
}
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error) {
+ if (!database || !database->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<SqliteDatabaseImpl>*>(database->private_data);
+ return (*ptr)->SetOption(key, value, error);
+}
+
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) {
+ if (!database->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<SqliteDatabaseImpl>*>(database->private_data);
+ return (*ptr)->Init(error);
+}
+
ADBC_DRIVER_EXPORT
AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
struct AdbcError* error) {
@@ -482,19 +510,31 @@ AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
}
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcConnectionInit(const struct AdbcConnectionOptions* options,
- struct AdbcConnection* out, struct AdbcError* error) {
- if (!options->database || !options->database->private_data) {
- SetError(error, "Must provide database");
- return ADBC_STATUS_INVALID_ARGUMENT;
- }
- auto ptr = reinterpret_cast<std::shared_ptr<SqliteDatabaseImpl>*>(
- options->database->private_data);
+AdbcStatusCode AdbcConnectionNew(struct AdbcDatabase* database,
+ struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<SqliteDatabaseImpl>*>(database->private_data);
auto impl = std::make_shared<SqliteConnectionImpl>(*ptr);
- out->private_data = new std::shared_ptr<SqliteConnectionImpl>(impl);
+ connection->private_data = new std::shared_ptr<SqliteConnectionImpl>(impl);
return ADBC_STATUS_OK;
}
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key,
+ const char* value, struct AdbcError* error) {
+ return ADBC_STATUS_OK;
+}
+
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<SqliteConnectionImpl>*>(connection->private_data);
+ return (*ptr)->Init(error);
+}
+
ADBC_DRIVER_EXPORT
AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
struct AdbcError* error) {
@@ -508,23 +548,22 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
}
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcConnectionSqlExecute(struct AdbcConnection* connection,
- const char* query, struct AdbcStatement* out,
- struct AdbcError* error) {
- if (!connection->private_data) return ADBC_STATUS_UNINITIALIZED;
+AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
+ struct ArrowArray* values, struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
auto* ptr =
- reinterpret_cast<std::shared_ptr<SqliteConnectionImpl>*>(connection->private_data);
- return (*ptr)->SqlExecute(query, out, error);
+ reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
+ return (*ptr)->Bind(*ptr, values, schema, error);
}
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcConnectionSqlPrepare(struct AdbcConnection* connection,
- const char* query, struct AdbcStatement* out,
- struct AdbcError* error) {
- if (!connection->private_data) return ADBC_STATUS_UNINITIALIZED;
+AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
auto* ptr =
- reinterpret_cast<std::shared_ptr<SqliteConnectionImpl>*>(connection->private_data);
- return (*ptr)->SqlPrepare(query, out, error);
+ reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
+ return (*ptr)->Execute(*ptr, error);
}
ADBC_DRIVER_EXPORT
@@ -541,25 +580,6 @@ AdbcStatusCode AdbcStatementGetPartitionDescSize(struct AdbcStatement* statement
return ADBC_STATUS_NOT_IMPLEMENTED;
}
-ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
- struct ArrowArray* values, struct ArrowSchema* schema,
- struct AdbcError* error) {
- if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
- auto* ptr =
- reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
- return (*ptr)->Bind(*ptr, values, schema, error);
-}
-
-ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
- struct AdbcError* error) {
- if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
- auto* ptr =
- reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
- return (*ptr)->Execute(*ptr, error);
-}
-
ADBC_DRIVER_EXPORT
AdbcStatusCode AdbcStatementGetStream(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
@@ -571,14 +591,24 @@ AdbcStatusCode AdbcStatementGetStream(struct AdbcStatement* statement,
}
ADBC_DRIVER_EXPORT
-AdbcStatusCode AdbcStatementInit(struct AdbcConnection* connection,
- struct AdbcStatement* statement,
- struct AdbcError* error) {
- auto impl = std::make_shared<SqliteStatementImpl>();
+AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ auto conn_ptr =
+ reinterpret_cast<std::shared_ptr<SqliteConnectionImpl>*>(connection->private_data);
+ auto impl = std::make_shared<SqliteStatementImpl>(*conn_ptr);
statement->private_data = new std::shared_ptr<SqliteStatementImpl>(impl);
return ADBC_STATUS_OK;
}
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
+ // No-op
+ return ADBC_STATUS_OK;
+}
+
ADBC_DRIVER_EXPORT
AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
struct AdbcError* error) {
@@ -591,6 +621,15 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
return status;
}
+ADBC_DRIVER_EXPORT
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
+ const char* query, struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_UNINITIALIZED;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<SqliteStatementImpl>*>(statement->private_data);
+ return (*ptr)->SetSqlQuery(*ptr, query, error);
+}
+
extern "C" {
ARROW_EXPORT
AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct AdbcDriver* driver,
@@ -599,18 +638,24 @@ AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct AdbcDriver* driver,
std::memset(driver, 0, sizeof(*driver));
driver->DatabaseInit = AdbcDatabaseInit;
+ driver->DatabaseNew = AdbcDatabaseNew;
driver->DatabaseRelease = AdbcDatabaseRelease;
+ driver->DatabaseSetOption = AdbcDatabaseSetOption;
+
driver->ConnectionInit = AdbcConnectionInit;
+ driver->ConnectionNew = AdbcConnectionNew;
driver->ConnectionRelease = AdbcConnectionRelease;
- driver->ConnectionSqlExecute = AdbcConnectionSqlExecute;
- driver->ConnectionSqlPrepare = AdbcConnectionSqlPrepare;
+ driver->ConnectionSetOption = AdbcConnectionSetOption;
+
driver->StatementBind = AdbcStatementBind;
driver->StatementExecute = AdbcStatementExecute;
driver->StatementGetPartitionDesc = AdbcStatementGetPartitionDesc;
driver->StatementGetPartitionDescSize = AdbcStatementGetPartitionDescSize;
driver->StatementGetStream = AdbcStatementGetStream;
- driver->StatementInit = AdbcStatementInit;
+ driver->StatementNew = AdbcStatementNew;
+ driver->StatementPrepare = AdbcStatementPrepare;
driver->StatementRelease = AdbcStatementRelease;
+ driver->StatementSetSqlQuery = AdbcStatementSetSqlQuery;
*initialized = ADBC_VERSION_0_0_1;
return ADBC_STATUS_OK;
}
diff --git a/drivers/sqlite/sqlite_test.cc b/drivers/sqlite/sqlite_test.cc
index 5296c8a..e3b4f6e 100644
--- a/drivers/sqlite/sqlite_test.cc
+++ b/drivers/sqlite/sqlite_test.cc
@@ -31,188 +31,141 @@ namespace adbc {
using arrow::PointeesEqual;
-TEST(Adbc, Basics) {
- AdbcDatabase database;
- AdbcConnection connection;
- AdbcError error = {};
-
- {
- AdbcDatabaseOptions options;
- std::memset(&options, 0, sizeof(options));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&options, &database, &error));
+class Sqlite : public ::testing::Test {
+ public:
+ void SetUp() override {
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcDatabaseSetOption(&database, "filename", ":memory:", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
ASSERT_NE(database.private_data, nullptr);
- }
- {
- AdbcConnectionOptions options;
- std::memset(&options, 0, sizeof(options));
- options.database = &database;
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&options, &connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection, &error));
ASSERT_NE(connection.private_data, nullptr);
}
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection, &error));
- ASSERT_EQ(connection.private_data, nullptr);
+ void TearDown() override {
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection, &error));
+ ASSERT_EQ(connection.private_data, nullptr);
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
- ASSERT_EQ(database.private_data, nullptr);
-}
-
-TEST(AdbcSqlite, SqlExecute) {
- AdbcDatabase database;
- AdbcConnection connection;
- AdbcError error = {};
-
- {
- AdbcDatabaseOptions options;
- std::memset(&options, 0, sizeof(options));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&options, &database, &error));
- }
- {
- AdbcConnectionOptions options;
- std::memset(&options, 0, sizeof(options));
- options.database = &database;
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&options, &connection, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
+ ASSERT_EQ(database.private_data, nullptr);
}
- {
- std::string query = "SELECT 1";
- AdbcStatement statement;
- std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlExecute(&connection, query.c_str(), &statement, &error));
-
- std::shared_ptr<arrow::Schema> schema;
- arrow::RecordBatchVector batches;
- ReadStatement(&statement, &schema, &batches);
- ASSERT_SCHEMA_EQ(*schema, *arrow::schema({arrow::field("1", arrow::int64())}));
- EXPECT_THAT(batches,
- ::testing::UnorderedPointwise(
- PointeesEqual(), {
- adbc::RecordBatchFromJSON(schema, "[[1]]"),
- }));
- }
-
- {
- std::string query = "INVALID";
- AdbcStatement statement;
- std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ASSERT_NE(AdbcConnectionSqlExecute(&connection, query.c_str(), &statement, &error),
- ADBC_STATUS_OK);
- ADBC_ASSERT_ERROR_THAT(
- error, ::testing::AllOf(::testing::HasSubstr("[SQLite3] sqlite3_prepare_v2:"),
- ::testing::HasSubstr("syntax error")));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementRelease(&statement, &error));
- }
-
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection, &error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
-}
-
-TEST(AdbcSqlite, SqlPrepare) {
+ protected:
AdbcDatabase database;
AdbcConnection connection;
AdbcError error = {};
+};
+
+TEST_F(Sqlite, SqlExecute) {
+ std::string query = "SELECT 1";
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+ std::shared_ptr<arrow::Schema> schema;
+ arrow::RecordBatchVector batches;
+ ReadStatement(&statement, &schema, &batches);
+ ASSERT_SCHEMA_EQ(*schema, *arrow::schema({arrow::field("1", arrow::int64())}));
+ EXPECT_THAT(batches,
+ ::testing::UnorderedPointwise(
+ PointeesEqual(), {
+ adbc::RecordBatchFromJSON(schema, "[[1]]"),
+ }));
+}
- {
- AdbcDatabaseOptions options;
- std::memset(&options, 0, sizeof(options));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&options, &database, &error));
- }
- {
- AdbcConnectionOptions options;
- std::memset(&options, 0, sizeof(options));
- options.database = &database;
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&options, &connection, &error));
- }
-
- {
- std::string query = "SELECT 1";
- AdbcStatement statement;
- std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlPrepare(&connection, query.c_str(), &statement, &error));
-
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
-
- std::shared_ptr<arrow::Schema> schema;
- arrow::RecordBatchVector batches;
- ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
- ASSERT_SCHEMA_EQ(*schema, *arrow::schema({arrow::field("1", arrow::int64())}));
- EXPECT_THAT(batches,
- ::testing::UnorderedPointwise(
- PointeesEqual(), {
- adbc::RecordBatchFromJSON(schema, "[[1]]"),
- }));
- }
-
- {
- auto param_schema = arrow::schema(
- {arrow::field("1", arrow::int64()), arrow::field("2", arrow::utf8())});
- std::string query = "SELECT ?, ?";
- AdbcStatement statement;
- ArrowArray export_params;
- ArrowSchema export_schema;
- std::memset(&statement, 0, sizeof(statement));
-
- ASSERT_OK(ExportRecordBatch(
- *adbc::RecordBatchFromJSON(param_schema, R"([[1, "foo"], [2, "bar"]])"),
- &export_params));
- ASSERT_OK(ExportSchema(*param_schema, &export_schema));
-
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection, &statement, &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlPrepare(&connection, query.c_str(), &statement, &error));
-
- ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcStatementBind(&statement, &export_params, &export_schema, &error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+TEST_F(Sqlite, SqlExecuteInvalid) {
+ std::string query = "INVALID";
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ASSERT_NE(AdbcStatementSetSqlQuery(&statement, query.c_str(), &error), ADBC_STATUS_OK);
+ ADBC_ASSERT_ERROR_THAT(
+ error, ::testing::AllOf(::testing::HasSubstr("[SQLite3] sqlite3_prepare_v2:"),
+ ::testing::HasSubstr("syntax error")));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementRelease(&statement, &error));
+}
- std::shared_ptr<arrow::Schema> schema;
- arrow::RecordBatchVector batches;
- ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
- ASSERT_SCHEMA_EQ(*schema, *arrow::schema({arrow::field("?", arrow::int64()),
- arrow::field("?", arrow::utf8())}));
- EXPECT_THAT(batches,
- ::testing::UnorderedPointwise(
- PointeesEqual(),
- {
- adbc::RecordBatchFromJSON(schema, R"([[1, "foo"], [2, "bar"]])"),
- }));
- }
+TEST_F(Sqlite, SqlPrepare) {
+ std::string query = "SELECT 1";
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementPrepare(&statement, &error));
+
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+ std::shared_ptr<arrow::Schema> schema;
+ arrow::RecordBatchVector batches;
+ ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+ ASSERT_SCHEMA_EQ(*schema, *arrow::schema({arrow::field("1", arrow::int64())}));
+ EXPECT_THAT(batches,
+ ::testing::UnorderedPointwise(
+ PointeesEqual(), {
+ adbc::RecordBatchFromJSON(schema, "[[1]]"),
+ }));
+}
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection, &error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
+TEST_F(Sqlite, SqlPrepareMultipleParams) {
+ auto param_schema = arrow::schema(
+ {arrow::field("1", arrow::int64()), arrow::field("2", arrow::utf8())});
+ std::string query = "SELECT ?, ?";
+ AdbcStatement statement;
+ ArrowArray export_params;
+ ArrowSchema export_schema;
+ std::memset(&statement, 0, sizeof(statement));
+
+ ASSERT_OK(ExportRecordBatch(
+ *adbc::RecordBatchFromJSON(param_schema, R"([[1, "foo"], [2, "bar"]])"),
+ &export_params));
+ ASSERT_OK(ExportSchema(*param_schema, &export_schema));
+
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error,
+ AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementPrepare(&statement, &error));
+
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementBind(&statement, &export_params, &export_schema, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+ std::shared_ptr<arrow::Schema> schema;
+ arrow::RecordBatchVector batches;
+ ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+ ASSERT_SCHEMA_EQ(*schema, *arrow::schema({arrow::field("?", arrow::int64()),
+ arrow::field("?", arrow::utf8())}));
+ EXPECT_THAT(batches,
+ ::testing::UnorderedPointwise(
+ PointeesEqual(),
+ {
+ adbc::RecordBatchFromJSON(schema, R"([[1, "foo"], [2, "bar"]])"),
+ }));
}
-TEST(AdbcSqlite, MultipleConnections) {
- AdbcDatabase database;
- AdbcConnection connection1, connection2;
- AdbcError error = {};
+TEST_F(Sqlite, MultipleConnections) {
+ struct AdbcConnection connection2;
{
- AdbcDatabaseOptions options;
- std::memset(&options, 0, sizeof(options));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&options, &database, &error));
- }
- {
- AdbcConnectionOptions options;
- std::memset(&options, 0, sizeof(options));
- options.database = &database;
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&options, &connection1, &error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&options, &connection2, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionNew(&database, &connection2, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionInit(&connection2, &error));
+ ASSERT_NE(connection.private_data, nullptr);
}
{
std::string query = "CREATE TABLE foo (bar INTEGER)";
AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection1, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlExecute(&connection1, query.c_str(), &statement, &error));
+ error, AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
@@ -223,15 +176,14 @@ TEST(AdbcSqlite, MultipleConnections) {
PointeesEqual(), std::vector<std::shared_ptr<arrow::RecordBatch>>{}));
}
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection1, &error));
-
{
std::string query = "SELECT * FROM foo";
AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementInit(&connection2, &statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement, &error));
ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcConnectionSqlExecute(&connection2, query.c_str(), &statement, &error));
+ error, AdbcStatementSetSqlQuery(&statement, query.c_str(), &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
std::shared_ptr<arrow::Schema> schema;
arrow::RecordBatchVector batches;
@@ -243,7 +195,6 @@ TEST(AdbcSqlite, MultipleConnections) {
}
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcConnectionRelease(&connection2, &error));
- ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
}
} // namespace adbc
diff --git a/drivers/test_util.h b/drivers/test_util.h
index 0123a5b..8cc06af 100644
--- a/drivers/test_util.h
+++ b/drivers/test_util.h
@@ -49,7 +49,7 @@ namespace adbc {
do { \
ASSERT_NE(ERROR.message, nullptr); \
std::string errmsg_ = ERROR.message ? ERROR.message : "(unknown error)"; \
- if (ERROR.message) error.release(&error); \
+ if (ERROR.message) error.release(&error); \
ASSERT_THAT(errmsg_, PATTERN) << errmsg_; \
} while (false)
diff --git a/drivers/util.h b/drivers/util.h
index 7552dfd..df49f9f 100644
--- a/drivers/util.h
+++ b/drivers/util.h
@@ -34,6 +34,12 @@
#define ADBC_DRIVER_EXPORT
#endif // ifdef __linux__
+#define ADBC_RETURN_NOT_OK(expr) \
+ do { \
+ auto _s = (expr); \
+ if (_s != ADBC_STATUS_OK) return _s; \
+ } while (false)
+
namespace adbc {
/// \brief Parse a connection string.