You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/08/29 16:16:02 UTC

[arrow-adbc] branch main updated: [C] Specify a standard entrypoint name (#93)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new aa4e798  [C] Specify a standard entrypoint name (#93)
aa4e798 is described below

commit aa4e798f16d1035b4efc1bf8c90dc3b61b6d6432
Author: David Li <li...@gmail.com>
AuthorDate: Mon Aug 29 12:15:58 2022 -0400

    [C] Specify a standard entrypoint name (#93)
    
    Fixes #86.
---
 adbc.h                                             |  3 ++
 c/driver_manager/adbc_driver_manager.cc            | 28 +++++++--------
 c/driver_manager/adbc_driver_manager_test.cc       | 42 ++++++++++++++--------
 c/drivers/flight_sql/flight_sql.cc                 |  4 +--
 c/drivers/postgres/postgres.cc                     |  4 +--
 c/drivers/sqlite/sqlite.cc                         |  4 +--
 .../adbc_driver_manager/_lib.pyx                   |  4 +--
 .../adbc_driver_manager/dbapi.py                   | 10 +++---
 .../adbc_driver_manager/tests/test_dbapi.py        |  5 +--
 .../adbc_driver_manager/tests/test_lowlevel.py     |  5 +--
 10 files changed, 61 insertions(+), 48 deletions(-)

diff --git a/adbc.h b/adbc.h
index f3f33d3..e1c83b1 100644
--- a/adbc.h
+++ b/adbc.h
@@ -1046,6 +1046,9 @@ struct ADBC_EXPORT AdbcDriver {
 ///   to load a library and call a function of this type to load the
 ///   driver.
 ///
+/// Although drivers may choose any name for this function, the
+/// recommended name is "AdbcDriverInit".
+///
 /// \param[in] count The number of entries to initialize. Provides
 ///   backwards compatibility if the struct definition is changed.
 /// \param[out] driver The table of function pointers to initialize.
diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc
index 5cfcdd3..255486c 100644
--- a/c/driver_manager/adbc_driver_manager.cc
+++ b/c/driver_manager/adbc_driver_manager.cc
@@ -38,6 +38,7 @@ void ReleaseError(struct AdbcError* error) {
   if (error) {
     delete[] error->message;
     error->message = nullptr;
+    error->release = nullptr;
   }
 }
 
@@ -144,7 +145,8 @@ AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement*, const uint8_t*,
 struct TempDatabase {
   std::unordered_map<std::string, std::string> options;
   std::string driver;
-  std::string entrypoint;
+  // Default name (see adbc.h)
+  std::string entrypoint = "AdbcDriverInit";
 };
 
 /// Temporary state while the database is being configured.
@@ -236,35 +238,30 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
   }
   TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
   if (args->driver.empty()) {
-    delete args;
+    // Don't delete args here; caller should still call AdbcDatabaseRelease
     SetError(error, "Must provide 'driver' parameter");
     return ADBC_STATUS_INVALID_ARGUMENT;
-  } else if (args->entrypoint.empty()) {
-    delete args;
-    SetError(error, "Must provide 'entrypoint' parameter");
-    return ADBC_STATUS_INVALID_ARGUMENT;
   }
 
   database->private_driver = new AdbcDriver;
+  std::memset(database->private_driver, 0, sizeof(AdbcDriver));
   size_t initialized = 0;
   AdbcStatusCode status =
       AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), ADBC_VERSION_0_0_1,
                      database->private_driver, &initialized, error);
   if (status != ADBC_STATUS_OK) {
-    delete args;
-
     if (database->private_driver->release) {
       database->private_driver->release(database->private_driver, error);
     }
     delete database->private_driver;
+    database->private_driver = nullptr;
     return status;
   } else if (initialized < ADBC_VERSION_0_0_1) {
-    delete args;
-
     if (database->private_driver->release) {
       database->private_driver->release(database->private_driver, error);
     }
     delete database->private_driver;
+    database->private_driver = nullptr;
 
     std::string message = "Database version is too old, expected ";
     message += std::to_string(ADBC_VERSION_0_0_1);
@@ -275,24 +272,22 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
   }
   status = database->private_driver->DatabaseNew(database, error);
   if (status != ADBC_STATUS_OK) {
-    delete args;
-
     if (database->private_driver->release) {
       database->private_driver->release(database->private_driver, error);
     }
     delete database->private_driver;
+    database->private_driver = nullptr;
     return status;
   }
   for (const auto& option : args->options) {
     status = database->private_driver->DatabaseSetOption(database, option.first.c_str(),
                                                          option.second.c_str(), error);
     if (status != ADBC_STATUS_OK) {
-      delete args;
-
       if (database->private_driver->release) {
         database->private_driver->release(database->private_driver, error);
       }
       delete database->private_driver;
+      database->private_driver = nullptr;
       return status;
     }
   }
