You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/11/09 15:56:12 UTC
(arrow-adbc) branch main updated: feat(c/driver/postgresql): Accept bulk ingest of dictionary-encoded strings/binary (#1275)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 9d6a2451 feat(c/driver/postgresql): Accept bulk ingest of dictionary-encoded strings/binary (#1275)
9d6a2451 is described below
commit 9d6a24514cd9a9df04568a5d7b2d92edb9c08130
Author: Dewey Dunnington <de...@voltrondata.com>
AuthorDate: Thu Nov 9 11:56:06 2023 -0400
feat(c/driver/postgresql): Accept bulk ingest of dictionary-encoded strings/binary (#1275)
This PR adds the ability for the Postgres driver to ingest
dictionary-encoded arrays. This shows up in R because factors are
relatively common and encode by default in Arrow to dictionary-encoded
string for performance reasons.
Reprex in R:
``` r
library(adbcdrivermanager)
con <- adbcpostgresql::adbcpostgresql() |>
adbc_database_init(uri = Sys.getenv("ADBC_POSTGRESQL_TEST_URI")) |>
adbc_connection_init()
df <- data.frame(x = letters, y = factor(letters))
write_adbc(df, con, "some_table")
#> Error in adbc_statement_execute_query(stmt): [libpq] Failed to create table: ERROR: relation "some_table" already exists
#>
#> Query was: CREATE TABLE "public" . "some_table" ("x" TEXT, "y" TEXT)
read_adbc(con, "SELECT * from some_table") |>
as.data.frame() |>
str()
#> 'data.frame': 26 obs. of 2 variables:
#> $ x: chr "a" "b" "c" "d" ...
#> $ y: chr "a" "b" "c" "d" ...
```
<sup>Created on 2023-11-09 with [reprex
v2.0.2](https://reprex.tidyverse.org)</sup>
There is probably some opportunity to consolidate some of the code that
currently lives in the `BindStream` into the `PostgresType` and/or
`PostgresTypeResolver`...I'm happy to poke away at that at some point
but in the meantime it seemed like it wasn't too onerous to tack on
dictionary support here.
---
c/driver/postgresql/postgres_copy_reader.h | 116 ++++++++++++++++++-----------
c/driver/postgresql/postgres_type.h | 3 +
c/driver/postgresql/postgres_type_test.cc | 9 +++
c/driver/postgresql/postgresql_test.cc | 1 -
c/driver/postgresql/statement.cc | 55 +++++++++++++-
5 files changed, 138 insertions(+), 46 deletions(-)
diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h
index 09436351..66a31418 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -1231,13 +1231,13 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
switch (TU) {
case NANOARROW_TIME_UNIT_SECOND:
if ((overflow_safe = raw_value <= kMaxSafeSecondsToMicros &&
- raw_value >= kMinSafeSecondsToMicros)) {
+ raw_value >= kMinSafeSecondsToMicros)) {
value = raw_value * 1000000;
}
break;
case NANOARROW_TIME_UNIT_MILLI:
if ((overflow_safe = raw_value <= kMaxSafeMillisToMicros &&
- raw_value >= kMinSafeMillisToMicros)) {
+ raw_value >= kMinSafeMillisToMicros)) {
value = raw_value * 1000;
}
break;
@@ -1251,11 +1251,8 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
if (!overflow_safe) {
ArrowErrorSet(
- error,
- "Row %" PRId64 " duration value %" PRId64 " with unit %d would overflow",
- index,
- raw_value,
- TU);
+ error, "Row %" PRId64 " duration value %" PRId64 " with unit %d would overflow",
+ index, raw_value, TU);
return ADBC_STATUS_INVALID_ARGUMENT;
}
@@ -1273,8 +1270,7 @@ class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
class PostgresCopyBinaryFieldWriter : public PostgresCopyFieldWriter {
public:
ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
- struct ArrowBufferView buffer_view =
- ArrowArrayViewGetBytesUnsafe(array_view_, index);
+ struct ArrowBufferView buffer_view = ArrowArrayViewGetBytesUnsafe(array_view_, index);
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, buffer_view.size_bytes, error));
NANOARROW_RETURN_NOT_OK(
ArrowBufferAppend(buffer, buffer_view.data.as_uint8, buffer_view.size_bytes));
@@ -1283,6 +1279,26 @@ class PostgresCopyBinaryFieldWriter : public PostgresCopyFieldWriter {
}
};
+class PostgresCopyBinaryDictFieldWriter : public PostgresCopyFieldWriter {
+ public:
+ ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
+ int64_t dict_index = ArrowArrayViewGetIntUnsafe(array_view_, index);
+ if (ArrowArrayViewIsNull(array_view_->dictionary, dict_index)) {
+ constexpr int32_t field_size_bytes = -1;
+ NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
+ } else {
+ struct ArrowBufferView buffer_view =
+ ArrowArrayViewGetBytesUnsafe(array_view_->dictionary, dict_index);
+ NANOARROW_RETURN_NOT_OK(
+ WriteChecked<int32_t>(buffer, buffer_view.size_bytes, error));
+ NANOARROW_RETURN_NOT_OK(
+ ArrowBufferAppend(buffer, buffer_view.data.as_uint8, buffer_view.size_bytes));
+ }
+
+ return ADBC_STATUS_OK;
+ }
+};
+
template <enum ArrowTimeUnit TU>
class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
public:
@@ -1297,13 +1313,13 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
switch (TU) {
case NANOARROW_TIME_UNIT_SECOND:
if ((overflow_safe = raw_value <= kMaxSafeSecondsToMicros &&
- raw_value >= kMinSafeSecondsToMicros)) {
+ raw_value >= kMinSafeSecondsToMicros)) {
value = raw_value * 1000000;
}
break;
case NANOARROW_TIME_UNIT_MILLI:
if ((overflow_safe = raw_value <= kMaxSafeMillisToMicros &&
- raw_value >= kMinSafeMillisToMicros)) {
+ raw_value >= kMinSafeMillisToMicros)) {
value = raw_value * 1000;
}
break;
@@ -1316,12 +1332,10 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
}
if (!overflow_safe) {
- ArrowErrorSet(
- error,
- "Row %" PRId64 " timestamp value %" PRId64 " with unit %d would overflow",
- index,
- raw_value,
- TU);
+ ArrowErrorSet(error,
+ "Row %" PRId64 " timestamp value %" PRId64
+ " with unit %d would overflow",
+ index, raw_value, TU);
return ADBC_STATUS_INVALID_ARGUMENT;
}
@@ -1334,9 +1348,12 @@ class PostgresCopyTimestampFieldWriter : public PostgresCopyFieldWriter {
}
};
-static inline ArrowErrorCode MakeCopyFieldWriter(
- const struct ArrowSchemaView& schema_view, PostgresCopyFieldWriter** out,
- ArrowError* error) {
+static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema,
+ PostgresCopyFieldWriter** out,
+ ArrowError* error) {
+ struct ArrowSchemaView schema_view;
+ NANOARROW_RETURN_NOT_OK(ArrowSchemaViewInit(&schema_view, schema, error));
+
switch (schema_view.type) {
case NANOARROW_TYPE_BOOL:
*out = new PostgresCopyBooleanFieldWriter();
@@ -1368,21 +1385,21 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
*out = new PostgresCopyBinaryFieldWriter();
return NANOARROW_OK;
case NANOARROW_TYPE_TIMESTAMP: {
- switch (schema_view.time_unit) {
- case NANOARROW_TIME_UNIT_NANO:
- *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>();
- break;
- case NANOARROW_TIME_UNIT_MILLI:
- *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>();
- break;
- case NANOARROW_TIME_UNIT_MICRO:
- *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>();
- break;
- case NANOARROW_TIME_UNIT_SECOND:
- *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>();
- break;
- }
- return NANOARROW_OK;
+ switch (schema_view.time_unit) {
+ case NANOARROW_TIME_UNIT_NANO:
+ *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_NANO>();
+ break;
+ case NANOARROW_TIME_UNIT_MILLI:
+ *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MILLI>();
+ break;
+ case NANOARROW_TIME_UNIT_MICRO:
+ *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_MICRO>();
+ break;
+ case NANOARROW_TIME_UNIT_SECOND:
+ *out = new PostgresCopyTimestampFieldWriter<NANOARROW_TIME_UNIT_SECOND>();
+ break;
+ }
+ return NANOARROW_OK;
}
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
*out = new PostgresCopyIntervalFieldWriter();
@@ -1405,10 +1422,27 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
}
return NANOARROW_OK;
}
+ case NANOARROW_TYPE_DICTIONARY: {
+ struct ArrowSchemaView value_view;
+ NANOARROW_RETURN_NOT_OK(
+ ArrowSchemaViewInit(&value_view, schema->dictionary, error));
+ switch (value_view.type) {
+ case NANOARROW_TYPE_BINARY:
+ case NANOARROW_TYPE_STRING:
+ case NANOARROW_TYPE_LARGE_BINARY:
+ case NANOARROW_TYPE_LARGE_STRING:
+ *out = new PostgresCopyBinaryDictFieldWriter();
+ return NANOARROW_OK;
+ default:
+ break;
+ }
+ }
default:
- ArrowErrorSet(error, "COPY Writer not implemented for type %d", schema_view.type);
- return EINVAL;
+ break;
}
+
+ ArrowErrorSet(error, "COPY Writer not implemented for type %d", schema_view.type);
+ return EINVAL;
}
class PostgresCopyStreamWriter {
@@ -1450,13 +1484,9 @@ class PostgresCopyStreamWriter {
}
for (int64_t i = 0; i < schema_->n_children; i++) {
- struct ArrowSchemaView schema_view;
- if (ArrowSchemaViewInit(&schema_view, schema_->children[i], error) !=
- NANOARROW_OK) {
- return ADBC_STATUS_INTERNAL;
- }
PostgresCopyFieldWriter* child_writer = nullptr;
- NANOARROW_RETURN_NOT_OK(MakeCopyFieldWriter(schema_view, &child_writer, error));
+ NANOARROW_RETURN_NOT_OK(
+ MakeCopyFieldWriter(schema_->children[i], &child_writer, error));
root_writer_.AppendChild(std::unique_ptr<PostgresCopyFieldWriter>(child_writer));
}
diff --git a/c/driver/postgresql/postgres_type.h b/c/driver/postgresql/postgres_type.h
index 1dfcbe2b..9d5f2e64 100644
--- a/c/driver/postgresql/postgres_type.h
+++ b/c/driver/postgresql/postgres_type.h
@@ -521,6 +521,9 @@ inline ArrowErrorCode PostgresType::FromSchema(const PostgresTypeResolver& resol
PostgresType::FromSchema(resolver, schema->children[0], &child, error));
return resolver.FindArray(child.oid(), out, error);
}
+ case NANOARROW_TYPE_DICTIONARY:
+ // Dictionary arrays always resolve to the dictionary type when binding or ingesting
+ return PostgresType::FromSchema(resolver, schema->dictionary, out, error);
default:
ArrowErrorSet(error, "Can't map Arrow type '%s' to Postgres type",
diff --git a/c/driver/postgresql/postgres_type_test.cc b/c/driver/postgresql/postgres_type_test.cc
index 02e8da93..faf9eb07 100644
--- a/c/driver/postgresql/postgres_type_test.cc
+++ b/c/driver/postgresql/postgres_type_test.cc
@@ -279,6 +279,15 @@ TEST(PostgresTypeTest, PostgresTypeFromSchema) {
EXPECT_EQ(type.child(0).type_id(), PostgresTypeId::kBool);
schema.reset();
+ ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INT64), NANOARROW_OK);
+ ASSERT_EQ(ArrowSchemaAllocateDictionary(schema.get()), NANOARROW_OK);
+ ASSERT_EQ(ArrowSchemaInitFromType(schema->dictionary, NANOARROW_TYPE_STRING),
+ NANOARROW_OK);
+ EXPECT_EQ(PostgresType::FromSchema(resolver, schema.get(), &type, nullptr),
+ NANOARROW_OK);
+ EXPECT_EQ(type.type_id(), PostgresTypeId::kText);
+ schema.reset();
+
ArrowError error;
ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO),
NANOARROW_OK);
diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc
index 343c729d..e1f95a49 100644
--- a/c/driver/postgresql/postgresql_test.cc
+++ b/c/driver/postgresql/postgresql_test.cc
@@ -820,7 +820,6 @@ class PostgresStatementTest : public ::testing::Test,
void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
- void TestSqlIngestStringDictionary() { GTEST_SKIP() << "Not implemented"; }
void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet implemented"; }
void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet implemented"; }
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index eac7eded..6f1bb9bd 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -210,6 +210,33 @@ struct BindStream {
type_id = PostgresTypeId::kInterval;
param_lengths[i] = 16;
break;
+ case ArrowType::NANOARROW_TYPE_DICTIONARY: {
+ struct ArrowSchemaView value_view;
+ CHECK_NA(INTERNAL,
+ ArrowSchemaViewInit(&value_view, bind_schema->children[i]->dictionary,
+ nullptr),
+ error);
+ switch (value_view.type) {
+ case NANOARROW_TYPE_BINARY:
+ case NANOARROW_TYPE_LARGE_BINARY:
+ type_id = PostgresTypeId::kBytea;
+ param_lengths[i] = 0;
+ break;
+ case NANOARROW_TYPE_STRING:
+ case NANOARROW_TYPE_LARGE_STRING:
+ type_id = PostgresTypeId::kText;
+ param_lengths[i] = 0;
+ break;
+ default:
+ SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
+ static_cast<uint64_t>(i + 1), " ('",
+ bind_schema->children[i]->name,
+ "') has unsupported dictionary value parameter type ",
+ ArrowTypeString(value_view.type));
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ break;
+ }
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
@@ -567,8 +594,8 @@ struct BindStream {
}
ArrowBuffer buffer = writer.WriteBuffer();
- if (PQputCopyData(conn, reinterpret_cast<char*>(buffer.data),
- buffer.size_bytes) <= 0) {
+ if (PQputCopyData(conn, reinterpret_cast<char*>(buffer.data), buffer.size_bytes) <=
+ 0) {
SetError(error, "Error writing tuple field data: %s", PQerrorMessage(conn));
return ADBC_STATUS_IO;
}
@@ -1029,6 +1056,30 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
create += " INTERVAL";
break;
+ case ArrowType::NANOARROW_TYPE_DICTIONARY: {
+ struct ArrowSchemaView value_view;
+ CHECK_NA(INTERNAL,
+ ArrowSchemaViewInit(&value_view, source_schema.children[i]->dictionary,
+ nullptr),
+ error);
+ switch (value_view.type) {
+ case NANOARROW_TYPE_BINARY:
+ case NANOARROW_TYPE_LARGE_BINARY:
+ create += " BYTEA";
+ break;
+ case NANOARROW_TYPE_STRING:
+ case NANOARROW_TYPE_LARGE_STRING:
+ create += " TEXT";
+ break;
+ default:
+ SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
+ static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
+ "') has unsupported dictionary value type for ingestion ",
+ ArrowTypeString(value_view.type));
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+ }
+ break;
+ }
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,