You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/08/26 18:33:02 UTC
[arrow-adbc] branch main updated: [C] Add nanoarrow-based libpq driver (#65)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 35a9a22 [C] Add nanoarrow-based libpq driver (#65)
35a9a22 is described below
commit 35a9a22a0fdb22f4f09650f31dd5bee183369bc1
Author: David Li <li...@gmail.com>
AuthorDate: Fri Aug 26 14:32:57 2022 -0400
[C] Add nanoarrow-based libpq driver (#65)
---
c/drivers/postgres/CMakeLists.txt | 73 ++++
c/drivers/postgres/README.md | 25 ++
c/drivers/postgres/connection.cc | 118 ++++++
c/drivers/postgres/connection.h | 50 +++
c/drivers/postgres/database.cc | 124 ++++++
c/drivers/postgres/database.h | 51 +++
c/drivers/postgres/postgres.cc | 479 ++++++++++++++++++++++
c/drivers/postgres/postgres_test.cc | 70 ++++
c/drivers/postgres/statement.cc | 792 ++++++++++++++++++++++++++++++++++++
c/drivers/postgres/statement.h | 113 +++++
c/drivers/postgres/type.cc | 80 ++++
c/drivers/postgres/type.h | 63 +++
c/drivers/postgres/util.h | 96 +++++
c/validation/adbc_validation.c | 90 ++--
14 files changed, 2195 insertions(+), 29 deletions(-)
diff --git a/c/drivers/postgres/CMakeLists.txt b/c/drivers/postgres/CMakeLists.txt
new file mode 100644
index 0000000..40b8f66
--- /dev/null
+++ b/c/drivers/postgres/CMakeLists.txt
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+cmake_minimum_required(VERSION 3.14)
+get_filename_component(REPOSITORY_ROOT "../../../" ABSOLUTE)
+list(APPEND CMAKE_MODULE_PATH "${REPOSITORY_ROOT}/c/cmake_modules/")
+include(AdbcDefines)
+include(BuildUtils)
+
+project(adbc_driver_postgres
+ VERSION "${ADBC_BASE_VERSION}"
+ LANGUAGES CXX)
+include(CTest)
+find_package(PkgConfig)
+
+pkg_check_modules(LIBPQ REQUIRED libpq)
+
+add_arrow_lib(adbc_driver_postgres
+ SOURCES
+ connection.cc
+ database.cc
+ postgres.cc
+ statement.cc
+ type.cc
+ ${REPOSITORY_ROOT}/c/vendor/nanoarrow/nanoarrow.c
+ OUTPUTS
+ ADBC_LIBRARIES
+ SHARED_LINK_LIBS
+ ${LIBPQ_LIBRARIES}
+ STATIC_LINK_LIBS
+ ${LIBPQ_STATIC_LIBRARIES})
+include_directories(SYSTEM ${REPOSITORY_ROOT})
+include_directories(SYSTEM ${REPOSITORY_ROOT}/c/)
+include_directories(SYSTEM ${REPOSITORY_ROOT}/c/vendor/nanoarrow/)
+include_directories(SYSTEM ${LIBPQ_INCLUDE_DIRS})
+foreach(LIB_TARGET ${ADBC_LIBRARIES})
+ target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_EXPORTING)
+endforeach()
+
+if(ADBC_TEST_LINKAGE STREQUAL "shared")
+ set(TEST_LINK_LIBS adbc_driver_postgres_shared)
+else()
+ set(TEST_LINK_LIBS adbc_driver_postgres_static)
+endif()
+
+if(ADBC_BUILD_TESTS)
+ add_test_case(driver_postgres_test
+ PREFIX
+ adbc
+ SOURCES
+ postgres_test.cc
+ ../../validation/adbc_validation.c
+ ${REPOSITORY_ROOT}/c/vendor/nanoarrow/nanoarrow.c
+ EXTRA_LINK_LIBS
+ ${TEST_LINK_LIBS})
+endif()
+
+validate_config()
+config_summary_message()
diff --git a/c/drivers/postgres/README.md b/c/drivers/postgres/README.md
new file mode 100644
index 0000000..8138536
--- /dev/null
+++ b/c/drivers/postgres/README.md
@@ -0,0 +1,25 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# libpq ADBC Driver
+
+With credit to 0x0L's [pgeon](https://github.com/0x0L/pgeon) for the overall approach.
+
+This implements an ADBC driver that wraps [libpq](https://www.postgresql.org/docs/14/libpq.html).
+This is still a work in progress.
diff --git a/c/drivers/postgres/connection.cc b/c/drivers/postgres/connection.cc
new file mode 100644
index 0000000..38cba57
--- /dev/null
+++ b/c/drivers/postgres/connection.cc
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "connection.h"
+
+#include <cstring>
+#include <memory>
+
+#include <adbc.h>
+
+#include "database.h"
+#include "util.h"
+
+namespace adbcpq {
+AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) {
+ if (autocommit_) {
+ SetError(error, "Cannot commit when autocommit is enabled");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ PGresult* result = PQexec(conn_, "COMMIT");
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error, "Failed to commit: ", PQerrorMessage(conn_));
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ PQclear(result);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog,
+ const char* db_schema,
+ const char* table_name,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ if (!database || !database->private_data) {
+ SetError(error, "Must provide an initialized AdbcDatabase");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ database_ =
+ *reinterpret_cast<std::shared_ptr<PostgresDatabase>*>(database->private_data);
+ type_mapping_ = database_->type_mapping();
+ return database_->Connect(&conn_, error);
+}
+
+AdbcStatusCode PostgresConnection::Release(struct AdbcError* error) {
+ if (conn_) {
+ return database_->Disconnect(&conn_, error);
+ }
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresConnection::Rollback(struct AdbcError* error) {
+ if (autocommit_) {
+ SetError(error, "Cannot rollback when autocommit is enabled");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ PGresult* result = PQexec(conn_, "ROLLBACK");
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error, "Failed to rollback: ", PQerrorMessage(conn_));
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ PQclear(result);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresConnection::SetOption(const char* key, const char* value,
+ struct AdbcError* error) {
+ if (std::strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) {
+ bool autocommit = true;
+ if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
+ autocommit = true;
+ } else if (std::strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) {
+ autocommit = false;
+ } else {
+ SetError(error, "Invalid value for option ", key, ": ", value);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ if (autocommit != autocommit_) {
+ const char* query = autocommit ? "COMMIT" : "BEGIN TRANSACTION";
+
+ PGresult* result = PQexec(conn_, query);
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error, "Failed to update autocommit: ", PQerrorMessage(conn_));
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ PQclear(result);
+ autocommit_ = autocommit;
+ }
+ return ADBC_STATUS_OK;
+ }
+ SetError(error, "Unknown option ", key);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+} // namespace adbcpq
diff --git a/c/drivers/postgres/connection.h b/c/drivers/postgres/connection.h
new file mode 100644
index 0000000..87b14ca
--- /dev/null
+++ b/c/drivers/postgres/connection.h
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <memory>
+
+#include <adbc.h>
+#include <libpq-fe.h>
+
+#include "type.h"
+
+namespace adbcpq {
+class PostgresDatabase;
+class PostgresConnection {
+ public:
+ PostgresConnection() : database_(nullptr), conn_(nullptr), autocommit_(true) {}
+
+ AdbcStatusCode Commit(struct AdbcError* error);
+ AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema,
+ const char* table_name, struct ArrowSchema* schema,
+ struct AdbcError* error);
+ AdbcStatusCode Init(struct AdbcDatabase* database, struct AdbcError* error);
+ AdbcStatusCode Release(struct AdbcError* error);
+ AdbcStatusCode Rollback(struct AdbcError* error);
+ AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error);
+
+ PGconn* conn() const { return conn_; }
+ const std::shared_ptr<TypeMapping>& type_mapping() const { return type_mapping_; }
+
+ private:
+ std::shared_ptr<PostgresDatabase> database_;
+ std::shared_ptr<TypeMapping> type_mapping_;
+ PGconn* conn_;
+ bool autocommit_;
+};
+} // namespace adbcpq
diff --git a/c/drivers/postgres/database.cc b/c/drivers/postgres/database.cc
new file mode 100644
index 0000000..0931d97
--- /dev/null
+++ b/c/drivers/postgres/database.cc
@@ -0,0 +1,124 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "database.h"
+
+#include <cstring>
+#include <memory>
+
+#include <adbc.h>
+#include <libpq-fe.h>
+#include <nanoarrow.h>
+
+#include "util.h"
+
+namespace adbcpq {
+
+PostgresDatabase::PostgresDatabase() : open_connections_(0) {
+ type_mapping_ = std::make_shared<TypeMapping>();
+}
+PostgresDatabase::~PostgresDatabase() = default;
+
+AdbcStatusCode PostgresDatabase::Init(struct AdbcError* error) {
+ // Connect to validate the parameters.
+ PGconn* conn = nullptr;
+ AdbcStatusCode final_status = Connect(&conn, error);
+ if (final_status != ADBC_STATUS_OK) {
+ return final_status;
+ }
+
+ // Build the type mapping table.
+ const std::string kTypeQuery = R"(
+SELECT
+ oid,
+ typname,
+ typreceive
+FROM
+ pg_catalog.pg_type
+)";
+
+ pg_result* result = PQexec(conn, kTypeQuery.c_str());
+ ExecStatusType pq_status = PQresultStatus(result);
+ if (pq_status == PGRES_TUPLES_OK) {
+ int num_rows = PQntuples(result);
+ for (int row = 0; row < num_rows; row++) {
+ const uint32_t oid = static_cast<uint32_t>(
+ std::strtol(PQgetvalue(result, row, 0), /*str_end=*/nullptr, /*base=*/10));
+ const char* typname = PQgetvalue(result, row, 1);
+ const char* typreceive = PQgetvalue(result, row, 2);
+
+ type_mapping_->Insert(oid, typname, typreceive);
+ }
+ } else {
+ SetError(error, "Failed to execute build type mapping table: ", PQerrorMessage(conn));
+ final_status = ADBC_STATUS_IO;
+ }
+ PQclear(result);
+
+ // Disconnect since Postgres connections can be heavy.
+ {
+ AdbcStatusCode status = Disconnect(&conn, error);
+ if (status != ADBC_STATUS_OK) final_status = status;
+ }
+ return final_status;
+}
+
+AdbcStatusCode PostgresDatabase::Release(struct AdbcError* error) {
+ if (open_connections_ != 0) {
+ SetError(error, "Database released with ", open_connections_, " open connections");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresDatabase::SetOption(const char* key, const char* value,
+ struct AdbcError* error) {
+ if (strcmp(key, "uri") == 0) {
+ uri_ = value;
+ } else {
+ SetError(error, "Unknown database option ", key);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresDatabase::Connect(PGconn** conn, struct AdbcError* error) {
+ if (uri_.empty()) {
+ SetError(error, "Must set database option 'uri' before creating a connection");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ *conn = PQconnectdb(uri_.c_str());
+ if (PQstatus(*conn) != CONNECTION_OK) {
+ SetError(error, "Failed to connect: ", PQerrorMessage(*conn));
+ PQfinish(*conn);
+ *conn = nullptr;
+ return ADBC_STATUS_IO;
+ }
+ open_connections_++;
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresDatabase::Disconnect(PGconn** conn, struct AdbcError* error) {
+ PQfinish(*conn);
+ *conn = nullptr;
+ if (--open_connections_ < 0) {
+ SetError(error, "Open connection count underflowed");
+ return ADBC_STATUS_INTERNAL;
+ }
+ return ADBC_STATUS_OK;
+}
+} // namespace adbcpq
diff --git a/c/drivers/postgres/database.h b/c/drivers/postgres/database.h
new file mode 100644
index 0000000..4db8e34
--- /dev/null
+++ b/c/drivers/postgres/database.h
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+#include <adbc.h>
+#include <libpq-fe.h>
+
+#include "type.h"
+
+namespace adbcpq {
+class PostgresDatabase {
+ public:
+ PostgresDatabase();
+ ~PostgresDatabase();
+
+ // Public ADBC API
+
+ AdbcStatusCode Init(struct AdbcError* error);
+ AdbcStatusCode Release(struct AdbcError* error);
+ AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error);
+
+ // Internal implementation
+
+ AdbcStatusCode Connect(PGconn** conn, struct AdbcError* error);
+ AdbcStatusCode Disconnect(PGconn** conn, struct AdbcError* error);
+
+ const std::shared_ptr<TypeMapping>& type_mapping() const { return type_mapping_; }
+
+ private:
+ int32_t open_connections_;
+ std::string uri_;
+ std::shared_ptr<TypeMapping> type_mapping_;
+};
+} // namespace adbcpq
diff --git a/c/drivers/postgres/postgres.cc b/c/drivers/postgres/postgres.cc
new file mode 100644
index 0000000..afa56fc
--- /dev/null
+++ b/c/drivers/postgres/postgres.cc
@@ -0,0 +1,479 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// A libpq-based Postgres driver for ADBC.
+
+#include <cstring>
+#include <memory>
+
+#include <adbc.h>
+
+#include "connection.h"
+#include "database.h"
+#include "statement.h"
+
+using adbcpq::PostgresConnection;
+using adbcpq::PostgresDatabase;
+using adbcpq::PostgresStatement;
+
+// ---------------------------------------------------------------------
+// ADBC interface implementation - as private functions so that these
+// don't get replaced by the dynamic linker. If we implemented these
+// under the Adbc* names, then DriverInit, the linker may resolve
+// functions to the address of the functions provided by the driver
+// manager instead of our functions.
+//
+// We could also:
+// - Play games with RTLD_DEEPBIND - but this doesn't work with ASan
+// - Use __attribute__((visibility("protected"))) - but this is
+// apparently poorly supported by some linkers
+// - Play with -Bsymbolic(-functions) - but this has other
+// consequences and complicates the build setup
+//
+// So in the end some manual effort here was chosen.
+
+// ---------------------------------------------------------------------
+// AdbcDatabase
+
+namespace {
+AdbcStatusCode PostgresDatabaseInit(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ if (!database || !database->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr = reinterpret_cast<std::shared_ptr<PostgresDatabase>*>(database->private_data);
+ return (*ptr)->Init(error);
+}
+
+AdbcStatusCode PostgresDatabaseNew(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ if (!database || database->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto impl = std::make_shared<PostgresDatabase>();
+ database->private_data = new std::shared_ptr<PostgresDatabase>(impl);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresDatabaseRelease(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ if (!database->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr = reinterpret_cast<std::shared_ptr<PostgresDatabase>*>(database->private_data);
+ AdbcStatusCode status = (*ptr)->Release(error);
+ delete ptr;
+ database->private_data = nullptr;
+ return status;
+}
+
+AdbcStatusCode PostgresDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error) {
+ if (!database || !database->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr = reinterpret_cast<std::shared_ptr<PostgresDatabase>*>(database->private_data);
+ return (*ptr)->SetOption(key, value, error);
+}
+} // namespace
+
+AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) {
+ return PostgresDatabaseInit(database, error);
+}
+
+AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) {
+ return PostgresDatabaseNew(database, error);
+}
+
+AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ return PostgresDatabaseRelease(database, error);
+}
+
+AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key,
+ const char* value, struct AdbcError* error) {
+ return PostgresDatabaseSetOption(database, key, value, error);
+}
+
+// ---------------------------------------------------------------------
+// AdbcConnection
+
+namespace {
+AdbcStatusCode PostgresConnectionCommit(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
+ return (*ptr)->Commit(error);
+}
+
+AdbcStatusCode PostgresConnectionGetInfo(struct AdbcConnection* connection,
+ uint32_t* info_codes, size_t info_codes_length,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ // if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ // auto ptr =
+ // reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ // return (*ptr)->GetInfo(*ptr, info_codes, info_codes_length, error);
+}
+
+AdbcStatusCode PostgresConnectionGetObjects(
+ struct AdbcConnection* connection, int depth, const char* catalog,
+ const char* db_schema, const char* table_name, const char** table_types,
+ const char* column_name, struct AdbcStatement* statement, struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ // if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ // auto ptr =
+ // reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ // return (*ptr)->GetObjects(*ptr, depth, catalog, db_schema, table_name, table_types,
+ // column_name, error);
+}
+
+AdbcStatusCode PostgresConnectionGetTableSchema(
+ struct AdbcConnection* connection, const char* catalog, const char* db_schema,
+ const char* table_name, struct ArrowSchema* schema, struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
+ return (*ptr)->GetTableSchema(catalog, db_schema, table_name, schema, error);
+}
+
+AdbcStatusCode PostgresConnectionGetTableTypes(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ // if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ // auto ptr =
+ // reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ // return (*ptr)->GetTableTypes(*ptr, error);
+}
+
+AdbcStatusCode PostgresConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
+ return (*ptr)->Init(database, error);
+}
+
+AdbcStatusCode PostgresConnectionNew(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ auto impl = std::make_shared<PostgresConnection>();
+ connection->private_data = new std::shared_ptr<PostgresConnection>(impl);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresConnectionSetOption(struct AdbcConnection* connection,
+ const char* key, const char* value,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
+ return (*ptr)->SetOption(key, value, error);
+}
+
+AdbcStatusCode PostgresConnectionRelease(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
+ AdbcStatusCode status = (*ptr)->Release(error);
+ delete ptr;
+ connection->private_data = nullptr;
+ return status;
+}
+
+AdbcStatusCode PostgresConnectionRollback(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto ptr =
+ reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
+ return (*ptr)->Rollback(error);
+}
+} // namespace
+AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ return PostgresConnectionCommit(connection, error);
+}
+
+AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection,
+ uint32_t* info_codes, size_t info_codes_length,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return PostgresConnectionGetInfo(connection, info_codes, info_codes_length, statement,
+ error);
+}
+
+AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth,
+ const char* catalog, const char* db_schema,
+ const char* table_name, const char** table_types,
+ const char* column_name,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return PostgresConnectionGetObjects(connection, depth, catalog, db_schema, table_name,
+ table_types, column_name, statement, error);
+}
+
+AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
+ const char* catalog, const char* db_schema,
+ const char* table_name,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return PostgresConnectionGetTableSchema(connection, catalog, db_schema, table_name,
+ schema, error);
+}
+
+AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return PostgresConnectionGetTableTypes(connection, statement, error);
+}
+
+AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
+ struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ return PostgresConnectionInit(connection, database, error);
+}
+
+AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ return PostgresConnectionNew(connection, error);
+}
+
+AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ return PostgresConnectionRelease(connection, error);
+}
+
+AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ return PostgresConnectionRollback(connection, error);
+}
+
+AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key,
+ const char* value, struct AdbcError* error) {
+ return PostgresConnectionSetOption(connection, key, value, error);
+}
+
+// ---------------------------------------------------------------------
+// AdbcStatement
+
+namespace {
+AdbcStatusCode PostgresStatementBind(struct AdbcStatement* statement,
+ struct ArrowArray* values,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ return (*ptr)->Bind(values, schema, error);
+}
+
+AdbcStatusCode PostgresStatementBindStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ return (*ptr)->Bind(stream, error);
+}
+
+AdbcStatusCode PostgresStatementExecuteQuery(struct AdbcStatement* statement,
+ struct ArrowArrayStream* output,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ return (*ptr)->ExecuteQuery(output, rows_affected, error);
+}
+
+AdbcStatusCode PostgresStatementExecuteUpdate(struct AdbcStatement* statement,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ return (*ptr)->ExecuteUpdate(rows_affected, error);
+}
+
+AdbcStatusCode PostgresStatementGetPartitionDesc(struct AdbcStatement* statement,
+ uint8_t* partition_desc,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode PostgresStatementGetPartitionDescSize(struct AdbcStatement* statement,
+ size_t* length,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode PostgresStatementGetParameterSchema(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ return (*ptr)->GetParameterSchema(schema, error);
+}
+
+AdbcStatusCode PostgresStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ auto impl = std::make_shared<PostgresStatement>();
+ statement->private_data = new std::shared_ptr<PostgresStatement>(impl);
+ return impl->New(connection, error);
+}
+
+AdbcStatusCode PostgresStatementPrepare(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ return (*ptr)->Prepare(error);
+}
+
+AdbcStatusCode PostgresStatementRelease(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ auto status = (*ptr)->Release(error);
+ delete ptr;
+ statement->private_data = nullptr;
+ return status;
+}
+
+AdbcStatusCode PostgresStatementSetOption(struct AdbcStatement* statement,
+ const char* key, const char* value,
+ struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ return (*ptr)->SetOption(key, value, error);
+}
+
+AdbcStatusCode PostgresStatementSetSqlQuery(struct AdbcStatement* statement,
+ const char* query, struct AdbcError* error) {
+ if (!statement->private_data) return ADBC_STATUS_INVALID_STATE;
+ auto* ptr =
+ reinterpret_cast<std::shared_ptr<PostgresStatement>*>(statement->private_data);
+ return (*ptr)->SetSqlQuery(query, error);
+}
+} // namespace
+
+AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement,
+ struct ArrowArray* values, struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return PostgresStatementBind(statement, values, schema, error);
+}
+
+AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement,
+ struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ return PostgresStatementBindStream(statement, stream, error);
+}
+
+AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement,
+ struct ArrowArrayStream* output,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ return PostgresStatementExecuteQuery(statement, output, rows_affected, error);
+}
+
+AdbcStatusCode AdbcStatementExecuteUpdate(struct AdbcStatement* statement,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ return PostgresStatementExecuteUpdate(statement, rows_affected, error);
+}
+
+AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement,
+ uint8_t* partition_desc,
+ struct AdbcError* error) {
+ return PostgresStatementGetPartitionDesc(statement, partition_desc, error);
+}
+
+AdbcStatusCode AdbcStatementGetPartitionDescSize(struct AdbcStatement* statement,
+ size_t* length,
+ struct AdbcError* error) {
+ return PostgresStatementGetPartitionDescSize(statement, length, error);
+}
+
+AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return PostgresStatementGetParameterSchema(statement, schema, error);
+}
+
+AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return PostgresStatementNew(connection, statement, error);
+}
+
+AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return PostgresStatementPrepare(statement, error);
+}
+
+AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ return PostgresStatementRelease(statement, error);
+}
+
+AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key,
+ const char* value, struct AdbcError* error) {
+ return PostgresStatementSetOption(statement, key, value, error);
+}
+
+AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement,
+ const char* query, struct AdbcError* error) {
+ return PostgresStatementSetSqlQuery(statement, query, error);
+}
+
+extern "C" {
+ADBC_EXPORT
+AdbcStatusCode AdbcPostgresDriverInit(size_t count, struct AdbcDriver* driver,
+ size_t* initialized, struct AdbcError* error) {
+ if (count < ADBC_VERSION_0_0_1) return ADBC_STATUS_NOT_IMPLEMENTED;
+
+ std::memset(driver, 0, sizeof(*driver));
+ driver->DatabaseInit = PostgresDatabaseInit;
+ driver->DatabaseNew = PostgresDatabaseNew;
+ driver->DatabaseRelease = PostgresDatabaseRelease;
+ driver->DatabaseSetOption = PostgresDatabaseSetOption;
+
+ driver->ConnectionCommit = PostgresConnectionCommit;
+ // driver->ConnectionGetInfo = PostgresConnectionGetInfo;
+ // driver->ConnectionGetObjects = PostgresConnectionGetObjects;
+ driver->ConnectionGetTableSchema = PostgresConnectionGetTableSchema;
+ // driver->ConnectionGetTableTypes = PostgresConnectionGetTableTypes;
+ driver->ConnectionInit = PostgresConnectionInit;
+ driver->ConnectionNew = PostgresConnectionNew;
+ driver->ConnectionRelease = PostgresConnectionRelease;
+ driver->ConnectionRollback = PostgresConnectionRollback;
+ driver->ConnectionSetOption = PostgresConnectionSetOption;
+
+ driver->StatementBind = PostgresStatementBind;
+ driver->StatementBindStream = PostgresStatementBindStream;
+ driver->StatementExecuteQuery = PostgresStatementExecuteQuery;
+ driver->StatementExecuteUpdate = PostgresStatementExecuteUpdate;
+ driver->StatementGetParameterSchema = PostgresStatementGetParameterSchema;
+ driver->StatementNew = PostgresStatementNew;
+ driver->StatementPrepare = PostgresStatementPrepare;
+ driver->StatementRelease = PostgresStatementRelease;
+ driver->StatementSetOption = PostgresStatementSetOption;
+ driver->StatementSetSqlQuery = PostgresStatementSetSqlQuery;
+ *initialized = ADBC_VERSION_0_0_1;
+ return ADBC_STATUS_OK;
+}
+}
diff --git a/c/drivers/postgres/postgres_test.cc b/c/drivers/postgres/postgres_test.cc
new file mode 100644
index 0000000..92e0d2e
--- /dev/null
+++ b/c/drivers/postgres/postgres_test.cc
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdlib>
+#include <cstring>
+
+#include <adbc.h>
+#include <gtest/gtest.h>
+
+#include "validation/adbc_validation.h"
+
+class Postgres : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ std::memset(&ctx, 0, sizeof(ctx));
+ ctx.setup_database = SetupDatabase;
+ }
+
+ void TearDown() override {
+ ASSERT_EQ(ctx.failed, 0);
+ ASSERT_EQ(ctx.total, ctx.passed);
+ }
+
+ struct AdbcValidateTestContext ctx;
+
+ static AdbcStatusCode SetupDatabase(struct AdbcDatabase* database,
+ struct AdbcError* error) {
+ const char* uri = std::getenv("ADBC_POSTGRES_TEST_URI");
+ if (!uri) {
+ ADD_FAILURE() << "Must provide env var ADBC_POSTGRES_TEST_URI";
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ return AdbcDatabaseSetOption(database, "uri", uri, error);
+ }
+};
+
+TEST_F(Postgres, ValidationDatabase) { AdbcValidateDatabaseNewRelease(&ctx); }
+
+TEST_F(Postgres, ValidationConnectionNewRelease) {
+ AdbcValidateConnectionNewRelease(&ctx);
+}
+
+TEST_F(Postgres, ValidationConnectionAutocommit) {
+ AdbcValidateConnectionAutocommit(&ctx);
+}
+
+TEST_F(Postgres, ValidationStatementNewRelease) { AdbcValidateStatementNewRelease(&ctx); }
+
+TEST_F(Postgres, ValidationStatementSqlExecute) { AdbcValidateStatementSqlExecute(&ctx); }
+
+TEST_F(Postgres, ValidationStatementSqlIngest) { AdbcValidateStatementSqlIngest(&ctx); }
+
+TEST_F(Postgres, ValidationStatementSqlPrepare) {
+ GTEST_SKIP() << "Not yet implemented";
+ AdbcValidateStatementSqlPrepare(&ctx);
+}
diff --git a/c/drivers/postgres/statement.cc b/c/drivers/postgres/statement.cc
new file mode 100644
index 0000000..10e765f
--- /dev/null
+++ b/c/drivers/postgres/statement.cc
@@ -0,0 +1,792 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "statement.h"
+
+#include <endian.h>
+#include <netinet/in.h>
+#include <array>
+#include <cerrno>
+#include <cstring>
+#include <iostream>
+#include <memory>
+#include <vector>
+
+#include <adbc.h>
+#include <libpq-fe.h>
+#include <nanoarrow.h>
+
+#include "connection.h"
+#include "util.h"
+
+namespace adbcpq {
+
+namespace {
+/// The header that comes at the start of a binary COPY stream
+constexpr std::array<char, 11> kPgCopyBinarySignature = {
+ 'P', 'G', 'C', 'O', 'P', 'Y', '\n', '\377', '\r', '\n', '\0'};
+/// The flag indicating to Postgres that we want binary-format values.
+constexpr int kPgBinaryFormat = 1;
+
+/// One-value ArrowArrayStream used to unify the implementations of Bind
+struct OneValueStream {
+ struct ArrowSchema schema;
+ struct ArrowArray array;
+
+ static int GetSchema(struct ArrowArrayStream* self, struct ArrowSchema* out) {
+ OneValueStream* stream = static_cast<OneValueStream*>(self->private_data);
+ return ArrowSchemaDeepCopy(&stream->schema, out);
+ }
+ static int GetNext(struct ArrowArrayStream* self, struct ArrowArray* out) {
+ OneValueStream* stream = static_cast<OneValueStream*>(self->private_data);
+ *out = stream->array;
+ stream->array.release = nullptr;
+ return 0;
+ }
+ static const char* GetLastError(struct ArrowArrayStream* self) { return NULL; }
+ static void Release(struct ArrowArrayStream* self) {
+ OneValueStream* stream = static_cast<OneValueStream*>(self->private_data);
+ if (stream->schema.release) {
+ stream->schema.release(&stream->schema);
+ stream->schema.release = nullptr;
+ }
+ if (stream->array.release) {
+ stream->array.release(&stream->array);
+ stream->array.release = nullptr;
+ }
+ delete stream;
+ self->release = nullptr;
+ }
+};
+
+/// Helper to manage resources with RAII
+
+template <typename T>
+struct Releaser {
+ static void Release(T* value) {
+ if (value->release) {
+ value->release(value);
+ }
+ }
+};
+
+template <>
+struct Releaser<struct ArrowArrayView> {
+ static void Release(struct ArrowArrayView* value) {
+ if (value->storage_type != NANOARROW_TYPE_UNINITIALIZED) {
+ ArrowArrayViewReset(value);
+ }
+ }
+};
+
+template <typename Resource>
+struct Handle {
+ Resource value;
+
+ Handle() { std::memset(&value, 0, sizeof(value)); }
+
+ ~Handle() { Releaser<Resource>::Release(&value); }
+
+ Resource* operator->() { return &value; }
+};
+
+/// Build an Arrow schema from a Postgres result set
+AdbcStatusCode InferSchema(const TypeMapping& type_mapping, PGresult* result,
+ struct ArrowSchema* out, struct AdbcError* error) {
+ const int num_fields = PQnfields(result);
+ CHECK_NA_ADBC(ArrowSchemaInit(out, NANOARROW_TYPE_STRUCT), error);
+ CHECK_NA_ADBC(ArrowSchemaAllocateChildren(out, num_fields), error);
+ for (int i = 0; i < num_fields; i++) {
+ ArrowType field_type = NANOARROW_TYPE_NA;
+ const Oid pg_type = PQftype(result, i);
+
+ auto it = type_mapping.type_mapping.find(pg_type);
+ if (it == type_mapping.type_mapping.end()) {
+ SetError(error, "Column #", i + 1, " (\"", PQfname(result, i),
+ "\") has unknown type code ", pg_type);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+
+ switch (it->second) {
+ // TODO: this mapping will eventually have to become dynamic,
+ // because of complex types like arrays/records
+ case PgType::kBool:
+ field_type = NANOARROW_TYPE_BOOL;
+ break;
+ case PgType::kInt2:
+ field_type = NANOARROW_TYPE_INT16;
+ break;
+ case PgType::kInt4:
+ field_type = NANOARROW_TYPE_INT32;
+ break;
+ case PgType::kInt8:
+ field_type = NANOARROW_TYPE_INT64;
+ break;
+ default:
+ SetError(error, "Column #", i + 1, " (\"", PQfname(result, i),
+ "\") has unimplemented type code ", pg_type);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ CHECK_NA_ADBC(ArrowSchemaInit(out->children[i], field_type), error);
+ CHECK_NA_ADBC(ArrowSchemaSetName(out->children[i], PQfname(result, i)), error);
+ }
+ return ADBC_STATUS_OK;
+}
+
+int32_t LoadNetworkUInt32(const char* buf) {
+ uint32_t v = 0;
+ std::memcpy(&v, buf, sizeof(uint32_t));
+ return ntohl(v);
+}
+
+int64_t LoadNetworkUInt64(const char* buf) {
+ uint64_t v = 0;
+ std::memcpy(&v, buf, sizeof(uint64_t));
+ return be64toh(v);
+}
+
+int32_t LoadNetworkInt32(const char* buf) {
+ return static_cast<int32_t>(LoadNetworkUInt32(buf));
+}
+
+int64_t LoadNetworkInt64(const char* buf) {
+ return static_cast<int64_t>(LoadNetworkUInt64(buf));
+}
+
+uint64_t ToNetworkInt64(int64_t v) { return htobe64(static_cast<uint64_t>(v)); }
+} // namespace
+
+int TupleReader::GetSchema(struct ArrowSchema* out) {
+ if (!result_) {
+ last_error_ = "[libpq] Result set was already consumed or freed";
+ return EINVAL;
+ }
+
+ std::memset(out, 0, sizeof(*out));
+ CHECK_NA(ArrowSchemaDeepCopy(&schema_, out));
+ return 0;
+}
+
+int TupleReader::GetNext(struct ArrowArray* out) {
+ if (!result_) {
+ out->release = nullptr;
+ return 0;
+ }
+
+ // Clear the result, since the data is actually read from the connection
+ PQclear(result_);
+ result_ = nullptr;
+
+ struct ArrowError error;
+ // TODO: consistently release out on error (use another trampoline?)
+ int na_res = ArrowArrayInitFromSchema(out, &schema_, &error);
+ if (na_res != 0) {
+ last_error_ = StringBuilder("[libpq] Failed to init output array: ", na_res,
+ std::strerror(na_res), ": ", error.message);
+ if (out->release) out->release(out);
+ return na_res;
+ }
+
+ std::vector<ArrowSchemaView> fields(schema_.n_children);
+ for (int col = 0; col < schema_.n_children; col++) {
+ na_res = ArrowSchemaViewInit(&fields[col], schema_.children[col], &error);
+ if (na_res != 0) {
+ last_error_ = StringBuilder("[libpq] Failed to init schema view: ", na_res,
+ std::strerror(na_res), ": ", error.message);
+ if (out->release) out->release(out);
+ return na_res;
+ }
+
+ struct ArrowBitmap validity_bitmap;
+ ArrowBitmapInit(&validity_bitmap);
+ ArrowArraySetValidityBitmap(out->children[col], &validity_bitmap);
+ }
+
+ // TODO: we need to always PQgetResult
+
+ char* buf = nullptr;
+ int buf_size = 0;
+
+ // Get the header
+ {
+ constexpr size_t kPqHeaderLength =
+ kPgCopyBinarySignature.size() + sizeof(uint32_t) + sizeof(uint32_t);
+ // https://www.postgresql.org/docs/14/sql-copy.html#id-1.9.3.55.9.4.5
+ const int size = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
+ if (size < kPqHeaderLength) {
+ return EIO;
+ } else if (std::strcmp(pgbuf_, kPgCopyBinarySignature.data()) != 0) {
+ return EIO;
+ }
+ buf = pgbuf_ + kPgCopyBinarySignature.size();
+
+ uint32_t flags = LoadNetworkUInt32(buf);
+ buf += sizeof(uint32_t);
+ if (flags != 0) {
+ return EIO;
+ }
+
+ // XXX: is this signed or unsigned? not stated by the docs
+ uint32_t extension_length = LoadNetworkUInt32(buf);
+ buf += sizeof(uint32_t) + extension_length;
+
+ buf_size = size - (kPqHeaderLength + extension_length);
+ }
+
+ // Append each row
+ int result_code = 0;
+ int64_t num_rows = 0;
+ last_error_.clear();
+ do {
+ result_code = AppendNext(fields.data(), buf, buf_size, &num_rows, out);
+ PQfreemem(pgbuf_);
+ if (result_code != 0) break;
+
+ buf_size = PQgetCopyData(conn_, &pgbuf_, /*async=*/0);
+ if (buf_size < 0) {
+ pgbuf_ = buf = nullptr;
+ break;
+ }
+ buf = pgbuf_;
+ } while (true);
+
+ // Finish the result array
+ for (int col = 0; col < schema_.n_children; col++) {
+ out->children[col]->length = num_rows;
+ }
+ out->length = num_rows;
+ na_res = ArrowArrayFinishBuilding(out, 0);
+ if (na_res != 0) {
+ result_code = na_res;
+ if (!last_error_.empty()) last_error_ += '\n';
+ last_error_ += StringBuilder("[libpq] Failed to build result array");
+ }
+
+ // Check the server-side response
+ result_ = PQgetResult(conn_);
+ const int pq_status = PQresultStatus(result_);
+ if (pq_status != PGRES_COMMAND_OK) {
+ if (!last_error_.empty()) last_error_ += '\n';
+ last_error_ += StringBuilder("[libpq] Query failed: (", pq_status, ") ",
+ PQresultErrorMessage(result_));
+ result_code = EIO;
+ }
+ PQclear(result_);
+ result_ = nullptr;
+ return result_code;
+}
+
+void TupleReader::Release() {
+ if (result_) {
+ PQclear(result_);
+ result_ = nullptr;
+ }
+ if (schema_.release) {
+ schema_.release(&schema_);
+ }
+ if (pgbuf_) {
+ PQfreemem(pgbuf_);
+ pgbuf_ = nullptr;
+ }
+}
+
+void TupleReader::ExportTo(struct ArrowArrayStream* stream) {
+ stream->get_schema = &GetSchemaTrampoline;
+ stream->get_next = &GetNextTrampoline;
+ stream->get_last_error = &GetLastErrorTrampoline;
+ stream->release = &ReleaseTrampoline;
+ stream->private_data = this;
+}
+
+int TupleReader::AppendNext(struct ArrowSchemaView* fields, const char* buf, int buf_size,
+ int64_t* row_count, struct ArrowArray* out) {
+ // https://www.postgresql.org/docs/14/sql-copy.html#id-1.9.3.55.9.4.6
+ // TODO: DCHECK_GE(buf_size, 2) << "Buffer too short to contain field count";
+
+ int16_t field_count = 0;
+ std::memcpy(&field_count, buf, sizeof(int16_t));
+ buf += sizeof(int16_t);
+ field_count = ntohs(field_count);
+
+ if (field_count == -1) {
+ // end-of-stream
+ return 0;
+ } else if (field_count != schema_.n_children) {
+ last_error_ = StringBuilder("[libpq] Expected ", schema_.n_children,
+ " fields but found ", field_count);
+ return EIO;
+ }
+
+ for (int col = 0; col < schema_.n_children; col++) {
+ int32_t field_length = LoadNetworkInt32(buf);
+ buf += sizeof(int32_t);
+
+ struct ArrowBitmap* bitmap = ArrowArrayValidityBitmap(out->children[col]);
+
+ CHECK_NA(ArrowBitmapAppend(bitmap, field_length >= 0, 1));
+
+ switch (fields[col].data_type) {
+ case NANOARROW_TYPE_INT32: {
+ // DCHECK_EQ(field_length, 4);
+ struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
+ int32_t value = LoadNetworkInt32(buf);
+ buf += sizeof(int32_t);
+ CHECK_NA(ArrowBufferAppendInt32(buffer, value));
+ break;
+ }
+ case NANOARROW_TYPE_INT64: {
+ // DCHECK_EQ(field_length, 8);
+ struct ArrowBuffer* buffer = ArrowArrayBuffer(out->children[col], 1);
+ int64_t value = field_length < 0 ? 0 : LoadNetworkInt64(buf);
+ buf += sizeof(int64_t);
+ CHECK_NA(ArrowBufferAppendInt64(buffer, value));
+ break;
+ }
+ default:
+ last_error_ = StringBuilder("[libpq] Column #", col + 1, " (\"",
+ schema_.children[col]->name,
+ "\") has unsupported type ", fields[col].data_type);
+ return ENOTSUP;
+ }
+ }
+ (*row_count)++;
+ return 0;
+}
+
+int TupleReader::GetSchemaTrampoline(struct ArrowArrayStream* self,
+ struct ArrowSchema* out) {
+ if (!self || !self->private_data) return EINVAL;
+
+ TupleReader* reader = static_cast<TupleReader*>(self->private_data);
+ return reader->GetSchema(out);
+}
+
+int TupleReader::GetNextTrampoline(struct ArrowArrayStream* self,
+ struct ArrowArray* out) {
+ if (!self || !self->private_data) return EINVAL;
+
+ TupleReader* reader = static_cast<TupleReader*>(self->private_data);
+ return reader->GetNext(out);
+}
+
+const char* TupleReader::GetLastErrorTrampoline(struct ArrowArrayStream* self) {
+ if (!self || !self->private_data) return nullptr;
+
+ TupleReader* reader = static_cast<TupleReader*>(self->private_data);
+ return reader->last_error();
+}
+
+void TupleReader::ReleaseTrampoline(struct ArrowArrayStream* self) {
+ if (!self || !self->private_data) return;
+
+ TupleReader* reader = static_cast<TupleReader*>(self->private_data);
+ reader->Release();
+ self->private_data = nullptr;
+ self->release = nullptr;
+}
+
+AdbcStatusCode PostgresStatement::New(struct AdbcConnection* connection,
+ struct AdbcError* error) {
+ if (!connection || !connection->private_data) {
+ SetError(error, "Must provide an initialized AdbcConnection");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ connection_ =
+ *reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
+ type_mapping_ = connection_->type_mapping();
+ reader_.conn_ = connection_->conn();
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::Bind(struct ArrowArray* values,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!values || !values->release) {
+ SetError(error, "Must provide non-NULL array");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ } else if (!schema || !schema->release) {
+ SetError(error, "Must provide non-NULL schema");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ if (bind_.release) bind_.release(&bind_);
+ // Make a one-value stream
+ bind_.private_data = new OneValueStream{*schema, *values};
+ bind_.get_schema = &OneValueStream::GetSchema;
+ bind_.get_next = &OneValueStream::GetNext;
+ bind_.get_last_error = &OneValueStream::GetLastError;
+ bind_.release = &OneValueStream::Release;
+ std::memset(values, 0, sizeof(*values));
+ std::memset(schema, 0, sizeof(*schema));
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::Bind(struct ArrowArrayStream* stream,
+ struct AdbcError* error) {
+ if (!stream || !stream->release) {
+ SetError(error, "Must provide non-NULL stream");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ // Move stream
+ if (bind_.release) bind_.release(&bind_);
+ bind_ = *stream;
+ std::memset(stream, 0, sizeof(*stream));
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::CreateBulkTable(
+ const struct ArrowSchema& source_schema,
+ const std::vector<struct ArrowSchemaView>& source_schema_fields,
+ struct AdbcError* error) {
+ std::string create = "CREATE TABLE ";
+ create += ingest_.target;
+ create += " (";
+
+ for (size_t i = 0; i < source_schema_fields.size(); i++) {
+ if (i > 0) create += ", ";
+ create += source_schema.children[i]->name;
+ switch (source_schema_fields[i].data_type) {
+ case ArrowType::NANOARROW_TYPE_INT16:
+ create += " SMALLINT";
+ break;
+ case ArrowType::NANOARROW_TYPE_INT32:
+ create += " INTEGER";
+ break;
+ case ArrowType::NANOARROW_TYPE_INT64:
+ create += " BIGINT";
+ break;
+ default:
+ // TODO: data type to string
+ SetError(error, "Field #", i + 1, " ('", source_schema.children[i]->name,
+ "') has unsupported type for ingestion ",
+ source_schema_fields[i].data_type);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ }
+
+ create += ")";
+ SetError(error, create);
+ PGresult* result = PQexecParams(connection_->conn(), create.c_str(), /*nParams=*/0,
+ /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
+ /*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
+ /*resultFormat=*/1 /*(binary)*/);
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error, "Failed to create table: ", PQerrorMessage(connection_->conn()));
+ SetError(error, "Query: ", create);
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ PQclear(result);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,
+ int64_t* rows_affected,
+ struct AdbcError* error) {
+ if (query_.empty()) {
+ SetError(error, "Must SetSqlQuery before ExecuteQuery");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ if (!stream) {
+ SetError(error, "Must provide output for ExecuteQuery");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ ClearResult();
+
+ // 1. Execute the query with LIMIT 0 to get the schema
+ {
+ std::string schema_query = "SELECT * FROM (" + query_ + ") AS ignored LIMIT 0";
+ PGresult* result =
+ PQexecParams(connection_->conn(), query_.c_str(), /*nParams=*/0,
+ /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
+ /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, kPgBinaryFormat);
+ if (PQresultStatus(result) != PGRES_TUPLES_OK) {
+ SetError(error, "Failed to execute query: ", PQerrorMessage(connection_->conn()));
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ AdbcStatusCode status = InferSchema(*type_mapping_, result, &reader_.schema_, error);
+ PQclear(result);
+ if (status != ADBC_STATUS_OK) return status;
+ }
+
+ // 2. Execute the query with COPY to get binary tuples
+ {
+ std::string copy_query = "COPY (" + query_ + ") TO STDOUT (FORMAT binary)";
+ reader_.result_ =
+ PQexecParams(connection_->conn(), copy_query.c_str(), /*nParams=*/0,
+ /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
+ /*paramLengths=*/nullptr, /*paramFormats=*/nullptr, kPgBinaryFormat);
+ if (PQresultStatus(reader_.result_) != PGRES_COPY_OUT) {
+ SetError(error, "Failed to execute query: ", PQerrorMessage(connection_->conn()));
+ ClearResult();
+ return ADBC_STATUS_IO;
+ }
+ // Result is read from the connection, not the result, but we won't clear it here
+ }
+
+ reader_.ExportTo(stream);
+ if (rows_affected) *rows_affected = -1;
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::ExecuteUpdate(int64_t* rows_affected,
+ struct AdbcError* error) {
+ ClearResult();
+ if (!ingest_.target.empty()) {
+ return ExecuteUpdateBulk(rows_affected, error);
+ }
+ return ExecuteUpdateQuery(rows_affected, error);
+}
+
+AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
+ struct AdbcError* error) {
+ if (!bind_.release) {
+ SetError(error, "Must Bind() before Execute() for bulk ingestion");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ Handle<struct ArrowSchema> source_schema;
+ struct ArrowSchemaView source_schema_view;
+ CHECK_NA_ADBC(bind_.get_schema(&bind_, &source_schema.value), error);
+ CHECK_NA_ADBC(
+ ArrowSchemaViewInit(&source_schema_view, &source_schema.value, /*error*/ nullptr),
+ error);
+
+ if (source_schema_view.data_type != ArrowType::NANOARROW_TYPE_STRUCT) {
+ SetError(error, "Bind parameters must have type STRUCT");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ std::vector<struct ArrowSchemaView> source_schema_fields(source_schema->n_children);
+ for (size_t i = 0; i < source_schema_fields.size(); i++) {
+ CHECK_NA_ADBC(ArrowSchemaViewInit(&source_schema_fields[i],
+ source_schema->children[i], /*error*/ nullptr),
+ error);
+ }
+
+ if (!ingest_.append) {
+ // CREATE TABLE
+ AdbcStatusCode status =
+ CreateBulkTable(source_schema.value, source_schema_fields, error);
+ if (status != ADBC_STATUS_OK) return status;
+ }
+
+ // Prepare and insert
+ std::vector<uint32_t> param_types(source_schema_fields.size());
+ std::vector<char*> param_values(source_schema_fields.size());
+ std::vector<int> param_lengths(source_schema_fields.size());
+ std::vector<int> param_formats(source_schema_fields.size(), /*value=*/kPgBinaryFormat);
+ // XXX: this assumes fixed-length fields only - will need more
+ // consideration to deal with variable-length fields
+
+ std::string insert = "INSERT INTO ";
+ insert += ingest_.target;
+ insert += " VALUES (";
+ for (size_t i = 0; i < source_schema_fields.size(); i++) {
+ if (i > 0) insert += ", ";
+ insert += "$";
+ insert += std::to_string(i + 1);
+
+ PgType pg_type;
+ switch (source_schema_fields[i].data_type) {
+ case ArrowType::NANOARROW_TYPE_INT16:
+ pg_type = PgType::kInt2;
+ param_lengths[i] = 2;
+ break;
+ case ArrowType::NANOARROW_TYPE_INT32:
+ pg_type = PgType::kInt4;
+ param_lengths[i] = 4;
+ break;
+ case ArrowType::NANOARROW_TYPE_INT64:
+ pg_type = PgType::kInt8;
+ param_lengths[i] = 8;
+ break;
+ default:
+ // TODO: data type to string
+ SetError(error, "Field #", i + 1, " ('", source_schema->children[i]->name,
+ "') has unsupported type for ingestion ",
+ source_schema_fields[i].data_type);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+
+ param_types[i] = type_mapping_->GetOid(pg_type);
+ if (param_types[i] == 0) {
+ // TODO: data type to string
+ SetError(error, "Field #", i + 1, " ('", source_schema->children[i]->name,
+ "') has unsupported type for ingestion ",
+ source_schema_fields[i].data_type);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ }
+ insert += ")";
+
+ size_t param_values_length = 0;
+ std::vector<size_t> param_values_offsets;
+ for (int length : param_lengths) {
+ param_values_offsets.push_back(param_values_length);
+ param_values_length += length;
+ }
+ std::vector<char> param_values_buffer(param_values_length);
+
+ PGresult* result = PQprepare(connection_->conn(), /*stmtName=*/"", insert.c_str(),
+ /*nParams=*/source_schema->n_children, param_types.data());
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error, "Failed to prepare query: ", PQerrorMessage(connection_->conn()));
+ SetError(error, "Query: ", insert);
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ PQclear(result);
+
+ // TODO: wrap this in BEGIN/END TRANSACTION (unless not in auto-commit mode?)
+ while (true) {
+ Handle<struct ArrowArray> array;
+ int res = bind_.get_next(&bind_, &array.value);
+ if (res != 0) {
+ bind_.release(&bind_);
+ // TODO: include errno
+ SetError(error, "Failed to read next batch from stream of bind parameters: ",
+ bind_.get_last_error(&bind_));
+ return ADBC_STATUS_IO;
+ }
+ // TODO: set rows_affected
+ if (!array->release) break;
+
+ Handle<struct ArrowArrayView> array_view;
+ // TODO: include error messages
+ CHECK_NA_ADBC(
+ ArrowArrayViewInitFromSchema(&array_view.value, &source_schema.value, nullptr),
+ error);
+ CHECK_NA_ADBC(ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr),
+ error);
+
+ for (int64_t row = 0; row < array->length; row++) {
+ for (int64_t col = 0; col < array_view->n_children; col++) {
+ if (ArrowArrayViewIsNull(array_view->children[col], row)) {
+ param_values[col] = nullptr;
+ continue;
+ } else {
+ param_values[col] = param_values_buffer.data() + param_values_offsets[col];
+ }
+ switch (source_schema_fields[col].data_type) {
+ case ArrowType::NANOARROW_TYPE_INT64: {
+ const int64_t value = ToNetworkInt64(
+ array_view->children[col]->buffer_views[1].data.as_int64[row]);
+ std::memcpy(param_values[col], &value, sizeof(int64_t));
+ break;
+ }
+ default:
+ bind_.release(&bind_);
+ // TODO: data type to string
+ SetError(error, "Field #", col + 1, " ('", source_schema->children[col]->name,
+ "') has unsupported type for ingestion ",
+ source_schema_fields[col].data_type);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ }
+
+ result = PQexecPrepared(connection_->conn(), /*stmtName=*/"",
+ /*nParams=*/source_schema->n_children, param_values.data(),
+ param_lengths.data(), param_formats.data(),
+ /*resultFormat=*/0 /*text*/);
+
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error, "Failed to insert row: ", PQerrorMessage(connection_->conn()));
+ PQclear(result);
+ bind_.release(&bind_);
+ return ADBC_STATUS_IO;
+ }
+
+ PQclear(result);
+ }
+ }
+
+ bind_.release(&bind_);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::ExecuteUpdateQuery(int64_t* rows_affected,
+ struct AdbcError* error) {
+ PGresult* result = PQexecParams(connection_->conn(), query_.c_str(), /*nParams=*/0,
+ /*paramTypes=*/nullptr, /*paramValues=*/nullptr,
+ /*paramLengths=*/nullptr, /*paramFormats=*/nullptr,
+ /*resultFormat=*/1 /*(binary)*/);
+ if (PQresultStatus(result) != PGRES_COMMAND_OK) {
+ SetError(error, "Failed to execute query: ", PQerrorMessage(connection_->conn()));
+ PQclear(result);
+ return ADBC_STATUS_IO;
+ }
+ if (rows_affected) *rows_affected = PQntuples(reader_.result_);
+ PQclear(result);
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::GetParameterSchema(struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode PostgresStatement::Prepare(struct AdbcError* error) {
+ if (query_.empty()) {
+ SetError(error, "Must SetSqlQuery() before Prepare()");
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode PostgresStatement::Release(struct AdbcError* error) {
+ ClearResult();
+ if (bind_.release) {
+ bind_.release(&bind_);
+ }
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::SetSqlQuery(const char* query,
+ struct AdbcError* error) {
+ ingest_.target.clear();
+ query_ = query;
+ return ADBC_STATUS_OK;
+}
+
+AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value,
+ struct AdbcError* error) {
+ if (std::strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) {
+ query_.clear();
+ ingest_.target = value;
+ } else if (std::strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) {
+ if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) {
+ ingest_.append = false;
+ } else if (std::strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) {
+ ingest_.append = true;
+ } else {
+ SetError(error, "Invalid value ", value, " for option ", key);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ } else {
+ SetError(error, "Unknown statement option ", key);
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ return ADBC_STATUS_OK;
+}
+
+void PostgresStatement::ClearResult() {
+ // TODO: we may want to synchronize here for safety
+ reader_.Release();
+}
+} // namespace adbcpq
diff --git a/c/drivers/postgres/statement.h b/c/drivers/postgres/statement.h
new file mode 100644
index 0000000..305f8f0
--- /dev/null
+++ b/c/drivers/postgres/statement.h
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstring>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <adbc.h>
+#include <libpq-fe.h>
+#include <nanoarrow.h>
+
+#include "type.h"
+
+namespace adbcpq {
+class PostgresConnection;
+class PostgresStatement;
+
+/// \brief An ArrowArrayStream that reads tuples from a PGresult.
+class TupleReader final {
+ public:
+ TupleReader(PGconn* conn) : conn_(conn), result_(nullptr), pgbuf_(nullptr) {
+ std::memset(&schema_, 0, sizeof(schema_));
+ }
+
+ int GetSchema(struct ArrowSchema* out);
+ int GetNext(struct ArrowArray* out);
+ const char* last_error() const { return last_error_.c_str(); }
+ void Release();
+
+ int AppendNext(struct ArrowSchemaView* fields, const char* buf, int buf_size,
+ int64_t* row_count, struct ArrowArray* out);
+ void ExportTo(struct ArrowArrayStream* stream);
+
+ private:
+ friend class PostgresStatement;
+
+ static int GetSchemaTrampoline(struct ArrowArrayStream* self, struct ArrowSchema* out);
+ static int GetNextTrampoline(struct ArrowArrayStream* self, struct ArrowArray* out);
+ static const char* GetLastErrorTrampoline(struct ArrowArrayStream* self);
+ static void ReleaseTrampoline(struct ArrowArrayStream* self);
+
+ PGconn* conn_;
+ PGresult* result_;
+ char* pgbuf_;
+ struct ArrowSchema schema_;
+ std::string last_error_;
+};
+
+class PostgresStatement {
+ public:
+ PostgresStatement() : connection_(nullptr), query_(), reader_(nullptr) {
+ std::memset(&bind_, 0, sizeof(bind_));
+ }
+
+ // ---------------------------------------------------------------------
+ // ADBC API implementation
+
+ AdbcStatusCode Bind(struct ArrowArray* values, struct ArrowSchema* schema,
+ struct AdbcError* error);
+ AdbcStatusCode Bind(struct ArrowArrayStream* stream, struct AdbcError* error);
+ AdbcStatusCode ExecuteQuery(struct ArrowArrayStream* stream, int64_t* rows_affected,
+ struct AdbcError* error);
+ AdbcStatusCode ExecuteUpdate(int64_t* rows_affected, struct AdbcError* error);
+ AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error);
+ AdbcStatusCode New(struct AdbcConnection* connection, struct AdbcError* error);
+ AdbcStatusCode Prepare(struct AdbcError* error);
+ AdbcStatusCode Release(struct AdbcError* error);
+ AdbcStatusCode SetOption(const char* key, const char* value, struct AdbcError* error);
+ AdbcStatusCode SetSqlQuery(const char* query, struct AdbcError* error);
+
+ // ---------------------------------------------------------------------
+ // Helper methods
+
+ void ClearResult();
+ AdbcStatusCode CreateBulkTable(
+ const struct ArrowSchema& source_schema,
+ const std::vector<struct ArrowSchemaView>& source_schema_fields,
+ struct AdbcError* error);
+ AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error);
+ AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError* error);
+
+ private:
+ std::shared_ptr<TypeMapping> type_mapping_;
+ std::shared_ptr<PostgresConnection> connection_;
+
+ // Query state
+ std::string query_;
+ struct ArrowArrayStream bind_;
+
+ // Bulk ingest state
+ struct {
+ std::string target;
+ bool append = false;
+ } ingest_;
+
+ TupleReader reader_;
+};
+} // namespace adbcpq
diff --git a/c/drivers/postgres/type.cc b/c/drivers/postgres/type.cc
new file mode 100644
index 0000000..0c4a731
--- /dev/null
+++ b/c/drivers/postgres/type.cc
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "type.h"
+
+#include <cstring>
+
+namespace adbcpq {
+void TypeMapping::Insert(uint32_t oid, const char* typname, const char* typreceive) {
+ PgType type;
+ if (FromPgTypreceive(typreceive, &type)) {
+ type_mapping[oid] = type;
+ }
+
+ // Record 'canonical' types
+ if (std::strcmp(typname, "int8") == 0) {
+ // DCHECK_EQ(type, PgType::kInt8);
+ canonical_types[PgType::kInt8] = oid;
+ }
+ // TODO: fill in remainder
+}
+
+uint32_t TypeMapping::GetOid(PgType type) const {
+ auto it = canonical_types.find(type);
+ if (it == canonical_types.end()) {
+ return 0;
+ }
+ return it->second;
+}
+
+bool FromPgTypreceive(const char* typreceive, PgType* out) {
+ if (std::strcmp(typreceive, "bitrecv") == 0) {
+ *out = PgType::kBit;
+ } else if (std::strcmp(typreceive, "boolrecv") == 0) {
+ *out = PgType::kBool;
+ } else if (std::strcmp(typreceive, "date_recv") == 0) {
+ *out = PgType::kDate;
+ } else if (std::strcmp(typreceive, "float4recv") == 0) {
+ *out = PgType::kFloat4;
+ } else if (std::strcmp(typreceive, "float8recv") == 0) {
+ *out = PgType::kFloat8;
+ } else if (std::strcmp(typreceive, "int2recv") == 0) {
+ *out = PgType::kInt2;
+ } else if (std::strcmp(typreceive, "int4recv") == 0) {
+ *out = PgType::kInt4;
+ } else if (std::strcmp(typreceive, "int8recv") == 0) {
+ *out = PgType::kInt8;
+ } else if (std::strcmp(typreceive, "textrecv") == 0) {
+ *out = PgType::kText;
+ } else if (std::strcmp(typreceive, "time_recv") == 0) {
+ *out = PgType::kTime;
+ } else if (std::strcmp(typreceive, "timestamp_recv") == 0) {
+ *out = PgType::kTimestamp;
+ } else if (std::strcmp(typreceive, "timestamptz_recv") == 0) {
+ *out = PgType::kTimestampTz;
+ } else if (std::strcmp(typreceive, "timetz_recv") == 0) {
+ *out = PgType::kTimeTz;
+ } else if (std::strcmp(typreceive, "varcharrecv") == 0) {
+ *out = PgType::kVarChar;
+ } else {
+ return false;
+ }
+ return true;
+}
+
+} // namespace adbcpq
diff --git a/c/drivers/postgres/type.h b/c/drivers/postgres/type.h
new file mode 100644
index 0000000..b93bf99
--- /dev/null
+++ b/c/drivers/postgres/type.h
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+#include <unordered_map>
+
+#include <nanoarrow.h>
+
+namespace adbcpq {
+
+enum class PgType : uint8_t {
+ // TODO: is there a good null type?
+ kBit,
+ kBool,
+ kDate,
+ kFloat4,
+ kFloat8,
+ kInt2,
+ kInt4,
+ kInt8,
+ kText,
+ kTime,
+ kTimestamp,
+ kTimestampTz,
+ kTimeTz,
+ kVarChar,
+};
+
+struct TypeMapping {
+ // Maps Postgres type OIDs to a standardized type name
+ // Example: int8 == 20
+ std::unordered_map<uint32_t, PgType> type_mapping;
+ // Maps standardized type names to the Postgres type OID to use
+ // Example: kInt8 == 20
+ std::unordered_map<PgType, uint32_t> canonical_types;
+
+ void Insert(uint32_t oid, const char* typname, const char* typreceive);
+ /// \return 0 if not found
+ uint32_t GetOid(PgType type) const;
+};
+
+bool FromPgTypreceive(const char* typreceive, PgType* out);
+
+// TODO: this should be upstream
+// const char* ArrowTypeToString(ArrowType type);
+
+} // namespace adbcpq
diff --git a/c/drivers/postgres/util.h b/c/drivers/postgres/util.h
new file mode 100644
index 0000000..8778ad8
--- /dev/null
+++ b/c/drivers/postgres/util.h
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstring>
+#include <sstream>
+#include <string>
+#include <utility>
+
+#include "adbc.h"
+
+namespace adbcpq {
+
+#define CONCAT(x, y) x##y
+#define MAKE_NAME(x, y) CONCAT(x, y)
+
+// see arrow/util/string_builder.h
+
+template <typename Head>
+static inline void StringBuilderRecursive(std::stringstream& stream, Head&& head) {
+ stream << head;
+}
+
+template <typename Head, typename... Tail>
+static inline void StringBuilderRecursive(std::stringstream& stream, Head&& head,
+ Tail&&... tail) {
+ StringBuilderRecursive(stream, std::forward<Head>(head));
+ StringBuilderRecursive(stream, std::forward<Tail>(tail)...);
+}
+
+template <typename... Args>
+static inline std::string StringBuilder(Args&&... args) {
+ std::stringstream ss;
+ StringBuilderRecursive(ss, std::forward<Args>(args)...);
+ return ss.str();
+}
+
+static inline void ReleaseError(struct AdbcError* error) {
+ delete[] error->message;
+ error->message = nullptr;
+ error->release = nullptr;
+}
+
+template <typename... Args>
+static inline void SetError(struct AdbcError* error, Args&&... args) {
+ if (!error) return;
+ std::string message = StringBuilder("[libpq] ", std::forward<Args>(args)...);
+ if (error->message) {
+ message.reserve(message.size() + 1 + std::strlen(error->message));
+ message.append(1, '\n');
+ message.append(error->message);
+ delete[] error->message;
+ }
+ error->message = new char[message.size() + 1];
+ message.copy(error->message, message.size());
+ error->message[message.size()] = '\0';
+ error->release = ReleaseError;
+}
+
+#define CHECK_NA_ADBC_IMPL(NAME, EXPR, ERROR) \
+ do { \
+ const int NAME = (EXPR); \
+ if (NAME) { \
+ SetError((ERROR), #EXPR " failed: ", std::strerror(NAME)); \
+ return ADBC_STATUS_INTERNAL; \
+ } \
+ } while (false)
+/// Check an errno-style code and return an ADBC code if necessary.
+#define CHECK_NA_ADBC(EXPR, ERROR) \
+ CHECK_NA_ADBC_IMPL(MAKE_NAME(errno_status_, __COUNTER__), EXPR, ERROR)
+
+#define CHECK_NA_IMPL(NAME, EXPR) \
+ do { \
+ const int NAME = (EXPR); \
+ if (NAME) return NAME; \
+ } while (false)
+
+/// Check an errno-style code and return it if necessary.
+#define CHECK_NA(EXPR) CHECK_NA_IMPL(MAKE_NAME(errno_status_, __COUNTER__), EXPR)
+
+} // namespace adbcpq
diff --git a/c/validation/adbc_validation.c b/c/validation/adbc_validation.c
index 5345ec4..7a0aa46 100644
--- a/c/validation/adbc_validation.c
+++ b/c/validation/adbc_validation.c
@@ -76,13 +76,16 @@ void AdbcValidatePass(struct AdbcValidateTestContext* ctx) {
}
void AdbcValidateFail(struct AdbcValidateTestContext* ctx, const char* file, int lineno,
- struct AdbcError* error) {
+ struct AdbcError* error, const char* message) {
ctx->failed++;
printf("\n%s:%d: FAIL\n", file, lineno);
if (error && error->release) {
printf("%s\n", error->message);
error->release(error);
}
+ if (message) {
+ printf("%s\n", message);
+ }
}
int AdbcValidationIsSet(struct ArrowArray* array, int64_t i) {
@@ -100,7 +103,7 @@ int AdbcValidationIsSet(struct ArrowArray* array, int64_t i) {
AdbcStatusCode NAME = (EXPR); \
if (ADBC_STATUS_##STATUS != NAME) { \
printf("\nActual value: %s\n", AdbcValidateStatusCodeMessage(NAME)); \
- AdbcValidateFail(adbc_context, __FILE__, __LINE__, ERROR); \
+ AdbcValidateFail(adbc_context, __FILE__, __LINE__, ERROR, NULL); \
return; \
} \
AdbcValidatePass(adbc_context);
@@ -110,44 +113,62 @@ int AdbcValidationIsSet(struct ArrowArray* array, int64_t i) {
#define ADBCV_ASSERT_EQ(EXPECTED, ACTUAL) \
AdbcValidateBeginAssert(adbc_context, "%s == %s", #ACTUAL, #EXPECTED); \
if ((EXPECTED) != (ACTUAL)) { \
- AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL); \
+ AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL, NULL); \
return; \
} \
AdbcValidatePass(adbc_context);
#define ADBCV_ASSERT_NE(EXPECTED, ACTUAL) \
- AdbcValidateBeginAssert(adbc_context, "%s == %s", #ACTUAL, #EXPECTED); \
+ AdbcValidateBeginAssert(adbc_context, "%s != %s", #ACTUAL, #EXPECTED); \
if ((EXPECTED) == (ACTUAL)) { \
- AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL); \
+ AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL, NULL); \
return; \
} \
AdbcValidatePass(adbc_context);
-#define ADBCV_ASSERT_TRUE(ACTUAL) \
- AdbcValidateBeginAssert(adbc_context, "%s is true", #ACTUAL); \
- if (!(ACTUAL)) { \
- AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL); \
- return; \
- } \
+#define ADBCV_ASSERT_TRUE(ACTUAL) \
+ AdbcValidateBeginAssert(adbc_context, "%s is true", #ACTUAL); \
+ if (!(ACTUAL)) { \
+ AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL, NULL); \
+ return; \
+ } \
AdbcValidatePass(adbc_context);
-#define ADBCV_ASSERT_FALSE(ACTUAL) \
- AdbcValidateBeginAssert(adbc_context, "%s is false", #ACTUAL); \
- if (ACTUAL) { \
- AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL); \
- return; \
- } \
+#define ADBCV_ASSERT_FALSE(ACTUAL) \
+ AdbcValidateBeginAssert(adbc_context, "%s is false", #ACTUAL); \
+ if (ACTUAL) { \
+ AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL, NULL); \
+ return; \
+ } \
AdbcValidatePass(adbc_context);
+#define ADBCV_FAIL() \
+ AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL, NULL); \
+ return
+
+#define NA_ASSERT_OK_IMPL(ERROR_NAME, EXPR) \
+ do { \
+ AdbcValidateBeginAssert(adbc_context, "%s is OK (0)", #EXPR); \
+ ArrowErrorCode ERROR_NAME = (EXPR); \
+ if (ERROR_NAME) { \
+ AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL, NULL); \
+ return; \
+ } \
+ AdbcValidatePass(adbc_context); \
+ } while (0)
+
+#define NA_ASSERT_OK(EXPR) NA_ASSERT_OK_IMPL(ADBCV_NAME(na_status_, __COUNTER__), EXPR)
-#define NA_ASSERT_OK_IMPL(ERROR_NAME, EXPR) \
+#define AAS_ASSERT_OK_IMPL(ERROR_NAME, STREAM, EXPR) \
do { \
AdbcValidateBeginAssert(adbc_context, "%s is OK (0)", #EXPR); \
ArrowErrorCode ERROR_NAME = (EXPR); \
if (ERROR_NAME) { \
- AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL); \
+ AdbcValidateFail(adbc_context, __FILE__, __LINE__, NULL, \
+ (STREAM)->get_last_error((STREAM))); \
return; \
} \
AdbcValidatePass(adbc_context); \
} while (0)
-#define NA_ASSERT_OK(EXPR) NA_ASSERT_OK_IMPL(ADBCV_NAME(na_status_, __COUNTER__), EXPR)
+#define AAS_ASSERT_OK(STREAM, EXPR) \
+ AAS_ASSERT_OK_IMPL(ADBCV_NAME(na_status_, __COUNTER__), STREAM, EXPR)
void AdbcValidateDatabaseNewRelease(struct AdbcValidateTestContext* adbc_context) {
struct AdbcError error;
@@ -351,20 +372,31 @@ void AdbcValidateStatementSqlExecute(struct AdbcValidateTestContext* adbc_contex
struct ArrowSchema schema;
struct ArrowSchemaView schema_view;
- ADBCV_ASSERT_EQ(0, out.get_schema(&out, &schema));
+ AAS_ASSERT_OK(&out, out.get_schema(&out, &schema));
ADBCV_ASSERT_EQ(1, schema.n_children);
- ADBCV_ASSERT_EQ(0, ArrowSchemaViewInit(&schema_view, schema.children[0], NULL));
- ADBCV_ASSERT_EQ(NANOARROW_TYPE_INT64, schema_view.data_type);
+ NA_ASSERT_OK(ArrowSchemaViewInit(&schema_view, schema.children[0], NULL));
struct ArrowArray array;
- ADBCV_ASSERT_EQ(0, out.get_next(&out, &array));
+ AAS_ASSERT_OK(&out, out.get_next(&out, &array));
ADBCV_ASSERT_NE(NULL, array.release);
+ ADBCV_ASSERT_EQ(1, array.length);
ADBCV_ASSERT_TRUE(AdbcValidationIsSet(array.children[0], 0));
- ADBCV_ASSERT_EQ(42, ((int64_t*)array.children[0]->buffers[1])[0]);
+
+ switch (schema_view.data_type) {
+ case NANOARROW_TYPE_INT32:
+ ADBCV_ASSERT_EQ(42, ((int32_t*)array.children[0]->buffers[1])[0]);
+ break;
+ case NANOARROW_TYPE_INT64:
+ ADBCV_ASSERT_EQ(42, ((int64_t*)array.children[0]->buffers[1])[0]);
+ break;
+ default:
+ printf("FAIL: Unexpected data type: %d\n", schema_view.data_type);
+ ADBCV_FAIL();
+ }
array.release(&array);
- ADBCV_ASSERT_EQ(0, out.get_next(&out, &array));
+ AAS_ASSERT_OK(&out, out.get_next(&out, &array));
ADBCV_ASSERT_EQ(NULL, array.release);
schema.release(&schema);
@@ -443,13 +475,13 @@ void AdbcValidateStatementSqlIngest(struct AdbcValidateTestContext* adbc_context
struct ArrowSchema schema;
struct ArrowSchemaView schema_view;
- NA_ASSERT_OK(out.get_schema(&out, &schema));
+ AAS_ASSERT_OK(&out, out.get_schema(&out, &schema));
ADBCV_ASSERT_EQ(1, schema.n_children);
NA_ASSERT_OK(ArrowSchemaViewInit(&schema_view, schema.children[0], NULL));
ADBCV_ASSERT_EQ(NANOARROW_TYPE_INT64, schema_view.data_type);
struct ArrowArray array;
- NA_ASSERT_OK(out.get_next(&out, &array));
+ AAS_ASSERT_OK(&out, out.get_next(&out, &array));
ADBCV_ASSERT_NE(NULL, array.release);
ADBCV_ASSERT_EQ(5, array.length);
@@ -464,7 +496,7 @@ void AdbcValidateStatementSqlIngest(struct AdbcValidateTestContext* adbc_context
ADBCV_ASSERT_EQ(42, data[4]);
array.release(&array);
- NA_ASSERT_OK(out.get_next(&out, &array));
+ AAS_ASSERT_OK(&out, out.get_next(&out, &array));
ADBCV_ASSERT_EQ(NULL, array.release);
ADBCV_ASSERT_NE(NULL, schema.release);