You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/08/29 16:16:02 UTC
[arrow-adbc] branch main updated: [C] Specify a standard entrypoint name (#93)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new aa4e798 [C] Specify a standard entrypoint name (#93)
aa4e798 is described below
commit aa4e798f16d1035b4efc1bf8c90dc3b61b6d6432
Author: David Li <li...@gmail.com>
AuthorDate: Mon Aug 29 12:15:58 2022 -0400
[C] Specify a standard entrypoint name (#93)
Fixes #86.
---
adbc.h | 3 ++
c/driver_manager/adbc_driver_manager.cc | 28 +++++++--------
c/driver_manager/adbc_driver_manager_test.cc | 42 ++++++++++++++--------
c/drivers/flight_sql/flight_sql.cc | 4 +--
c/drivers/postgres/postgres.cc | 4 +--
c/drivers/sqlite/sqlite.cc | 4 +--
.../adbc_driver_manager/_lib.pyx | 4 +--
.../adbc_driver_manager/dbapi.py | 10 +++---
.../adbc_driver_manager/tests/test_dbapi.py | 5 +--
.../adbc_driver_manager/tests/test_lowlevel.py | 5 +--
10 files changed, 61 insertions(+), 48 deletions(-)
diff --git a/adbc.h b/adbc.h
index f3f33d3..e1c83b1 100644
--- a/adbc.h
+++ b/adbc.h
@@ -1046,6 +1046,9 @@ struct ADBC_EXPORT AdbcDriver {
/// to load a library and call a function of this type to load the
/// driver.
///
+/// Although drivers may choose any name for this function, the
+/// recommended name is "AdbcDriverInit".
+///
/// \param[in] count The number of entries to initialize. Provides
/// backwards compatibility if the struct definition is changed.
/// \param[out] driver The table of function pointers to initialize.
diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc
index 5cfcdd3..255486c 100644
--- a/c/driver_manager/adbc_driver_manager.cc
+++ b/c/driver_manager/adbc_driver_manager.cc
@@ -38,6 +38,7 @@ void ReleaseError(struct AdbcError* error) {
if (error) {
delete[] error->message;
error->message = nullptr;
+ error->release = nullptr;
}
}
@@ -144,7 +145,8 @@ AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const uint8_t*,
struct TempDatabase {
std::unordered_map<std::string, std::string> options;
std::string driver;
- std::string entrypoint;
+ // Default name (see adbc.h)
+ std::string entrypoint = "AdbcDriverInit";
};
/// Temporary state while the database is being configured.
@@ -236,35 +238,30 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
}
TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
if (args->driver.empty()) {
- delete args;
+ // Don't delete args here; caller should still call AdbcDatabaseRelease
SetError(error, "Must provide 'driver' parameter");
return ADBC_STATUS_INVALID_ARGUMENT;
- } else if (args->entrypoint.empty()) {
- delete args;
- SetError(error, "Must provide 'entrypoint' parameter");
- return ADBC_STATUS_INVALID_ARGUMENT;
}
database->private_driver = new AdbcDriver;
+ std::memset(database->private_driver, 0, sizeof(AdbcDriver));
size_t initialized = 0;
AdbcStatusCode status =
AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), ADBC_VERSION_0_0_1,
database->private_driver, &initialized, error);
if (status != ADBC_STATUS_OK) {
- delete args;
-
if (database->private_driver->release) {
database->private_driver->release(database->private_driver, error);
}
delete database->private_driver;
+ database->private_driver = nullptr;
return status;
} else if (initialized < ADBC_VERSION_0_0_1) {
- delete args;
-
if (database->private_driver->release) {
database->private_driver->release(database->private_driver, error);
}
delete database->private_driver;
+ database->private_driver = nullptr;
std::string message = "Database version is too old, expected ";
message += std::to_string(ADBC_VERSION_0_0_1);
@@ -275,24 +272,22 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
}
status = database->private_driver->DatabaseNew(database, error);
if (status != ADBC_STATUS_OK) {
- delete args;
-
if (database->private_driver->release) {
database->private_driver->release(database->private_driver, error);
}
delete database->private_driver;
+ database->private_driver = nullptr;
return status;
}
for (const auto& option : args->options) {
status = database->private_driver->DatabaseSetOption(database, option.first.c_str(),
option.second.c_str(), error);
if (status != ADBC_STATUS_OK) {
- delete args;
-
if (database->private_driver->release) {
database->private_driver->release(database->private_driver, error);
}
delete database->private_driver;
+ database->private_driver = nullptr;
return status;
}
}
@@ -613,6 +608,11 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint,
AdbcDriverInitFunc init_func;
std::string error_message;
+ if (!entrypoint) {
+ // Default entrypoint (see adbc.h)
+ entrypoint = "AdbcDriverInit";
+ }
+
#if defined(_WIN32)
HMODULE handle = LoadLibraryExA(driver_name, NULL, 0);
diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc
index 45e7bcf..50330f1 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -45,17 +45,14 @@ class DriverManager : public ::testing::Test {
size_t initialized = 0;
ADBC_ASSERT_OK_WITH_ERROR(
- error, AdbcLoadDriver("adbc_driver_sqlite", "AdbcSqliteDriverInit",
- ADBC_VERSION_0_0_1, &driver, &initialized, &error));
+ error, AdbcLoadDriver("adbc_driver_sqlite", NULL, ADBC_VERSION_0_0_1, &driver,
+ &initialized, &error));
ASSERT_EQ(initialized, ADBC_VERSION_0_0_1);
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
ASSERT_NE(database.private_data, nullptr);
ADBC_ASSERT_OK_WITH_ERROR(
error, AdbcDatabaseSetOption(&database, "driver", "adbc_driver_sqlite", &error));
- ADBC_ASSERT_OK_WITH_ERROR(
- error,
- AdbcDatabaseSetOption(&database, "entrypoint", "AdbcSqliteDriverInit", &error));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
ASSERT_NE(database.private_data, nullptr);
@@ -65,7 +62,7 @@ class DriverManager : public ::testing::Test {
}
void TearDown() override {
- if (error.message) {
+ if (error.release) {
error.release(&error);
}
@@ -90,7 +87,6 @@ class DriverManager : public ::testing::Test {
};
TEST_F(DriverManager, DatabaseInitRelease) {
- AdbcError error = {};
AdbcDatabase database;
std::memset(&database, 0, sizeof(database));
@@ -98,8 +94,31 @@ TEST_F(DriverManager, DatabaseInitRelease) {
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
}
+TEST_F(DriverManager, DatabaseCustomInitFunc) {
+ AdbcDatabase database;
+ std::memset(&database, 0, sizeof(database));
+
+ // Explicitly set entrypoint
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcDatabaseSetOption(&database, "driver", "adbc_driver_sqlite", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcDatabaseSetOption(&database, "entrypoint", "AdbcDriverInit", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
+
+ // Set invalid entrypoint
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcDatabaseSetOption(&database, "driver", "adbc_driver_sqlite", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error,
+ AdbcDatabaseSetOption(&database, "entrypoint", "ThisSymbolDoesNotExist", &error));
+ ASSERT_EQ(ADBC_STATUS_INTERNAL, AdbcDatabaseInit(&database, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
+}
+
TEST_F(DriverManager, ConnectionInitRelease) {
- AdbcError error = {};
AdbcConnection connection;
std::memset(&connection, 0, sizeof(connection));
@@ -296,12 +315,7 @@ TEST_F(DriverManager, Transactions) {
}
AdbcStatusCode SetupDatabase(struct AdbcDatabase* database, struct AdbcError* error) {
- AdbcStatusCode status;
- if ((status = AdbcDatabaseSetOption(database, "driver", "adbc_driver_sqlite", error)) !=
- ADBC_STATUS_OK) {
- return status;
- }
- return AdbcDatabaseSetOption(database, "entrypoint", "AdbcSqliteDriverInit", error);
+ return AdbcDatabaseSetOption(database, "driver", "adbc_driver_sqlite", error);
}
TEST_F(DriverManager, ValidationSuite) {
diff --git a/c/drivers/flight_sql/flight_sql.cc b/c/drivers/flight_sql/flight_sql.cc
index 3def988..66d402c 100644
--- a/c/drivers/flight_sql/flight_sql.cc
+++ b/c/drivers/flight_sql/flight_sql.cc
@@ -677,8 +677,8 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
extern "C" {
ADBC_EXPORT
-AdbcStatusCode AdbcFlightSqlDriverInit(size_t count, struct AdbcDriver* driver,
- size_t* initialized, struct AdbcError* error) {
+AdbcStatusCode AdbcDriverInit(size_t count, struct AdbcDriver* driver,
+ size_t* initialized, struct AdbcError* error) {
if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
std::memset(driver, 0, sizeof(*driver));
diff --git a/c/drivers/postgres/postgres.cc b/c/drivers/postgres/postgres.cc
index afa56fc..4922596 100644
--- a/c/drivers/postgres/postgres.cc
+++ b/c/drivers/postgres/postgres.cc
@@ -442,8 +442,8 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
extern "C" {
ADBC_EXPORT
-AdbcStatusCode AdbcPostgresDriverInit(size_t count, struct AdbcDriver* driver,
- size_t* initialized, struct AdbcError* error) {
+AdbcStatusCode AdbcDriverInit(size_t count, struct AdbcDriver* driver,
+ size_t* initialized, struct AdbcError* error) {
if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
std::memset(driver, 0, sizeof(*driver));
diff --git a/c/drivers/sqlite/sqlite.cc b/c/drivers/sqlite/sqlite.cc
index 80afbad..e6d003f 100644
--- a/c/drivers/sqlite/sqlite.cc
+++ b/c/drivers/sqlite/sqlite.cc
@@ -1767,8 +1767,8 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
extern "C" {
ADBC_EXPORT
-AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct AdbcDriver* driver,
- size_t* initialized, struct AdbcError* error) {
+AdbcStatusCode AdbcDriverInit(size_t count, struct AdbcDriver* driver,
+ size_t* initialized, struct AdbcError* error) {
if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
std::memset(driver, 0, sizeof(*driver));
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index ab2da54..1a12e0b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -440,8 +440,8 @@ cdef class AdbcDatabase(_AdbcHandle):
----------
kwargs : dict
String key-value options to pass to the underlying database.
- Must include at least "driver" and "entrypoint" to identify
- the underlying database driver to load.
+ Must include at least "driver" to identify the underlying
+ database driver to load.
"""
cdef:
CAdbcDatabase database
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 33c0b68..440737f 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -124,7 +124,7 @@ ROWID = _TypeSet([pyarrow.int64().id])
def connect(
*,
driver: str,
- entrypoint: str,
+ entrypoint: str = None,
db_kwargs: Optional[Dict[str, str]] = None,
conn_kwargs: Optional[Dict[str, str]] = None
) -> "Connection":
@@ -149,13 +149,15 @@ def connect(
db = None
conn = None
- if db_kwargs is None:
- db_kwargs = {}
+ db_kwargs = dict(db_kwargs or {})
+ db_kwargs["driver"] = driver
+ if entrypoint:
+ db_kwargs["entrypoint"] = entrypoint
if conn_kwargs is None:
conn_kwargs = {}
try:
- db = _lib.AdbcDatabase(driver=driver, entrypoint=entrypoint, **db_kwargs)
+ db = _lib.AdbcDatabase(**db_kwargs)
conn = _lib.AdbcConnection(db, **conn_kwargs)
return Connection(db, conn)
except Exception:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/tests/test_dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/tests/test_dbapi.py
index f088021..72b3ec4 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/tests/test_dbapi.py
@@ -26,10 +26,7 @@ from adbc_driver_manager import dbapi
@pytest.fixture
def sqlite():
"""Dynamically load the SQLite driver."""
- with dbapi.connect(
- driver="adbc_driver_sqlite",
- entrypoint="AdbcSqliteDriverInit",
- ) as conn:
+ with dbapi.connect(driver="adbc_driver_sqlite") as conn:
yield conn
diff --git a/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
index a15585c..5734cc2 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
@@ -24,10 +24,7 @@ import adbc_driver_manager
@pytest.fixture
def sqlite():
"""Dynamically load the SQLite driver."""
- with adbc_driver_manager.AdbcDatabase(
- driver="adbc_driver_sqlite",
- entrypoint="AdbcSqliteDriverInit",
- ) as db:
+ with adbc_driver_manager.AdbcDatabase(driver="adbc_driver_sqlite") as db:
with adbc_driver_manager.AdbcConnection(db) as conn:
yield (db, conn)