You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/06/09 14:48:57 UTC

[arrow-adbc] branch main updated: Refactor error handling (#5)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 13b8734  Refactor error handling (#5)
13b8734 is described below

commit 13b8734b6b860ebdf94e94a452aa53e4d7b0baa4
Author: David Li <li...@gmail.com>
AuthorDate: Thu Jun 9 10:48:53 2022 -0400

    Refactor error handling (#5)
    
    It's hard for the driver manager to allocate/free errors since it
    would have to manually intercept all calls. Also, it makes it hard
    to provide error messages when there is no driver. Move the error
    release callback to the error structure itself to make this easier.
---
 adbc.h                                          | 28 ++++++-----------
 adbc_driver_manager/adbc_driver_manager.cc      | 33 +++++++++++++++-----
 adbc_driver_manager/adbc_driver_manager.h       |  5 ++-
 adbc_driver_manager/adbc_driver_manager_test.cc |  6 ++--
 drivers/flight_sql/flight_sql.cc                | 41 +++++++++----------------
 drivers/sqlite/sqlite.cc                        | 16 +++++-----
 drivers/test_util.h                             |  4 +--
 7 files changed, 64 insertions(+), 69 deletions(-)

diff --git a/adbc.h b/adbc.h
index 2e2df69..29aff22 100644
--- a/adbc.h
+++ b/adbc.h
@@ -164,22 +164,13 @@ struct AdbcError {
   /// \brief The error message.
   char* message;
 
-  /// \brief The associated driver (used by the driver manager to help
-  ///   track state).
-  AdbcDriver* private_driver;
-
-  // TODO: go back to just inlining 'release' here? And remove the
-  // global AdbcErrorRelease? It would be slightly inconsistent (and
-  // would make the struct impossible to extend) but would be easier
-  // to manage between the driver manager and driver.
+  /// \brief Release the contained error.
+  ///
+  /// Unlike other structures, this is an embedded callback to make it
+  /// easier for the driver manager and driver to cooperate.
+  void (*release)(struct AdbcError* error);
 };
 
-/// \brief Destroy an error message.
-void AdbcErrorRelease(struct AdbcError* error);
-
-/// \brief Get a human-readable description of a status code.
-const char* AdbcStatusCodeMessage(AdbcStatusCode code);
-
 /// }@
 
 /// \defgroup adbc-database Database initialization.
@@ -576,9 +567,6 @@ struct AdbcDriver {
   void* private_data;
   // TODO: DriverRelease
 
-  void (*ErrorRelease)(struct AdbcError*);
-  const char* (*StatusCodeMessage)(AdbcStatusCode);
-
   AdbcStatusCode (*DatabaseInit)(const struct AdbcDatabaseOptions*, struct AdbcDatabase*,
                                  struct AdbcError*);
   AdbcStatusCode (*DatabaseRelease)(struct AdbcDatabase*, struct AdbcError*);
@@ -631,14 +619,16 @@ struct AdbcDriver {
 /// \param[out] driver The table of function pointers to initialize.
 /// \param[out] initialized How much of the table was actually
 ///   initialized (can be less than count).
+/// \param[out] error An optional location to return an error message
+///   if necessary.
 typedef AdbcStatusCode (*AdbcDriverInitFunc)(size_t count, struct AdbcDriver* driver,
-                                             size_t* initialized);
+                                             size_t* initialized, struct AdbcError* error);
 // TODO: how best to report errors here?
 // TODO: use sizeof() instead of count, or version the
 // struct/entrypoint instead?
 
 // For use with count
-#define ADBC_VERSION_0_0_1 21
+#define ADBC_VERSION_0_0_1 19
 
 /// }@
 
diff --git a/adbc_driver_manager/adbc_driver_manager.cc b/adbc_driver_manager/adbc_driver_manager.cc
index 998c672..a5f899d 100644
--- a/adbc_driver_manager/adbc_driver_manager.cc
+++ b/adbc_driver_manager/adbc_driver_manager.cc
@@ -47,6 +47,20 @@ std::unordered_map<std::string, std::string> ParseConnectionString(
   return option_pairs;
 }
 
+void ReleaseError(struct AdbcError* error) {
+  if (error) {
+    delete[] error->message;
+    error->message = nullptr;
+  }
+}
+
+void SetError(struct AdbcError* error, const std::string& message) {
+  error->message = new char[message.size() + 1];
+  message.copy(error->message, message.size());
+  error->message[message.size()] = '\0';
+  error->release = ReleaseError;
+}
+
 // Default stubs
 AdbcStatusCode ConnectionSqlPrepare(struct AdbcConnection*, const char*,
                                     struct AdbcStatement*, struct AdbcError* error) {
@@ -70,12 +84,6 @@ AdbcStatusCode StatementExecute(struct AdbcStatement*, struct AdbcError* error)
 
 // Direct implementations of API methods
 
-void AdbcErrorRelease(struct AdbcError* error) {
-  if (!error->message) return;
-  // TODO: assert
-  error->private_driver->ErrorRelease(error);
-}
-
 AdbcStatusCode AdbcDatabaseInit(const struct AdbcDatabaseOptions* options,
                                 struct AdbcDatabase* out, struct AdbcError* error) {
   if (!options->driver) {
@@ -212,31 +220,40 @@ const char* AdbcStatusCodeMessage(AdbcStatusCode code) {
 }
 
 AdbcStatusCode AdbcLoadDriver(const char* connection, size_t count,
-                              struct AdbcDriver* driver, size_t* initialized) {
+                              struct AdbcDriver* driver, size_t* initialized,
+                              struct AdbcError* error) {
   auto params = ParseConnectionString(connection);
 
   auto driver_str = params.find("Driver");
   if (driver_str == params.end()) {
+    SetError(error, "Must provide Driver parameter");
     return ADBC_STATUS_INVALID_ARGUMENT;
   }
 
   auto entrypoint_str = params.find("Entrypoint");
   if (entrypoint_str == params.end()) {
+    SetError(error, "Must provide Entrypoint parameter");
     return ADBC_STATUS_INVALID_ARGUMENT;
   }
 
   void* handle = dlopen(driver_str->second.c_str(), RTLD_NOW | RTLD_LOCAL);
   if (!handle) {
+    std::string message = "dlopen() failed: ";
+    message += dlerror();
+    SetError(error, message);
     return ADBC_STATUS_UNKNOWN;
   }
 
   void* load_handle = dlsym(handle, entrypoint_str->second.c_str());
   auto* load = reinterpret_cast<AdbcDriverInitFunc>(load_handle);
   if (!load) {
+    std::string message = "dlsym() failed: ";
+    message += dlerror();
+    SetError(error, message);
     return ADBC_STATUS_INTERNAL;
   }
 
-  auto result = load(count, driver, initialized);
+  auto result = load(count, driver, initialized, error);
   if (result != ADBC_STATUS_OK) {
     return result;
   }
diff --git a/adbc_driver_manager/adbc_driver_manager.h b/adbc_driver_manager/adbc_driver_manager.h
index 074cb52..00f1077 100644
--- a/adbc_driver_manager/adbc_driver_manager.h
+++ b/adbc_driver_manager/adbc_driver_manager.h
@@ -40,8 +40,11 @@ extern "C" {
 /// \param[out] driver The table of function pointers to initialize.
 /// \param[out] initialized How much of the table was actually
 ///   initialized (can be less than count).
+/// \param[out] error An optional location to return an error message
+///   if necessary.
 AdbcStatusCode AdbcLoadDriver(const char* connection, size_t count,
-                              struct AdbcDriver* driver, size_t* initialized);
+                              struct AdbcDriver* driver, size_t* initialized,
+                              struct AdbcError* error);
 
 #endif  // ADBC_DRIVER_MANAGER_H
 
diff --git a/adbc_driver_manager/adbc_driver_manager_test.cc b/adbc_driver_manager/adbc_driver_manager_test.cc
index 5ba09af..65ec7e4 100644
--- a/adbc_driver_manager/adbc_driver_manager_test.cc
+++ b/adbc_driver_manager/adbc_driver_manager_test.cc
@@ -36,9 +36,9 @@ class DriverManager : public ::testing::Test {
  public:
   void SetUp() override {
     size_t initialized = 0;
-    ADBC_ASSERT_OK(
+    ADBC_ASSERT_OK_WITH_ERROR(error,
         AdbcLoadDriver("Driver=libadbc_driver_sqlite.so;Entrypoint=AdbcSqliteDriverInit",
-                       ADBC_VERSION_0_0_1, &driver, &initialized));
+                       ADBC_VERSION_0_0_1, &driver, &initialized, &error));
     ASSERT_EQ(initialized, ADBC_VERSION_0_0_1);
 
     AdbcDatabaseOptions db_options;
@@ -90,8 +90,6 @@ TEST_F(DriverManager, SqlExecute) {
 }
 
 TEST_F(DriverManager, SqlExecuteInvalid) {
-  GTEST_SKIP() << "AdbcError needs refactoring";
-
   std::string query = "INVALID";
   AdbcStatement statement;
   std::memset(&statement, 0, sizeof(statement));
diff --git a/drivers/flight_sql/flight_sql.cc b/drivers/flight_sql/flight_sql.cc
index 5130d21..584abb1 100644
--- a/drivers/flight_sql/flight_sql.cc
+++ b/drivers/flight_sql/flight_sql.cc
@@ -35,18 +35,11 @@ using arrow::Status;
 
 namespace {
 
-void SetError(const Status& status, struct AdbcError* error) {
-  if (!error) return;
-  std::string message = arrow::util::StringBuilder("[Flight SQL] ", status.ToString());
+void ReleaseError(struct AdbcError* error) {
   if (error->message) {
-    message.reserve(message.size() + 1 + std::strlen(error->message));
-    message.append(1, '\n');
-    message.append(error->message);
     delete[] error->message;
+    error->message = nullptr;
   }
-  error->message = new char[message.size() + 1];
-  message.copy(error->message, message.size());
-  error->message[message.size()] = '\0';
 }
 
 template <typename... Args>
@@ -63,6 +56,7 @@ void SetError(struct AdbcError* error, Args&&... args) {
   error->message = new char[message.size() + 1];
   message.copy(error->message, message.size());
   error->message[message.size()] = '\0';
+  error->release = ReleaseError;
 }
 
 class FlightSqlDatabaseImpl {
@@ -96,7 +90,7 @@ class FlightSqlDatabaseImpl {
 
     auto status = client_->Close();
     if (!status.ok()) {
-      SetError(status, error);
+      SetError(error, status);
       return ADBC_STATUS_IO;
     }
     return ADBC_STATUS_OK;
@@ -159,7 +153,7 @@ class FlightSqlStatementImpl : public arrow::RecordBatchReader {
 
     auto status = NextStream();
     if (!status.ok()) {
-      SetError(status, error);
+      SetError(error, status);
       return ADBC_STATUS_IO;
     }
     if (!schema_) {
@@ -168,7 +162,7 @@ class FlightSqlStatementImpl : public arrow::RecordBatchReader {
 
     status = arrow::ExportRecordBatchReader(self, out);
     if (!status.ok()) {
-      SetError(status, error);
+      SetError(error, status);
       return ADBC_STATUS_UNKNOWN;
     }
     return ADBC_STATUS_OK;
@@ -257,7 +251,7 @@ class AdbcFlightSqlImpl {
     std::unique_ptr<flight::FlightInfo> flight_info;
     auto status = client_->GetTableTypes(call_options).Value(&flight_info);
     if (!status.ok()) {
-      SetError(status, error);
+      SetError(error, status);
       return ADBC_STATUS_IO;
     }
     impl->Init(client_, std::move(flight_info));
@@ -282,7 +276,7 @@ class AdbcFlightSqlImpl {
     std::unique_ptr<flight::FlightInfo> flight_info;
     auto status = client_->Execute(call_options, std::string(query)).Value(&flight_info);
     if (!status.ok()) {
-      SetError(status, error);
+      SetError(error, status);
       return ADBC_STATUS_IO;
     }
     impl->Init(client_, std::move(flight_info));
@@ -312,7 +306,7 @@ class AdbcFlightSqlImpl {
         *arrow::schema({}), flight::FlightDescriptor::Command(""), endpoints,
         /*total_records=*/-1, /*total_bytes=*/-1);
     if (!maybe_info.ok()) {
-      SetError(maybe_info.status(), error);
+      SetError(error, maybe_info.status());
       return ADBC_STATUS_INVALID_ARGUMENT;
     }
     std::unique_ptr<flight::FlightInfo> flight_info(
@@ -329,12 +323,6 @@ class AdbcFlightSqlImpl {
 
 }  // namespace
 
-ADBC_DRIVER_EXPORT
-void AdbcErrorRelease(struct AdbcError* error) {
-  delete[] error->message;
-  error->message = nullptr;
-}
-
 ADBC_DRIVER_EXPORT
 AdbcStatusCode AdbcDatabaseInit(const struct AdbcDatabaseOptions* options,
                                 struct AdbcDatabase* out, struct AdbcError* error) {
@@ -342,26 +330,26 @@ AdbcStatusCode AdbcDatabaseInit(const struct AdbcDatabaseOptions* options,
   auto status = adbc::ParseConnectionString(arrow::util::string_view(options->target))
                     .Value(&option_pairs);
   if (!status.ok()) {
-    SetError(status, error);
+    SetError(error, status);
     return ADBC_STATUS_INVALID_ARGUMENT;
   }
   auto location_it = option_pairs.find("Location");
   if (location_it == option_pairs.end()) {
-    SetError(Status::Invalid("Must provide Location option"), error);
+    SetError(error, Status::Invalid("Must provide Location option"));
     return ADBC_STATUS_INVALID_ARGUMENT;
   }
 
   flight::Location location;
   status = flight::Location::Parse(location_it->second).Value(&location);
   if (!status.ok()) {
-    SetError(status, error);
+    SetError(error, status);
     return ADBC_STATUS_INVALID_ARGUMENT;
   }
 
   std::unique_ptr<flight::FlightClient> flight_client;
   status = flight::FlightClient::Connect(location).Value(&flight_client);
   if (!status.ok()) {
-    SetError(status, error);
+    SetError(error, status);
     return ADBC_STATUS_IO;
   }
   std::unique_ptr<flightsql::FlightSqlClient> client(
@@ -497,11 +485,10 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
 extern "C" {
 ARROW_EXPORT
 AdbcStatusCode AdbcFlightSqlDriverInit(size_t count, struct AdbcDriver* driver,
-                                       size_t* initialized) {
+                                       size_t* initialized, struct AdbcError* error) {
   if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
 
   std::memset(driver, 0, sizeof(*driver));
-  driver->ErrorRelease = AdbcErrorRelease;
   driver->DatabaseInit = AdbcDatabaseInit;
   driver->DatabaseRelease = AdbcDatabaseRelease;
   driver->ConnectionInit = AdbcConnectionInit;
diff --git a/drivers/sqlite/sqlite.cc b/drivers/sqlite/sqlite.cc
index e94cb96..e55e0b5 100644
--- a/drivers/sqlite/sqlite.cc
+++ b/drivers/sqlite/sqlite.cc
@@ -36,6 +36,11 @@ namespace {
 
 using arrow::Status;
 
+void ReleaseError(struct AdbcError* error) {
+  delete[] error->message;
+  error->message = nullptr;
+}
+
 void SetError(sqlite3* db, const std::string& source, struct AdbcError* error) {
   if (!error) return;
   std::string message =
@@ -49,6 +54,7 @@ void SetError(sqlite3* db, const std::string& source, struct AdbcError* error) {
   error->message = new char[message.size() + 1];
   message.copy(error->message, message.size());
   error->message[message.size()] = '\0';
+  error->release = ReleaseError;
 }
 
 template <typename... Args>
@@ -65,6 +71,7 @@ void SetError(struct AdbcError* error, Args&&... args) {
   error->message = new char[message.size() + 1];
   message.copy(error->message, message.size());
   error->message[message.size()] = '\0';
+  error->release = ReleaseError;
 }
 
 std::shared_ptr<arrow::Schema> StatementToSchema(sqlite3_stmt* stmt) {
@@ -444,12 +451,6 @@ class SqliteConnectionImpl {
 
 }  // namespace
 
-ADBC_DRIVER_EXPORT
-void AdbcErrorRelease(struct AdbcError* error) {
-  delete[] error->message;
-  error->message = nullptr;
-}
-
 ADBC_DRIVER_EXPORT
 AdbcStatusCode AdbcDatabaseInit(const struct AdbcDatabaseOptions* options,
                                 struct AdbcDatabase* out, struct AdbcError* error) {
@@ -593,11 +594,10 @@ AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
 extern "C" {
 ARROW_EXPORT
 AdbcStatusCode AdbcSqliteDriverInit(size_t count, struct AdbcDriver* driver,
-                                    size_t* initialized) {
+                                    size_t* initialized, struct AdbcError* error) {
   if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
 
   std::memset(driver, 0, sizeof(*driver));
-  driver->ErrorRelease = AdbcErrorRelease;
   driver->DatabaseInit = AdbcDatabaseInit;
   driver->DatabaseRelease = AdbcDatabaseRelease;
   driver->ConnectionInit = AdbcConnectionInit;
diff --git a/drivers/test_util.h b/drivers/test_util.h
index 35f887e..0123a5b 100644
--- a/drivers/test_util.h
+++ b/drivers/test_util.h
@@ -40,7 +40,7 @@ namespace adbc {
     auto code_ = (EXPR);                                                       \
     if (code_ != ADBC_STATUS_OK) {                                             \
       std::string errmsg_ = ERROR.message ? ERROR.message : "(unknown error)"; \
-      AdbcErrorRelease(&ERROR);                                                \
+      if (ERROR.message) error.release(&error);                                \
       ASSERT_EQ(code_, ADBC_STATUS_OK) << errmsg_;                             \
     }                                                                          \
   } while (false)
@@ -49,7 +49,7 @@ namespace adbc {
   do {                                                                       \
     ASSERT_NE(ERROR.message, nullptr);                                       \
     std::string errmsg_ = ERROR.message ? ERROR.message : "(unknown error)"; \
-    AdbcErrorRelease(&ERROR);                                                \
+      if (ERROR.message) error.release(&error);                                \
     ASSERT_THAT(errmsg_, PATTERN) << errmsg_;                                \
   } while (false)