You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/11/09 15:56:12 UTC

(arrow-adbc) branch main updated: feat(c/driver/postgresql): Accept bulk ingest of dictionary-encoded strings/binary (#1275)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d6a2451 feat(c/driver/postgresql): Accept bulk ingest of dictionary-encoded strings/binary (#1275)
9d6a2451 is described below

commit 9d6a24514cd9a9df04568a5d7b2d92edb9c08130
Author: Dewey Dunnington <de...@voltrondata.com>
AuthorDate: Thu Nov 9 11:56:06 2023 -0400

    feat(c/driver/postgresql): Accept bulk ingest of dictionary-encoded strings/binary (#1275)
    
    This PR adds the ability for the Postgres driver to ingest
    dictionary-encoded arrays. This shows up in R because factors are
    relatively common and encode by default in Arrow to dictionary-encoded
    string for performance reasons.
    
    Reprex in R:
    
    ``` r
    library(adbcdrivermanager)
    
    con <- adbcpostgresql::adbcpostgresql() |>
      adbc_database_init(uri = Sys.getenv("ADBC_POSTGRESQL_TEST_URI")) |>
      adbc_connection_init()
    
    df <- data.frame(x = letters, y = factor(letters))
    write_adbc(df, con, "some_table")
    #> Error in adbc_statement_execute_query(stmt): [libpq] Failed to create table: ERROR:  relation "some_table" already exists
    #>
    #> Query was: CREATE TABLE "public" . "some_table" ("x" TEXT, "y" TEXT)
    read_adbc(con, "SELECT * from some_table") |>
      as.data.frame() |>
      str()
    #> 'data.frame':    26 obs. of  2 variables:
    #>  $ x: chr  "a" "b" "c" "d" ...
    #>  $ y: chr  "a" "b" "c" "d" ...
    ```
    
    <sup>Created on 2023-11-09 with [reprex
    v2.0.2](https://reprex.tidyverse.org)</sup>
    
    There is probably some opportunity to consolidate some of the code that
    currently lives in the `BindStream` into the `PostgresType` and/or
    `PostgresTypeResolver`...I'm happy to poke away at that at some point
    but in the meantime it seemed like it wasn't too onerous to tack on
    dictionary support here.
---
 c/driver/postgresql/postgres_copy_reader.h | 116 ++++++++++++++++++-----------
 c/driver/postgresql/postgres_type.h        |   3 +
 c/driver/postgresql/postgres_type_test.cc  |   9 +++
 c/driver/postgresql/postgresql_test.cc     |   1 -
 c/driver/postgresql/statement.cc           |  55 +++++++++++++-
 5 files changed, 138 insertions(+), 46 deletions(-)

diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h
index 09436351..66a31418 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -1231,13 +1231,13 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
     switch (TU) {
       case NANOARROW_TIME_UNIT_SECOND:
         if ((overflow_safe = raw_value <= kMaxSafeSecondsToMicros &&
-             raw_value >= kMinSafeSecondsToMicros)) {
+                             raw_value >= kMinSafeSecondsToMicros)) {
           value = raw_value * 1000000;
         }
         break;
       case NANOARROW_TIME_UNIT_MILLI:
         if ((overflow_safe = raw_value <= kMaxSafeMillisToMicros &&
-             raw_value >= kMinSafeMillisToMicros)) {
+                             raw_value >= kMinSafeMillisToMicros)) {
           value = raw_value * 1000;
         }
         break;
@@ -1251,11 +1251,8 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
 
     if (!overflow_safe) {
       ArrowErrorSet(
-          error,
-          "Row %" PRId64 " duration value %" PRId64 " with unit %d would overflow",
-          index,
-          raw_value,
-          TU);
+          error, "Row %" PRId64 " duration value %" PRId64 " with unit %d would overflow",
+          index, raw_value, TU);
       return ADBC_STATUS_INVALID_ARGUMENT;
     }
 
@@ -1273,8 +1270,7 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
 class PostgresCopyBinaryFieldWriter : public PostgresCopyFieldWriter {
  public:
   ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
-    struct ArrowBufferView buffer_view =
-        ArrowArrayViewGetBytesUnsafe(array_view_, index);
+    struct ArrowBufferView buffer_view = ArrowArrayViewGetBytesUnsafe(array_view_, index);
     NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, buffer_view.size_bytes, error));
     NANOARROW_RETURN_NOT_OK(
         ArrowBufferAppend(buffer, buffer_view.data.as_uint8, buffer_view.size_bytes));
