You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/05/24 22:07:54 UTC

[arrow-adbc] branch main updated: feat(c/driver/postgresql): Implement GetObjectsDbSchemas for Postgres (#679)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new dcc681e  feat(c/driver/postgresql): Implement GetObjectsDbSchemas for Postgres (#679)
dcc681e is described below

commit dcc681e236af9706f9d9e631efe1c9eb10ab2496
Author: William Ayd <wi...@icloud.com>
AuthorDate: Wed May 24 15:07:49 2023 -0700

    feat(c/driver/postgresql): Implement GetObjectsDbSchemas for Postgres (#679)
---
 c/driver/postgresql/connection.cc      | 137 ++++++++++++++++++++++++++++-----
 c/driver/postgresql/postgresql_test.cc |  56 +++++++++++++-
 2 files changed, 172 insertions(+), 21 deletions(-)

diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc
index d9a5d13..2cee78e 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -216,6 +216,85 @@ AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection,
   return BatchToArrayStream(&array, &schema, out, error);
 }
 
+AdbcStatusCode PostgresConnectionGetSchemasImpl(PGconn* conn, int depth,
+                                                const char* db_name,
+                                                const char* db_schema,
+                                                struct ArrowArray* db_schemas_list,
+                                                struct AdbcError* error) {
+  struct ArrowArray* db_schema_items = db_schemas_list->children[0];
+  struct ArrowArray* db_schema_names = db_schema_items->children[0];
+  struct ArrowArray* db_schema_tables_list = db_schema_items->children[1];
+
+  // inefficient to place here but better localized until we do a class-based refactor
+  std::string curr_db;
+  PqResultHelper curr_db_helper = PqResultHelper{conn, "SELECT current_database()"};
+  if (curr_db_helper.Status() == PGRES_TUPLES_OK) {
+    assert(curr_db_helper.NumRows() == 1);
+    auto curr_iter = curr_db_helper.begin();
+    PqResultRow db_row = *curr_iter;
+    curr_db = std::string(db_row[0].data);
+  } else {
+    return ADBC_STATUS_INTERNAL;
+  }
+
+  // postgres only allows you to list schemas for the currently connected db
+  if (strcmp(db_name, curr_db.c_str()) == 0) {
+    struct StringBuilder query = {0};
+    if (StringBuilderInit(&query, /*initial_size*/ 256)) {
+      return ADBC_STATUS_INTERNAL;
+    }
+
+    const char* stmt =
+        "SELECT nspname FROM pg_catalog.pg_namespace WHERE "
+        "nspname !~ '^pg_' AND nspname <> 'information_schema'";
+
+    if (StringBuilderAppend(&query, "%s", stmt)) {
+      StringBuilderReset(&query);
+      return ADBC_STATUS_INTERNAL;
+    }
+
+    if (db_schema != NULL) {
+      char* schema_name = PQescapeIdentifier(conn, db_schema, strlen(db_schema));
+      if (schema_name == NULL) {
+        SetError(error, "%s%s", "Failed to escape schema: ", PQerrorMessage(conn));
+        StringBuilderReset(&query);
+        return ADBC_STATUS_INVALID_ARGUMENT;
+      }
+
+      int res =
+          StringBuilderAppend(&query, "%s%s%s", " AND nspname ='", schema_name, "'");
+      PQfreemem(schema_name);
+      if (res) {
+        return ADBC_STATUS_INTERNAL;
+      }
+    }
+
+    auto result_helper = PqResultHelper{conn, query.buffer};
+    StringBuilderReset(&query);
+
+    if (result_helper.Status() == PGRES_TUPLES_OK) {
+      for (PqResultRow row : result_helper) {
+        const char* schema_name = row[0].data;
+        CHECK_NA(INTERNAL,
+                 ArrowArrayAppendString(db_schema_names, ArrowCharView(schema_name)),
+                 error);
+        if (depth >= ADBC_OBJECT_DEPTH_TABLES) {
+          return ADBC_STATUS_NOT_IMPLEMENTED;
+        } else {
+          CHECK_NA(INTERNAL, ArrowArrayAppendNull(db_schema_tables_list, 1), error);
+        }
+        CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_items), error);
+      }
+    } else {
+      return ADBC_STATUS_NOT_IMPLEMENTED;
+    }
+  }
+
+  CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schemas_list), error);
+
+  return ADBC_STATUS_OK;
+}
+
 AdbcStatusCode PostgresConnectionGetObjectsImpl(
     PGconn* conn, int depth, const char* catalog, const char* db_schema,
     const char* table_name, const char** table_types, const char* column_name,
@@ -230,33 +309,51 @@ AdbcStatusCode PostgresConnectionGetObjectsImpl(
   struct ArrowArray* catalog_name_col = array->children[0];
   struct ArrowArray* catalog_db_schemas_col = array->children[1];
 
-  // TODO: support proper filters
-  if (!catalog) {
-    struct StringBuilder query = {0};
-    if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL;
+  struct ArrowArray* catalog_db_schemas_items = catalog_db_schemas_col->children[0];
+  struct ArrowArray* db_schema_name_col = catalog_db_schemas_items->children[0];
+  struct ArrowArray* db_schema_tables_col = catalog_db_schemas_items->children[1];
+
+  struct StringBuilder query = {0};
+  if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL;
+
+  if (StringBuilderAppend(&query, "%s", "SELECT datname FROM pg_catalog.pg_database")) {
+    return ADBC_STATUS_INTERNAL;
+  }
+
+  if (catalog != NULL) {
+    char* catalog_name = PQescapeIdentifier(conn, catalog, strlen(catalog));
+    if (catalog_name == NULL) {
+      SetError(error, "%s%s", "Failed to escape catalog: ", PQerrorMessage(conn));
+      StringBuilderReset(&query);
+      return ADBC_STATUS_INVALID_ARGUMENT;
+    }
 
-    if (StringBuilderAppend(&query, "%s", "SELECT datname FROM pg_catalog.pg_database")) {
+    int res =
+        StringBuilderAppend(&query, "%s%s%s", " WHERE datname = '", catalog_name, "'");
+    PQfreemem(catalog_name);
+    if (res) {
       return ADBC_STATUS_INTERNAL;
     }
+  }
 
-    PqResultHelper result_helper = PqResultHelper{conn, query.buffer};
-    StringBuilderReset(&query);
+  PqResultHelper result_helper = PqResultHelper{conn, query.buffer};
+  StringBuilderReset(&query);
 
-    if (result_helper.Status() == PGRES_TUPLES_OK) {
-      for (PqResultRow row : result_helper) {
-        const char* db_name = row[0].data;
-        CHECK_NA(INTERNAL,
-                 ArrowArrayAppendString(catalog_name_col, ArrowCharView(db_name)), error);
-        if (depth == ADBC_OBJECT_DEPTH_CATALOGS) {
-          CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col, 1), error);
-        } else {
-          return ADBC_STATUS_NOT_IMPLEMENTED;
-        }
-        CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
+  if (result_helper.Status() == PGRES_TUPLES_OK) {
+    for (PqResultRow row : result_helper) {
+      const char* db_name = row[0].data;
+      CHECK_NA(INTERNAL, ArrowArrayAppendString(catalog_name_col, ArrowCharView(db_name)),
+               error);
+      if (depth == ADBC_OBJECT_DEPTH_CATALOGS) {
+        CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col, 1), error);
+      } else {
+        RAISE_ADBC(PostgresConnectionGetSchemasImpl(conn, depth, db_name, db_schema,
+                                                    catalog_db_schemas_col, error));
       }
-    } else {
-      return ADBC_STATUS_NOT_IMPLEMENTED;
+      CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
     }
+  } else {
+    return ADBC_STATUS_INTERNAL;
   }
 
   CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array, &na_error), &na_error,
diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc
index 281b83b..c3e8dc7 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -86,7 +86,6 @@ class PostgresConnectionTest : public ::testing::Test,
   void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); }
   void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }
 
