You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/05/24 22:07:54 UTC
[arrow-adbc] branch main updated: feat(c/driver/postgresql): Implement GetObjectsDbSchemas for Postgres (#679)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new dcc681e feat(c/driver/postgresql): Implement GetObjectsDbSchemas for Postgres (#679)
dcc681e is described below
commit dcc681e236af9706f9d9e631efe1c9eb10ab2496
Author: William Ayd <wi...@icloud.com>
AuthorDate: Wed May 24 15:07:49 2023 -0700
feat(c/driver/postgresql): Implement GetObjectsDbSchemas for Postgres (#679)
---
c/driver/postgresql/connection.cc | 137 ++++++++++++++++++++++++++++-----
c/driver/postgresql/postgresql_test.cc | 56 +++++++++++++-
2 files changed, 172 insertions(+), 21 deletions(-)
diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc
index d9a5d13..2cee78e 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -216,6 +216,85 @@ AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection,
return BatchToArrayStream(&array, &schema, out, error);
}
+AdbcStatusCode PostgresConnectionGetSchemasImpl(PGconn* conn, int depth,
+ const char* db_name,
+ const char* db_schema,
+ struct ArrowArray* db_schemas_list,
+ struct AdbcError* error) {
+ struct ArrowArray* db_schema_items = db_schemas_list->children[0];
+ struct ArrowArray* db_schema_names = db_schema_items->children[0];
+ struct ArrowArray* db_schema_tables_list = db_schema_items->children[1];
+
+ // inefficient to place here but better localized until we do a class-based refactor
+ std::string curr_db;
+ PqResultHelper curr_db_helper = PqResultHelper{conn, "SELECT current_database()"};
+ if (curr_db_helper.Status() == PGRES_TUPLES_OK) {
+ assert(curr_db_helper.NumRows() == 1);
+ auto curr_iter = curr_db_helper.begin();
+ PqResultRow db_row = *curr_iter;
+ curr_db = std::string(db_row[0].data);
+ } else {
+ return ADBC_STATUS_INTERNAL;
+ }
+
+ // postgres only allows you to list schemas for the currently connected db
+ if (strcmp(db_name, curr_db.c_str()) == 0) {
+ struct StringBuilder query = {0};
+ if (StringBuilderInit(&query, /*initial_size*/ 256)) {
+ return ADBC_STATUS_INTERNAL;
+ }
+
+ const char* stmt =
+ "SELECT nspname FROM pg_catalog.pg_namespace WHERE "
+ "nspname !~ '^pg_' AND nspname <> 'information_schema'";
+
+ if (StringBuilderAppend(&query, "%s", stmt)) {
+ StringBuilderReset(&query);
+ return ADBC_STATUS_INTERNAL;
+ }
+
+ if (db_schema != NULL) {
+ char* schema_name = PQescapeIdentifier(conn, db_schema, strlen(db_schema));
+ if (schema_name == NULL) {
+ SetError(error, "%s%s", "Failed to escape schema: ", PQerrorMessage(conn));
+ StringBuilderReset(&query);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+
+ int res =
+ StringBuilderAppend(&query, "%s%s%s", " AND nspname ='", schema_name, "'");
+ PQfreemem(schema_name);
+ if (res) {
+ return ADBC_STATUS_INTERNAL;
+ }
+ }
+
+ auto result_helper = PqResultHelper{conn, query.buffer};
+ StringBuilderReset(&query);
+
+ if (result_helper.Status() == PGRES_TUPLES_OK) {
+ for (PqResultRow row : result_helper) {
+ const char* schema_name = row[0].data;
+ CHECK_NA(INTERNAL,
+ ArrowArrayAppendString(db_schema_names, ArrowCharView(schema_name)),
+ error);
+ if (depth >= ADBC_OBJECT_DEPTH_TABLES) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ } else {
+ CHECK_NA(INTERNAL, ArrowArrayAppendNull(db_schema_tables_list, 1), error);
+ }
+ CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_items), error);
+ }
+ } else {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ }
+
+ CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schemas_list), error);
+
+ return ADBC_STATUS_OK;
+}
+
AdbcStatusCode PostgresConnectionGetObjectsImpl(
PGconn* conn, int depth, const char* catalog, const char* db_schema,
const char* table_name, const char** table_types, const char* column_name,
@@ -230,33 +309,51 @@ AdbcStatusCode PostgresConnectionGetObjectsImpl(
struct ArrowArray* catalog_name_col = array->children[0];
struct ArrowArray* catalog_db_schemas_col = array->children[1];
- // TODO: support proper filters
- if (!catalog) {
- struct StringBuilder query = {0};
- if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL;
+ struct ArrowArray* catalog_db_schemas_items = catalog_db_schemas_col->children[0];
+ struct ArrowArray* db_schema_name_col = catalog_db_schemas_items->children[0];
+ struct ArrowArray* db_schema_tables_col = catalog_db_schemas_items->children[1];
+
+ struct StringBuilder query = {0};
+ if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL;
+
+ if (StringBuilderAppend(&query, "%s", "SELECT datname FROM pg_catalog.pg_database")) {
+ return ADBC_STATUS_INTERNAL;
+ }
+
+ if (catalog != NULL) {
+ char* catalog_name = PQescapeIdentifier(conn, catalog, strlen(catalog));
+ if (catalog_name == NULL) {
+ SetError(error, "%s%s", "Failed to escape catalog: ", PQerrorMessage(conn));
+ StringBuilderReset(&query);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
- if (StringBuilderAppend(&query, "%s", "SELECT datname FROM pg_catalog.pg_database")) {
+ int res =
+ StringBuilderAppend(&query, "%s%s%s", " WHERE datname = '", catalog_name, "'");
+ PQfreemem(catalog_name);
+ if (res) {
return ADBC_STATUS_INTERNAL;
}
+ }
- PqResultHelper result_helper = PqResultHelper{conn, query.buffer};
- StringBuilderReset(&query);
+ PqResultHelper result_helper = PqResultHelper{conn, query.buffer};
+ StringBuilderReset(&query);
- if (result_helper.Status() == PGRES_TUPLES_OK) {
- for (PqResultRow row : result_helper) {
- const char* db_name = row[0].data;
- CHECK_NA(INTERNAL,
- ArrowArrayAppendString(catalog_name_col, ArrowCharView(db_name)), error);
- if (depth == ADBC_OBJECT_DEPTH_CATALOGS) {
- CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col, 1), error);
- } else {
- return ADBC_STATUS_NOT_IMPLEMENTED;
- }
- CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
+ if (result_helper.Status() == PGRES_TUPLES_OK) {
+ for (PqResultRow row : result_helper) {
+ const char* db_name = row[0].data;
+ CHECK_NA(INTERNAL, ArrowArrayAppendString(catalog_name_col, ArrowCharView(db_name)),
+ error);
+ if (depth == ADBC_OBJECT_DEPTH_CATALOGS) {
+ CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col, 1), error);
+ } else {
+ RAISE_ADBC(PostgresConnectionGetSchemasImpl(conn, depth, db_name, db_schema,
+ catalog_db_schemas_col, error));
}
- } else {
- return ADBC_STATUS_NOT_IMPLEMENTED;
+ CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
}
+ } else {
+ return ADBC_STATUS_INTERNAL;
}
CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array, &na_error), &na_error,
diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc
index 281b83b..c3e8dc7 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -86,7 +86,6 @@ class PostgresConnectionTest : public ::testing::Test,
void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); }
void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }
- void TestMetadataGetObjectsDbSchemas() { GTEST_SKIP() << "Not yet implemented"; }
void TestMetadataGetObjectsTables() { GTEST_SKIP() << "Not yet implemented"; }
void TestMetadataGetObjectsTablesTypes() { GTEST_SKIP() << "Not yet implemented"; }
void TestMetadataGetObjectsColumns() { GTEST_SKIP() << "Not yet implemented"; }
@@ -204,6 +203,61 @@ TEST_F(PostgresConnectionTest, GetObjectsGetCatalogs) {
EXPECT_TRUE(seen_tempalte1_db) << "template1 database does not exist";
}
+TEST_F(PostgresConnectionTest, GetObjectsGetDbSchemas) {
+ ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));
+
+ if (!quirks()->supports_get_objects()) {
+ GTEST_SKIP();
+ }
+
+ adbc_validation::StreamReader reader;
+ ASSERT_THAT(AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_DB_SCHEMAS, nullptr,
+ nullptr, nullptr, nullptr, nullptr,
+ &reader.stream.value, &error),
+ IsOkStatus(&error));
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NE(nullptr, reader.array->release);
+ ASSERT_GT(reader.array->length, 0);
+
+ bool seen_public = false;
+
+ struct ArrowArrayView* catalog_db_schemas_list = reader.array_view->children[1];
+ struct ArrowArrayView* catalog_db_schema_names = catalog_db_schemas_list->children[0];
+
+ do {
+ for (int64_t catalog_idx = 0; catalog_idx < reader.array->length; catalog_idx++) {
+ ArrowStringView db_name =
+ ArrowArrayViewGetStringUnsafe(reader.array_view->children[0], catalog_idx);
+ auto db_str = std::string(db_name.data, db_name.size_bytes);
+
+ auto schema_list_start =
+ ArrowArrayViewListChildOffset(catalog_db_schemas_list, catalog_idx);
+ auto schema_list_end =
+ ArrowArrayViewListChildOffset(catalog_db_schemas_list, catalog_idx + 1);
+
+ if (db_str == "postgres") {
+ ASSERT_FALSE(ArrowArrayViewIsNull(catalog_db_schemas_list, catalog_idx));
+ for (auto db_schemas_index = schema_list_start;
+ db_schemas_index < schema_list_end; db_schemas_index++) {
+ ArrowStringView schema_name = ArrowArrayViewGetStringUnsafe(
+ catalog_db_schema_names->children[0], db_schemas_index);
+ auto schema_str = std::string(schema_name.data, schema_name.size_bytes);
+ if (schema_str == "public") {
+ seen_public = true;
+ }
+ }
+ } else {
+ ASSERT_EQ(schema_list_start, schema_list_end);
+ }
+ }
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ } while (reader.array->release);
+
+ ASSERT_TRUE(seen_public) << "public schema does not exist";
+}
+
TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();