@@ -1283,6 +1279,26 @@ class PostgresCopyBinaryFieldWriter : public PostgresCopyFieldWriter {
   }
 };
 
+class PostgresCopyBinaryDictFieldWriter : public PostgresCopyFieldWriter {
+ public:
+  ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
+    int64_t dict_index = ArrowArrayViewGetIntUnsafe(array_view_, index);
+    if (ArrowArrayViewIsNull(array_view_->dictionary, dict_index)) {
+      constexpr int32_t field_size_bytes = -1;
+      NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
+    } else {
+      struct ArrowBufferView buffer_view =
+          ArrowArrayViewGetBytesUnsafe(array_view_->dictionary, dict_index);
+      NANOARROW_RETURN_NOT_OK(
+          WriteChecked<int32_t>(buffer, buffer_view.size_bytes, error));
+      NANOARROW_RETURN_NOT_OK(
+          ArrowBufferAppend(buffer, buffer_view.data.as_uint8, buffer_view.size_bytes));
+    }
+
+    return ADBC_STATUS_OK;
+  }
+};
+
 template <enum ArrowTimeUnit TU>
 class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
  public:
@@ -1297,13 +1313,13 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
     switch (TU) {
       case NANOARROW_TIME_UNIT_SECOND:
         if ((overflow_safe = raw_value <= kMaxSafeSecondsToMicros &&
-             raw_value >= kMinSafeSecondsToMicros)) {
+                             raw_value >= kMinSafeSecondsToMicros)) {
           value = raw_value * 1000000;
         }
         break;
       case NANOARROW_TIME_UNIT_MILLI:
         if ((overflow_safe = raw_value <= kMaxSafeMillisToMicros &&
-             raw_value >= kMinSafeMillisToMicros)) {
+                             raw_value >= kMinSafeMillisToMicros)) {
           value = raw_value * 1000;
         }
         break;
@@ -1316,12 +1332,10 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
     }
 
     if (!overflow_safe) {
-      ArrowErrorSet(
-          error,
-          "Row %" PRId64 " timestamp value %" PRId64 " with unit %d would overflow",
-          index,
-          raw_value,
-          TU);
+      ArrowErrorSet(error,
+                    "Row %" PRId64 " timestamp value %" PRId64
+                    " with unit %d would overflow",
+                    index, raw_value, TU);
       return ADBC_STATUS_INVALID_ARGUMENT;
     }
 
@@ -1334,9 +1348,12 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
   }
 };
 
-static inline ArrowErrorCode MakeCopyFieldWriter(
-    const struct ArrowSchemaView& schema_view, PostgresCopyFieldWriter** out,
-    ArrowError* error) {
+static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema,
+                                                 PostgresCopyFieldWriter** out,
+                                                 ArrowError* error) {
+  struct ArrowSchemaView schema_view;
+  NANOARROW_RETURN_NOT_OK(ArrowSchemaViewInit(&schema_view, schema, error));
+
   switch (schema_view.type) {
     case NANOARROW_TYPE_BOOL:
       *out = new PostgresCopyBooleanFieldWriter();
@@ -1368,21 +1385,21 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
       *out = new PostgresCopyBinaryFieldWriter();
       return NANOARROW_OK;
     case NANOARROW_TYPE_TIMESTAMP: {
-        switch (schema_view.time_unit) {
-          case NANOARROW_TIME_UNIT_NANO:
-            *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>();
-            break;
-          case NANOARROW_TIME_UNIT_MILLI:
-            *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>();
-            break;
-          case NANOARROW_TIME_UNIT_MICRO:
-            *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>();
-            break;
-          case NANOARROW_TIME_UNIT_SECOND:
-            *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>();
-            break;
-        }
-        return NANOARROW_OK;
+      switch (schema_view.time_unit) {
+        case NANOARROW_TIME_UNIT_NANO:
+          *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>();
+          break;
+        case NANOARROW_TIME_UNIT_MILLI:
+          *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>();
+          break;
+        case NANOARROW_TIME_UNIT_MICRO:
+          *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>();
+          break;
+        case NANOARROW_TIME_UNIT_SECOND:
+          *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>();
+          break;
+      }
+      return NANOARROW_OK;
     }
     case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
       *out = new PostgresCopyIntervalFieldWriter();
@@ -1405,10 +1422,27 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
       }
       return NANOARROW_OK;
     }