-  void TestMetadataGetObjectsDbSchemas() { GTEST_SKIP() << "Not yet implemented"; }
   void TestMetadataGetObjectsTables() { GTEST_SKIP() << "Not yet implemented"; }
   void TestMetadataGetObjectsTablesTypes() { GTEST_SKIP() << "Not yet implemented"; }
   void TestMetadataGetObjectsColumns() { GTEST_SKIP() << "Not yet implemented"; }
@@ -204,6 +203,61 @@ TEST_F(PostgresConnectionTest, GetObjectsGetCatalogs) {
   EXPECT_TRUE(seen_tempalte1_db) << "template1 database does not exist";
 }
 
+TEST_F(PostgresConnectionTest, GetObjectsGetDbSchemas) {
+  ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));
+
+  if (!quirks()->supports_get_objects()) {
+    GTEST_SKIP();
+  }
+
+  adbc_validation::StreamReader reader;
+  ASSERT_THAT(AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_DB_SCHEMAS, nullptr,
+                                       nullptr, nullptr, nullptr, nullptr,
+                                       &reader.stream.value, &error),
+              IsOkStatus(&error));
+  ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_NE(nullptr, reader.array->release);
+  ASSERT_GT(reader.array->length, 0);
+
+  bool seen_public = false;
+
+  struct ArrowArrayView* catalog_db_schemas_list = reader.array_view->children[1];
+  struct ArrowArrayView* catalog_db_schema_names = catalog_db_schemas_list->children[0];
+
+  do {
+    for (int64_t catalog_idx = 0; catalog_idx < reader.array->length; catalog_idx++) {
+      ArrowStringView db_name =
+          ArrowArrayViewGetStringUnsafe(reader.array_view->children[0], catalog_idx);
+      auto db_str = std::string(db_name.data, db_name.size_bytes);
+
+      auto schema_list_start =
+          ArrowArrayViewListChildOffset(catalog_db_schemas_list, catalog_idx);
+      auto schema_list_end =
+          ArrowArrayViewListChildOffset(catalog_db_schemas_list, catalog_idx + 1);
+
+      if (db_str == "postgres") {
+        ASSERT_FALSE(ArrowArrayViewIsNull(catalog_db_schemas_list, catalog_idx));
+        for (auto db_schemas_index = schema_list_start;
+             db_schemas_index < schema_list_end; db_schemas_index++) {
+          ArrowStringView schema_name = ArrowArrayViewGetStringUnsafe(
+              catalog_db_schema_names->children[0], db_schemas_index);
+          auto schema_str = std::string(schema_name.data, schema_name.size_bytes);
+          if (schema_str == "public") {
+            seen_public = true;
+          }
+        }
+      } else {
+        ASSERT_EQ(schema_list_start, schema_list_end);
+      }
+    }
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+  } while (reader.array->release);
+
+  ASSERT_TRUE(seen_public) << "public schema does not exist";
+}
+
 TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) {
   if (!quirks()->supports_bulk_ingest()) {
     GTEST_SKIP();