You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/16 21:14:51 UTC

[arrow-adbc] branch main updated: feat(c/driver/postgresql): Implement GetObjects with table_types argument (#799)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 48222978 feat(c/driver/postgresql): Implement GetObjects with table_types argument (#799)
48222978 is described below

commit 4822297812711577e76af410e22bebef580a45e8
Author: William Ayd <wi...@icloud.com>
AuthorDate: Fri Jun 16 14:14:46 2023 -0700

    feat(c/driver/postgresql): Implement GetObjects with table_types argument (#799)
---
 c/driver/postgresql/connection.cc      | 75 +++++++++++++++++++---------
 c/driver/postgresql/postgresql_test.cc | 89 +++++++++++++++++++++++++++++++++-
 c/validation/adbc_validation.h         |  7 +++
 3 files changed, 146 insertions(+), 25 deletions(-)

diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc
index 22451218..611cd513 100644
--- a/c/driver/postgresql/connection.cc
+++ b/c/driver/postgresql/connection.cc
@@ -23,6 +23,7 @@
 #include <memory>
 #include <sstream>
 #include <string>
+#include <unordered_map>
 #include <utility>
 #include <vector>
 
@@ -39,6 +40,10 @@ static const uint32_t kSupportedInfoCodes[] = {
     ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION,
 };
 
+static const std::unordered_map<std::string, std::string> kPgTableTypes = {
+    {"table", "r"},       {"view", "v"},          {"materialized_view", "m"},
+    {"toast_table", "t"}, {"foreign_table", "f"}, {"partitioned_table", "p"}};
+
 struct PqRecord {
   const char* data;
   const int len;
@@ -356,7 +361,7 @@ class PqGetObjectsHelper {
       return ADBC_STATUS_INTERNAL;
     }
 
-    if (table_name_ != NULL) {
+    if (table_name_ != nullptr) {
       if (StringBuilderAppend(&query, "%s", " AND c.relname LIKE $2")) {
         StringBuilderReset(&query);
         return ADBC_STATUS_INTERNAL;
@@ -365,6 +370,45 @@ class PqGetObjectsHelper {
       params.push_back(std::string(table_name_));
     }
 
+    if (table_types_ != nullptr) {
+      std::vector<std::string> table_type_filter;
+      const char** table_types = table_types_;
+      while (*table_types != NULL) {
+        auto table_type_str = std::string(*table_types);
+        if (auto search = kPgTableTypes.find(table_type_str);
+            search != kPgTableTypes.end()) {
+          table_type_filter.push_back(search->second);
+        }
+        table_types++;
+      }
+
+      if (!table_type_filter.empty()) {
+        std::ostringstream oss;
+        bool first = true;
+        oss << "(";
+        for (const auto& str : table_type_filter) {
+          if (!first) {
+            oss << ", ";
+          }
+          oss << "'" << str << "'";
+          first = false;
+        }
+        oss << ")";
+
+        if (StringBuilderAppend(&query, "%s%s", " AND c.relkind IN ",
+                                oss.str().c_str())) {
+          StringBuilderReset(&query);
+          return ADBC_STATUS_INTERNAL;
+        }
+      } else {
+        // no matching table type means no records should come back
+        if (StringBuilderAppend(&query, "%s", " AND false")) {
+          StringBuilderReset(&query);
+          return ADBC_STATUS_INTERNAL;
+        }
+      }
+    }
+
     auto result_helper = PqResultHelper{conn_, query.buffer, params, error_};
     StringBuilderReset(&query);
 
@@ -889,28 +933,13 @@ AdbcStatusCode PostgresConnectionGetTableTypesImpl(struct ArrowSchema* schema,
   CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), NULL), error);
   CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error);
 
-  CHECK_NA(INTERNAL, ArrowArrayAppendString(array->children[0], ArrowCharView("table")),
-           error);
-  CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
-  CHECK_NA(INTERNAL,
-           ArrowArrayAppendString(array->children[0], ArrowCharView("toast_table")),
-           error);
-  CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
-  CHECK_NA(INTERNAL, ArrowArrayAppendString(array->children[0], ArrowCharView("view")),
-           error);
-  CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
-  CHECK_NA(INTERNAL,
-           ArrowArrayAppendString(array->children[0], ArrowCharView("materialized_view")),
-           error);
-  CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
-  CHECK_NA(INTERNAL,
-           ArrowArrayAppendString(array->children[0], ArrowCharView("foreign_table")),
-           error);
-  CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
-  CHECK_NA(INTERNAL,
-           ArrowArrayAppendString(array->children[0], ArrowCharView("partitioned_table")),
-           error);
-  CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
+  for (auto const& table_type : kPgTableTypes) {
+    CHECK_NA(INTERNAL,
+             ArrowArrayAppendString(array->children[0],
+                                    ArrowCharView(table_type.first.c_str())),
+             error);
+    CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error);
+  }
 
   CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array, NULL), error);
 
diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc
index eed03fa9..429d59dd 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -63,6 +63,24 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
     return status;
   }
 
