You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/07/28 15:40:36 UTC
[arrow-adbc] branch main updated: feat(c/driver/postgresql): Interval support (#908)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 0a5afa3e feat(c/driver/postgresql): Interval support (#908)
0a5afa3e is described below
commit 0a5afa3e8020e72ed5e5823bd7ebf03a3c7ba24e
Author: William Ayd <wi...@icloud.com>
AuthorDate: Fri Jul 28 08:40:30 2023 -0700
feat(c/driver/postgresql): Interval support (#908)
---
c/driver/flightsql/sqlite_flightsql_test.cc | 3 +
c/driver/postgresql/postgres_copy_reader.h | 49 +++++++++++++++-
c/driver/postgresql/postgres_type.h | 5 ++
c/driver/postgresql/statement.cc | 24 ++++++++
c/driver/sqlite/sqlite_test.cc | 3 +
c/driver_manager/adbc_driver_manager_test.cc | 3 +
c/validation/adbc_validation.cc | 83 ++++++++++++++++++++++++++++
c/validation/adbc_validation.h | 2 +
c/validation/adbc_validation_util.h | 11 ++++
c/vendor/nanoarrow/nanoarrow.h | 73 ++++++++++++++++++++++++
10 files changed, 255 insertions(+), 1 deletion(-)
diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc b/c/driver/flightsql/sqlite_flightsql_test.cc
index fdb2a795..b61b47bc 100644
--- a/c/driver/flightsql/sqlite_flightsql_test.cc
+++ b/c/driver/flightsql/sqlite_flightsql_test.cc
@@ -229,6 +229,9 @@ class SqliteFlightSqlStatementTest : public ::testing::Test,
void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }
void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; }
+ void TestSqlIngestInterval() {
+ GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
+ }
protected:
SqliteFlightSqlQuirks quirks_;
diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h
index 7d3844f8..dc827833 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -19,6 +19,7 @@
#include <algorithm>
#include <cerrno>
+#include <cinttypes>
#include <cstdint>
#include <memory>
#include <string>
@@ -212,7 +213,43 @@ class PostgresCopyNetworkEndianFieldReader : public PostgresCopyFieldReader {
}
};
-// Converts COPY resulting from the Postgres NUMERIC type into a string.
+// Reader for Intervals
+class PostgresCopyIntervalFieldReader : public PostgresCopyFieldReader {
+ public:
+ ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array,
+ ArrowError* error) override {
+ if (field_size_bytes <= 0) {
+ return ArrowArrayAppendNull(array, 1);
+ }
+
+ if (field_size_bytes != 16) {
+ ArrowErrorSet(error, "Expected field with %d bytes but found field with %d bytes",
+ 16,
+ static_cast<int>(field_size_bytes)); // NOLINT(runtime/int)
+ return EINVAL;
+ }
+
+ // postgres stores time as usec, arrow stores as ns
+ const int64_t time_usec = ReadUnsafe<int64_t>(data);
+
+ if ((time_usec > INT64_MAX / 1000) || (time_usec < INT64_MIN / 1000)) {
+ ArrowErrorSet(error, "[libpq] Interval with time value %" PRId64
+ " usec would overflow when converting to nanoseconds");
+ return EINVAL;
+ }
+
+ const int64_t time = time_usec * 1000;
+ const int32_t days = ReadUnsafe<int32_t>(data);
+ const int32_t months = ReadUnsafe<int32_t>(data);
+
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &months, sizeof(int32_t)));
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &days, sizeof(int32_t)));
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &time, sizeof(int64_t)));
+ return AppendValid(array);
+ }
+};
+
+// // Converts COPY resulting from the Postgres NUMERIC type into a string.
// Rewritten based on the Postgres implementation of NUMERIC cast to string in
// src/backend/utils/adt/numeric.c : get_str_from_var() (Note that in the initial source,
// DEC_DIGITS is always 4 and DBASE is always 10000).
@@ -836,6 +873,16 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type,
default:
return ErrorCantConvert(error, pg_type, schema_view);
}
+ case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
+ switch (pg_type.type_id()) {
+ case PostgresTypeId::kInterval: {
+ *out = new PostgresCopyIntervalFieldReader();
+ return NANOARROW_OK;
+ }
+ default:
+ return ErrorCantConvert(error, pg_type, schema_view);
+ }
+
default:
return ErrorCantConvert(error, pg_type, schema_view);
}
diff --git a/c/driver/postgresql/postgres_type.h b/c/driver/postgresql/postgres_type.h
index 7b4197bf..1dfcbe2b 100644
--- a/c/driver/postgresql/postgres_type.h
+++ b/c/driver/postgresql/postgres_type.h
@@ -258,6 +258,11 @@ class PostgresType {
NANOARROW_TIME_UNIT_MICRO, /*timezone=*/"UTC"));
break;
+ case PostgresTypeId::kInterval:
+ NANOARROW_RETURN_NOT_OK(
+ ArrowSchemaSetType(schema, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO));
+ break;
+
// ---- Nested --------------------
case PostgresTypeId::kRecord:
NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children()));
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 4cae15b6..dd3fd82c 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -223,6 +223,10 @@ struct BindStream {
type_id = PostgresTypeId::kTimestamp;
param_lengths[i] = 8;
break;
+ case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
+ type_id = PostgresTypeId::kInterval;
+ param_lengths[i] = 16;
+ break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
@@ -426,6 +430,23 @@ struct BindStream {
std::memcpy(param_values[col], &value, sizeof(int64_t));
break;
}
+ case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
+ const auto buf =
+ array_view->children[col]->buffer_views[1].data.as_uint8 + row * 16;
+ const int32_t raw_months = *(int32_t*)buf;
+ const int32_t raw_days = *(int32_t*)(buf + 4);
+ const int64_t raw_ns = *(int64_t*)(buf + 8);
+
+ const uint32_t months = ToNetworkInt32(raw_months);
+ const uint32_t days = ToNetworkInt32(raw_days);
+ const uint64_t ms = ToNetworkInt64(raw_ns / 1000);
+
+ std::memcpy(param_values[col], &ms, sizeof(uint64_t));
+ std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t));
+ std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t),
+ &months, sizeof(uint32_t));
+ break;
+ }
default:
SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('",
bind_schema->children[col]->name,
@@ -787,6 +808,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
create += " TIMESTAMP";
}
break;
+ case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
+ create += " INTERVAL";
+ break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index 3ab21fe3..00332066 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -195,6 +195,9 @@ class SqliteStatementTest : public ::testing::Test,
}
void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; }
+ void TestSqlIngestInterval() {
+ GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
+ }
protected:
void ValidateIngestedTemporalData(struct ArrowArrayView* values,
diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc
index 4475bd19..149da7c6 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -232,6 +232,9 @@ class SqliteStatementTest : public ::testing::Test,
void TestSqlIngestTimestampTz() {
GTEST_SKIP() << "Cannot ingest TIMESTAMP WITH TIMEZONE (not implemented)";
}
+ void TestSqlIngestInterval() {
+ GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
+ }
protected:
SqliteQuirks quirks_;
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index e104b6cd..54f3981c 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1193,6 +1193,89 @@ void StatementTest::TestSqlIngestTimestampTz() {
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
}
+void StatementTest::TestSqlIngestInterval() {
+ if (!quirks()->supports_bulk_ingest()) {
+ GTEST_SKIP();
+ }
+
+ ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error),
+ IsOkStatus(&error));
+
+ Handle<struct ArrowSchema> schema;
+ Handle<struct ArrowArray> array;
+ struct ArrowError na_error;
+ const enum ArrowType type = NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO;
+ // values are days, months, ns
+ struct ArrowInterval neg_interval;
+ struct ArrowInterval zero_interval;
+ struct ArrowInterval pos_interval;
+
+ ArrowIntervalInit(&neg_interval, type);
+ ArrowIntervalInit(&zero_interval, type);
+ ArrowIntervalInit(&pos_interval, type);
+
+ neg_interval.months = -5;
+ neg_interval.days = -5;
+ neg_interval.ns = -42000;
+
+ pos_interval.months = 5;
+ pos_interval.days = 5;
+ pos_interval.ns = 42000;
+
+ const std::vector<std::optional<ArrowInterval*>> values = {
+ std::nullopt, &neg_interval, &zero_interval, &pos_interval};
+
+ ASSERT_THAT(MakeSchema(&schema.value, {{"col", type}}), IsOkErrno());
+
+ ASSERT_THAT(MakeBatch<ArrowInterval*>(&schema.value, &array.value, &na_error, values),
+ IsOkErrno());
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
+ "bulk_ingest", &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
+ IsOkStatus(&error));
+
+ int64_t rows_affected = 0;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(rows_affected,
+ ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));
+
+ ASSERT_THAT(AdbcStatementSetSqlQuery(
+ &statement,
+ "SELECT * FROM bulk_ingest ORDER BY \"col\" ASC NULLS FIRST", &error),
+ IsOkStatus(&error));
+ {
+ StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(reader.rows_affected,
+ ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));
+
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type);
+ ASSERT_NO_FATAL_FAILURE(
+ CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}}));
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_NE(nullptr, reader.array->release);
+ ASSERT_EQ(values.size(), reader.array->length);
+ ASSERT_EQ(1, reader.array->n_children);
+
+ if (round_trip_type == type) {
+ ASSERT_NO_FATAL_FAILURE(
+ CompareArray<ArrowInterval*>(reader.array_view->children[0], values));
+ }
+
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ ASSERT_EQ(nullptr, reader.array->release);
+ }
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
void StatementTest::TestSqlIngestTableEscaping() {
std::string name = "create_table_escaping";
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 61208d13..dc5d69c2 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -232,6 +232,7 @@ class StatementTest {
// Temporal
void TestSqlIngestTimestamp();
void TestSqlIngestTimestampTz();
+ void TestSqlIngestInterval();
// ---- End Type-specific tests ----------------
@@ -302,6 +303,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \
TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \
TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \
+ TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \
TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \
TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); } \
diff --git a/c/validation/adbc_validation_util.h b/c/validation/adbc_validation_util.h
index fec5e758..5845676b 100644
--- a/c/validation/adbc_validation_util.h
+++ b/c/validation/adbc_validation_util.h
@@ -264,6 +264,10 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* array,
if (int errno_res = ArrowArrayAppendBytes(array, view); errno_res != 0) {
return errno_res;
}
+ } else if constexpr (std::is_same<T, ArrowInterval*>::value) {
+ if (int errno_res = ArrowArrayAppendInterval(array, *v); errno_res != 0) {
+ return errno_res;
+ }
} else {
static_assert(!sizeof(T), "Not yet implemented");
return ENOTSUP;
@@ -375,6 +379,13 @@ void CompareArray(struct ArrowArrayView* array,
struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i);
std::string str(view.data, view.size_bytes);
ASSERT_EQ(*v, str);
+ } else if constexpr (std::is_same<T, ArrowInterval*>::value) {
+ ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+ const auto buf = array->buffer_views[1].data.as_uint8;
+ const auto record = buf + i * 16;
+ ASSERT_EQ(memcmp(record, &(*v)->months, 4), 0);
+ ASSERT_EQ(memcmp(record + 4, &(*v)->days, 4), 0);
+ ASSERT_EQ(memcmp(record + 8, &(*v)->ns, 8), 0);
} else {
static_assert(!sizeof(T), "Not yet implemented");
}
diff --git a/c/vendor/nanoarrow/nanoarrow.h b/c/vendor/nanoarrow/nanoarrow.h
index 85353df4..dccb842c 100644
--- a/c/vendor/nanoarrow/nanoarrow.h
+++ b/c/vendor/nanoarrow/nanoarrow.h
@@ -688,6 +688,29 @@ struct ArrowArrayPrivateData {
int8_t union_type_id_is_child_index;
};
+/// \brief A representation of an interval.
+/// \ingroup nanoarrow-utils
+struct ArrowInterval {
+ /// \brief The type of interval being used
+ enum ArrowType type;
+ /// \brief The number of months represented by the interval
+ int32_t months;
+ /// \brief The number of days represented by the interval
+ int32_t days;
+ /// \brief The number of ms represented by the interval
+ int32_t ms;
+ /// \brief The number of ns represented by the interval
+ int64_t ns;
+};
+
+/// \brief Zero initialize an Interval with a given unit
+/// \ingroup nanoarrow-utils
+static inline void ArrowIntervalInit(struct ArrowInterval* interval,
+ enum ArrowType type) {
+ memset(interval, 0, sizeof(struct ArrowInterval));
+ interval->type = type;
+}
+
/// \brief A representation of a fixed-precision decimal number
/// \ingroup nanoarrow-utils
///
@@ -1649,6 +1672,13 @@ static inline ArrowErrorCode ArrowArrayAppendBytes(struct ArrowArray* array,
static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array,
struct ArrowStringView value);
+/// \brief Append a Interval to an array
+///
+/// Returns NANOARROW_OK if value can be exactly represented by
+/// the underlying storage type or EINVAL otherwise.
+static inline ArrowErrorCode ArrowArrayAppendInterval(struct ArrowArray* array,
+ struct ArrowInterval* value);
+
/// \brief Append a decimal value to an array
///
/// Returns NANOARROW_OK if array is a decimal array with the appropriate
@@ -2891,6 +2921,49 @@ static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array,
}
}
+static inline ArrowErrorCode ArrowArrayAppendInterval(struct ArrowArray* array,
+ struct ArrowInterval* value) {
+ struct ArrowArrayPrivateData* private_data =
+ (struct ArrowArrayPrivateData*)array->private_data;
+
+ struct ArrowBuffer* data_buffer = ArrowArrayBuffer(array, 1);
+
+ switch (private_data->storage_type) {
+ case NANOARROW_TYPE_INTERVAL_MONTHS: {
+ if (value->type != NANOARROW_TYPE_INTERVAL_MONTHS) {
+ return EINVAL;
+ }
+
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->months));
+ break;
+ }
+ case NANOARROW_TYPE_INTERVAL_DAY_TIME: {
+ if (value->type != NANOARROW_TYPE_INTERVAL_DAY_TIME) {
+ return EINVAL;
+ }
+
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->days));
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->ms));
+ break;
+ }
+ case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
+ if (value->type != NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO) {
+ return EINVAL;
+ }
+
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->months));
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->days));
+ NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt64(data_buffer, value->ns));
+ break;
+ }
+ default:
+ return EINVAL;
+ }
+
+ array->length++;
+ return NANOARROW_OK;
+}
+
static inline ArrowErrorCode ArrowArrayAppendDecimal(struct ArrowArray* array,
struct ArrowDecimal* value) {
struct ArrowArrayPrivateData* private_data =