@@ -613,6 +608,11 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint,
   AdbcDriverInitFunc init_func;
   std::string error_message;
 
+  if (!entrypoint) {
+    // Default entrypoint (see adbc.h)
+    entrypoint = "AdbcDriverInit";
+  }
+
 #if defined(_WIN32)
 
   HMODULE handle = LoadLibraryExA(driver_name, NULL, 0);
diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc
index 45e7bcf..50330f1 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -45,17 +45,14 @@ class DriverManager : public ::testing::Test {
 
     size_t initialized = 0;
     ADBC_ASSERT_OK_WITH_ERROR(
-        error, AdbcLoadDriver("adbc_driver_sqlite", "AdbcSqliteDriverInit",
-                              ADBC_VERSION_0_0_1, &driver, &initialized, &error));
+        error, AdbcLoadDriver("adbc_driver_sqlite", NULL, ADBC_VERSION_0_0_1, &driver,
+                              &initialized, &error));
     ASSERT_EQ(initialized, ADBC_VERSION_0_0_1);
 
     ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
     ASSERT_NE(database.private_data, nullptr);
     ADBC_ASSERT_OK_WITH_ERROR(
         error, AdbcDatabaseSetOption(&database, "driver", "adbc_driver_sqlite", &error));
-    ADBC_ASSERT_OK_WITH_ERROR(
-        error,
-        AdbcDatabaseSetOption(&database, "entrypoint", "AdbcSqliteDriverInit", &error));
     ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
     ASSERT_NE(database.private_data, nullptr);
 