+  AdbcStatusCode DropView(struct AdbcConnection* connection, const std::string& name,
+                          struct AdbcError* error) const override {
+    struct AdbcStatement statement;
+    std::memset(&statement, 0, sizeof(statement));
+    AdbcStatusCode status = AdbcStatementNew(connection, &statement, error);
+    if (status != ADBC_STATUS_OK) return status;
+
+    std::string query = "DROP VIEW IF EXISTS " + name;
+    status = AdbcStatementSetSqlQuery(&statement, query.c_str(), error);
+    if (status != ADBC_STATUS_OK) {
+      std::ignore = AdbcStatementRelease(&statement, error);
+      return status;
+    }
+    status = AdbcStatementExecuteQuery(&statement, nullptr, nullptr, error);
+    std::ignore = AdbcStatementRelease(&statement, error);
+    return status;
+  }
+
   std::string BindParameter(int index) const override {
     return "$" + std::to_string(index + 1);
   }
@@ -97,8 +115,6 @@ class PostgresConnectionTest : public ::testing::Test,
   void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); }
   void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }
 
-  void TestMetadataGetObjectsTablesTypes() { GTEST_SKIP() << "Not yet implemented"; }
-
  protected:
   PostgresQuirks quirks_;
 };
@@ -421,6 +437,75 @@ TEST_F(PostgresConnectionTest, GetObjectsGetAllFindsForeignKey) {
   }
 }
 
+TEST_F(PostgresConnectionTest, GetObjectsTableTypesFilter) {
+  ASSERT_THAT(AdbcConnectionNew(&connection, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), IsOkStatus(&error));
+
+  if (!quirks()->supports_get_objects()) {
+    GTEST_SKIP();
+  }
+
+  ASSERT_THAT(quirks()->DropView(&connection, "adbc_table_types_view_test", &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(quirks()->DropTable(&connection, "adbc_table_types_table_test", &error),
+              IsOkStatus(&error));
+
+  {
+    adbc_validation::Handle<struct AdbcStatement> statement;
+    ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+                IsOkStatus(&error));
+
+    ASSERT_THAT(
+        AdbcStatementSetSqlQuery(
+            &statement.value,
+            "CREATE TABLE adbc_table_types_table_test (id1 INT, id2 INT)", &error),
+        IsOkStatus(&error));
+
+    int64_t rows_affected = 0;
+    ASSERT_THAT(
+        AdbcStatementExecuteQuery(&statement.value, nullptr, &rows_affected, &error),
+        IsOkStatus(&error));
+  }
+
+  {
+    adbc_validation::Handle<struct AdbcStatement> statement;
+    ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value,
+                                         "CREATE VIEW adbc_table_types_view_test AS ( "
+                                         "SELECT * FROM adbc_table_types_table_test)",
+                                         &error),
+                IsOkStatus(&error));
+    int64_t rows_affected = 0;
+    ASSERT_THAT(
+        AdbcStatementExecuteQuery(&statement.value, nullptr, &rows_affected, &error),
+        IsOkStatus(&error));
+  }
+
+  adbc_validation::StreamReader reader;
+  std::vector<const char*> table_types = {"view", nullptr};
+  ASSERT_THAT(AdbcConnectionGetObjects(&connection, ADBC_OBJECT_DEPTH_ALL, nullptr,
+                                       nullptr, nullptr, table_types.data(), nullptr,
+                                       &reader.stream.value, &error),
+              IsOkStatus(&error));
+  ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+  ASSERT_NO_FATAL_FAILURE(reader.Next());
+  ASSERT_NE(nullptr, reader.array->release);
+  ASSERT_GT(reader.array->length, 0);
+
+  auto get_objects_data = adbc_validation::GetObjectsReader{&reader.array_view.value};
+  ASSERT_NE(*get_objects_data, nullptr)
+      << "could not initialize the AdbcGetInfoData object";
+
+  struct AdbcGetObjectsTable* table = AdbcGetObjectsDataGetTableByName(
+      *get_objects_data, "postgres", "public", "adbc_table_types_table_test");
+  ASSERT_EQ(table, nullptr) << "unexpected table adbc_table_types_table_test found";
+
+  struct AdbcGetObjectsTable* view = AdbcGetObjectsDataGetTableByName(
+      *get_objects_data, "postgres", "public", "adbc_table_types_view_test");
+  ASSERT_NE(view, nullptr) << "did not find view adbc_table_types_view_test";
+}
+
 TEST_F(PostgresConnectionTest, MetadataGetTableSchemaInjection) {
   if (!quirks()->supports_bulk_ingest()) {
     GTEST_SKIP();
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index ffff035e..4e4251bf 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -47,6 +47,13 @@ class DriverQuirks {
     return ADBC_STATUS_OK;
   }
 
+  /// \brief Drop the given view. Used by tests to reset state.
+  virtual AdbcStatusCode DropView(struct AdbcConnection* connection,
+                                  const std::string& name,
+                                  struct AdbcError* error) const {
+    return ADBC_STATUS_NOT_IMPLEMENTED;
+  }
+
   virtual AdbcStatusCode EnsureSampleTable(struct AdbcConnection* connection,
                                            const std::string& name,
                                            struct AdbcError* error) const;