+    case NANOARROW_TYPE_DICTIONARY: {
+      struct ArrowSchemaView value_view;
+      NANOARROW_RETURN_NOT_OK(
+          ArrowSchemaViewInit(&value_view, schema->dictionary, error));
+      switch (value_view.type) {
+        case NANOARROW_TYPE_BINARY:
+        case NANOARROW_TYPE_STRING:
+        case NANOARROW_TYPE_LARGE_BINARY:
+        case NANOARROW_TYPE_LARGE_STRING:
+          *out = new PostgresCopyBinaryDictFieldWriter();
+          return NANOARROW_OK;
+        default:
+          break;
+      }
+    }
     default:
-      ArrowErrorSet(error, "COPY Writer not implemented for type %d", schema_view.type);
-      return EINVAL;
+      break;
   }
+
+  ArrowErrorSet(error, "COPY Writer not implemented for type %d", schema_view.type);
+  return EINVAL;
 }
 
 class PostgresCopyStreamWriter {
@@ -1450,13 +1484,9 @@ class PostgresCopyStreamWriter {
     }
 
     for (int64_t i = 0; i < schema_->n_children; i++) {
-      struct ArrowSchemaView schema_view;
-      if (ArrowSchemaViewInit(&schema_view, schema_->children[i], error) !=
-          NANOARROW_OK) {
-        return ADBC_STATUS_INTERNAL;
-      }
       PostgresCopyFieldWriter* child_writer = nullptr;
-      NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(schema_view, &child_writer, error));
+      NANOARROW_RETURN_NOT_OK(
+          MakeCopyFieldWriter(schema_->children[i], &child_writer, error));
       root_writer_.AppendChild(std::unique_ptr<PostgresCopyFieldWriter>(child_writer));
     }
 
diff --git a/c/driver/postgresql/postgres_type.h b/c/driver/postgresql/postgres_type.h
index 1dfcbe2b..9d5f2e64 100644
--- a/c/driver/postgresql/postgres_type.h
+++ b/c/driver/postgresql/postgres_type.h
@@ -521,6 +521,9 @@ inline ArrowErrorCode PostgresType::FromSchema(const PostgresTypeResolver& resol
           PostgresType::FromSchema(resolver, schema->children[0], &child, error));
       return resolver.FindArray(child.oid(), out, error);
     }
+    case NANOARROW_TYPE_DICTIONARY:
+      // Dictionary arrays always resolve to the dictionary type when binding or ingesting
+      return PostgresType::FromSchema(resolver, schema->dictionary, out, error);
 
     default:
       ArrowErrorSet(error, "Can't map Arrow type '%s' to Postgres type",
diff --git a/c/driver/postgresql/postgres_type_test.cc b/c/driver/postgresql/postgres_type_test.cc
index 02e8da93..faf9eb07 100644
--- a/c/driver/postgresql/postgres_type_test.cc
+++ b/c/driver/postgresql/postgres_type_test.cc
@@ -279,6 +279,15 @@ TEST(PostgresTypeTest, PostgresTypeFromSchema) {
   EXPECT_EQ(type.child(0).type_id(), PostgresTypeId::kBool);
   schema.reset();
 
+  ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INT64), NANOARROW_OK);
+  ASSERT_EQ(ArrowSchemaAllocateDictionary(schema.get()), NANOARROW_OK);
+  ASSERT_EQ(ArrowSchemaInitFromType(schema->dictionary, NANOARROW_TYPE_STRING),
+            NANOARROW_OK);
+  EXPECT_EQ(PostgresType::FromSchema(resolver, schema.get(), &type, nullptr),
+            NANOARROW_OK);
+  EXPECT_EQ(type.type_id(), PostgresTypeId::kText);
+  schema.reset();
+
   ArrowError error;
   ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO),
             NANOARROW_OK);
diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc
index 343c729d..e1f95a49 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -820,7 +820,6 @@ class PostgresStatementTest : public ::testing::Test,
   void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; }
   void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; }
   void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
-  void TestSqlIngestStringDictionary() { GTEST_SKIP() << "Not implemented"; }
 
   void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet implemented"; }
   void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet implemented"; }
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index eac7eded..6f1bb9bd 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -210,6 +210,33 @@ struct BindStream {
           type_id = PostgresTypeId::kInterval;
           param_lengths[i] = 16;
           break;
+        case ArrowType::NANOARROW_TYPE_DICTIONARY: {
+          struct ArrowSchemaView value_view;
+          CHECK_NA(INTERNAL,
+                   ArrowSchemaViewInit(&value_view, bind_schema->children[i]->dictionary,
+                                       nullptr),
+                   error);
+          switch (value_view.type) {
+            case NANOARROW_TYPE_BINARY:
+            case NANOARROW_TYPE_LARGE_BINARY:
+              type_id = PostgresTypeId::kBytea;
+              param_lengths[i] = 0;
+              break;
+            case NANOARROW_TYPE_STRING:
+            case NANOARROW_TYPE_LARGE_STRING:
+              type_id = PostgresTypeId::kText;
+              param_lengths[i] = 0;
+              break;
+            default:
+              SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
+                       static_cast<uint64_t>(i + 1), " ('",
+                       bind_schema->children[i]->name,
+                       "') has unsupported dictionary value parameter type ",
+                       ArrowTypeString(value_view.type));
+              return ADBC_STATUS_NOT_IMPLEMENTED;
+          }
+          break;
+        }
         default:
           SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
                    static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
@@ -567,8 +594,8 @@ struct BindStream {
       }
 
       ArrowBuffer buffer = writer.WriteBuffer();
-      if (PQputCopyData(conn, reinterpret_cast<char*>(buffer.data),
-                        buffer.size_bytes) <= 0) {
+      if (PQputCopyData(conn, reinterpret_cast<char*>(buffer.data), buffer.size_bytes) <=
+          0) {
         SetError(error, "Error writing tuple field data: %s", PQerrorMessage(conn));
         return ADBC_STATUS_IO;
       }
@@ -1029,6 +1056,30 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
       case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
         create += " INTERVAL";
         break;
+      case ArrowType::NANOARROW_TYPE_DICTIONARY: {
+        struct ArrowSchemaView value_view;
+        CHECK_NA(INTERNAL,
+                 ArrowSchemaViewInit(&value_view, source_schema.children[i]->dictionary,
+                                     nullptr),
+                 error);
+        switch (value_view.type) {
+          case NANOARROW_TYPE_BINARY:
+          case NANOARROW_TYPE_LARGE_BINARY:
+            create += " BYTEA";
+            break;
+          case NANOARROW_TYPE_STRING:
+          case NANOARROW_TYPE_LARGE_STRING:
+            create += " TEXT";
+            break;
+          default:
+            SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
+                     static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
+                     "') has unsupported dictionary value type for ingestion ",
+                     ArrowTypeString(value_view.type));
+            return ADBC_STATUS_NOT_IMPLEMENTED;
+        }
+        break;
+      }
       default:
         SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
                  static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,