@@ -65,7 +62,7 @@ class DriverManager : public ::testing::Test {
   }
 
   void TearDown() override {
-    if (error.message) {
+    if (error.release) {
       error.release(&error);
     }
 
@@ -90,7 +87,6 @@ class DriverManager : public ::testing::Test {
 };
 
 TEST_F(DriverManager, DatabaseInitRelease) {
-  AdbcError error = {};
   AdbcDatabase database;
   std::memset(&database, 0, sizeof(database));
 
@@ -98,8 +94,31 @@ TEST_F(DriverManager, DatabaseInitRelease) {
   ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
 }
 
+TEST_F(DriverManager, DatabaseCustomInitFunc) {
+  AdbcDatabase database;
+  std::memset(&database, 0, sizeof(database));
+
+  // Explicitly set entrypoint
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(
+      error, AdbcDatabaseSetOption(&database, "driver", "adbc_driver_sqlite", &error));
+  ADBC_ASSERT_OK_WITH_ERROR(
+      error, AdbcDatabaseSetOption(&database, "entrypoint", "AdbcDriverInit", &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseInit(&database, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
+
+  // Set invalid entrypoint
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseNew(&database, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(
+      error, AdbcDatabaseSetOption(&database, "driver", "adbc_driver_sqlite", &error));
+  ADBC_ASSERT_OK_WITH_ERROR(
+      error,
+      AdbcDatabaseSetOption(&database, "entrypoint", "ThisSymbolDoesNotExist", &error));
+  ASSERT_EQ(ADBC_STATUS_INTERNAL, AdbcDatabaseInit(&database, &error));
+  ADBC_ASSERT_OK_WITH_ERROR(error, AdbcDatabaseRelease(&database, &error));
+}
+
 TEST_F(DriverManager, ConnectionInitRelease) {
-  AdbcError error = {};
   AdbcConnection connection;
   std::memset(&connection, 0, sizeof(connection));
 
@@ -296,12 +315,7 @@ TEST_F(DriverManager, Transactions) {
 }
 
 AdbcStatusCode SetupDatabase(struct AdbcDatabase* database, struct AdbcError* error) {
-  AdbcStatusCode status;
-  if ((status = AdbcDatabaseSetOption(database, "driver", "adbc_driver_sqlite", error)) !=
-      ADBC_STATUS_OK) {
-    return status;
-  }
-  return AdbcDatabaseSetOption(database, "entrypoint", "AdbcSqliteDriverInit", error);
+  return AdbcDatabaseSetOption(database, "driver", "adbc_driver_sqlite", error);
 }
 
 TEST_F(DriverManager, ValidationSuite) {
diff --git a/c/drivers/flight_sql/flight_sql.cc b/c/drivers/flight_sql/flight_sql.cc
index 3def988..66d402c 100644
--- a/c/drivers/flight_sql/flight_sql.cc
+++ b/c/drivers/flight_sql/flight_sql.cc
@@ -677,8 +677,8 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
 
 extern "C" {
 ADBC_EXPORT
-AdbcStatusCode AdbcFlightSqlDriverInit(size_t count, struct AdbcDriver* driver,
-                                       size_t* initialized, struct AdbcError* error) {
+AdbcStatusCode AdbcDriverInit(size_t count, struct AdbcDriver* driver,
+                              size_t* initialized, struct AdbcError* error) {
   if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
 
   std::memset(driver, 0, sizeof(*driver));
diff --git a/c/drivers/postgres/postgres.cc b/c/drivers/postgres/postgres.cc
index afa56fc..4922596 100644
--- a/c/drivers/postgres/postgres.cc
+++ b/c/drivers/postgres/postgres.cc
@@ -442,8 +442,8 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
 
 extern "C" {
 ADBC_EXPORT
-AdbcStatusCode AdbcPostgresDriverInit(size_t count, struct AdbcDriver* driver,
-                                      size_t* initialized, struct AdbcError* error) {
+AdbcStatusCode AdbcDriverInit(size_t count, struct AdbcDriver* driver,
+                              size_t* initialized, struct AdbcError* error) {
   if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
 
   std::memset(driver, 0, sizeof(*driver));
diff --git a/c/drivers/sqlite/sqlite.cc b/c/drivers/sqlite/sqlite.cc
index 80afbad..e6d003f 100644
--- a/c/drivers/sqlite/sqlite.cc
+++ b/c/drivers/sqlite/sqlite.cc
@@ -1767,8 +1767,8 @@ AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
 
 extern "C" {
 ADBC_EXPORT
-AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct AdbcDriver* driver,
-                                    size_t* initialized, struct AdbcError* error) {
+AdbcStatusCode AdbcDriverInit(size_t count, struct AdbcDriver* driver,
+                              size_t* initialized, struct AdbcError* error) {
   if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
 
   std::memset(driver, 0, sizeof(*driver));
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index ab2da54..1a12e0b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -440,8 +440,8 @@ cdef class AdbcDatabase(_AdbcHandle):
     ----------
     kwargs : dict
         String key-value options to pass to the underlying database.
-        Must include at least "driver" and "entrypoint" to identify
-        the underlying database driver to load.
+        Must include at least "driver" to identify the underlying
+        database driver to load.
     """
     cdef:
         CAdbcDatabase database
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 33c0b68..440737f 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -124,7 +124,7 @@ ROWID = _TypeSet([pyarrow.int64().id])
 def connect(
     *,
     driver: str,
-    entrypoint: str,
+    entrypoint: str = None,
     db_kwargs: Optional[Dict[str, str]] = None,
     conn_kwargs: Optional[Dict[str, str]] = None
 ) -> "Connection":
@@ -149,13 +149,15 @@ def connect(
     db = None
     conn = None
 
-    if db_kwargs is None:
-        db_kwargs = {}
+    db_kwargs = dict(db_kwargs or {})
+    db_kwargs["driver"] = driver
+    if entrypoint:
+        db_kwargs["entrypoint"] = entrypoint
     if conn_kwargs is None:
         conn_kwargs = {}
 
     try:
-        db = _lib.AdbcDatabase(driver=driver, entrypoint=entrypoint, **db_kwargs)
+        db = _lib.AdbcDatabase(**db_kwargs)
         conn = _lib.AdbcConnection(db, **conn_kwargs)
         return Connection(db, conn)
     except Exception:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/tests/test_dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/tests/test_dbapi.py
index f088021..72b3ec4 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/tests/test_dbapi.py
@@ -26,10 +26,7 @@ from adbc_driver_manager import dbapi
 @pytest.fixture
 def sqlite():
     """Dynamically load the SQLite driver."""
-    with dbapi.connect(
-        driver="adbc_driver_sqlite",
-        entrypoint="AdbcSqliteDriverInit",
-    ) as conn:
+    with dbapi.connect(driver="adbc_driver_sqlite") as conn:
         yield conn
 
 
diff --git a/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
index a15585c..5734cc2 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
@@ -24,10 +24,7 @@ import adbc_driver_manager
 @pytest.fixture
 def sqlite():
     """Dynamically load the SQLite driver."""
-    with adbc_driver_manager.AdbcDatabase(
-        driver="adbc_driver_sqlite",
-        entrypoint="AdbcSqliteDriverInit",
-    ) as db:
+    with adbc_driver_manager.AdbcDatabase(driver="adbc_driver_sqlite") as db:
         with adbc_driver_manager.AdbcConnection(db) as conn:
             yield (db, conn)