You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/07/28 15:40:36 UTC

[arrow-adbc] branch main updated: feat(c/driver/postgresql): Interval support (#908)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a5afa3e feat(c/driver/postgresql): Interval support (#908)
0a5afa3e is described below

commit 0a5afa3e8020e72ed5e5823bd7ebf03a3c7ba24e
Author: William Ayd <wi...@icloud.com>
AuthorDate: Fri Jul 28 08:40:30 2023 -0700

    feat(c/driver/postgresql): Interval support (#908)
---
 c/driver/flightsql/sqlite_flightsql_test.cc  |  3 +
 c/driver/postgresql/postgres_copy_reader.h   | 49 +++++++++++++++-
 c/driver/postgresql/postgres_type.h          |  5 ++
 c/driver/postgresql/statement.cc             | 24 ++++++++
 c/driver/sqlite/sqlite_test.cc               |  3 +
 c/driver_manager/adbc_driver_manager_test.cc |  3 +
 c/validation/adbc_validation.cc              | 83 ++++++++++++++++++++++++++++
 c/validation/adbc_validation.h               |  2 +
 c/validation/adbc_validation_util.h          | 11 ++++
 c/vendor/nanoarrow/nanoarrow.h               | 73 ++++++++++++++++++++++++
 10 files changed, 255 insertions(+), 1 deletion(-)

diff --git a/c/driver/flightsql/sqlite_flightsql_test.cc b/c/driver/flightsql/sqlite_flightsql_test.cc
index fdb2a795..b61b47bc 100644
--- a/c/driver/flightsql/sqlite_flightsql_test.cc
+++ b/c/driver/flightsql/sqlite_flightsql_test.cc
@@ -229,6 +229,9 @@ class SqliteFlightSqlStatementTest : public ::testing::Test,
   void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }
 
   void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; }
+  void TestSqlIngestInterval() {
+    GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
+  }
 
  protected:
   SqliteFlightSqlQuirks quirks_;
diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h
index 7d3844f8..dc827833 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -19,6 +19,7 @@
 
 #include <algorithm>
 #include <cerrno>
+#include <cinttypes>
 #include <cstdint>
 #include <memory>
 #include <string>
@@ -212,7 +213,43 @@ class PostgresCopyNetworkEndianFieldReader : public PostgresCopyFieldReader {
   }
 };
 
-// Converts COPY resulting from the Postgres NUMERIC type into a string.
+// Reader for Intervals
+class PostgresCopyIntervalFieldReader : public PostgresCopyFieldReader {
+ public:
+  ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array,
+                      ArrowError* error) override {
+    if (field_size_bytes <= 0) {
+      return ArrowArrayAppendNull(array, 1);
+    }
+
+    if (field_size_bytes != 16) {
+      ArrowErrorSet(error, "Expected field with %d bytes but found field with %d bytes",
+                    16,
+                    static_cast<int>(field_size_bytes));  // NOLINT(runtime/int)
+      return EINVAL;
+    }
+
+    // postgres stores time as usec, arrow stores as ns
+    const int64_t time_usec = ReadUnsafe<int64_t>(data);
+
+    if ((time_usec > INT64_MAX / 1000) || (time_usec < INT64_MIN / 1000)) {
+      ArrowErrorSet(error, "[libpq] Interval with time value %" PRId64
+                           " usec would overflow when converting to nanoseconds");
+      return EINVAL;
+    }
+
+    const int64_t time = time_usec * 1000;
+    const int32_t days = ReadUnsafe<int32_t>(data);
+    const int32_t months = ReadUnsafe<int32_t>(data);
+
+    NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &months, sizeof(int32_t)));
+    NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &days, sizeof(int32_t)));
+    NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &time, sizeof(int64_t)));
+    return AppendValid(array);
+  }
+};
+
+// // Converts COPY resulting from the Postgres NUMERIC type into a string.
 // Rewritten based on the Postgres implementation of NUMERIC cast to string in
 // src/backend/utils/adt/numeric.c : get_str_from_var() (Note that in the initial source,
 // DEC_DIGITS is always 4 and DBASE is always 10000).
@@ -836,6 +873,16 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type,
         default:
           return ErrorCantConvert(error, pg_type, schema_view);
       }
+    case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
+      switch (pg_type.type_id()) {
+        case PostgresTypeId::kInterval: {
+          *out = new PostgresCopyIntervalFieldReader();
+          return NANOARROW_OK;
+        }
+        default:
+          return ErrorCantConvert(error, pg_type, schema_view);
+      }
+
     default:
       return ErrorCantConvert(error, pg_type, schema_view);
   }
diff --git a/c/driver/postgresql/postgres_type.h b/c/driver/postgresql/postgres_type.h
index 7b4197bf..1dfcbe2b 100644
--- a/c/driver/postgresql/postgres_type.h
+++ b/c/driver/postgresql/postgres_type.h
@@ -258,6 +258,11 @@ class PostgresType {
                                        NANOARROW_TIME_UNIT_MICRO, /*timezone=*/"UTC"));
         break;
 
+      case PostgresTypeId::kInterval:
+        NANOARROW_RETURN_NOT_OK(
+            ArrowSchemaSetType(schema, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO));
+        break;
+
       // ---- Nested --------------------
       case PostgresTypeId::kRecord:
         NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children()));
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 4cae15b6..dd3fd82c 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -223,6 +223,10 @@ struct BindStream {
           type_id = PostgresTypeId::kTimestamp;
           param_lengths[i] = 8;
           break;
+        case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
+          type_id = PostgresTypeId::kInterval;
+          param_lengths[i] = 16;
+          break;
         default:
           SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
                    static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
@@ -426,6 +430,23 @@ struct BindStream {
               std::memcpy(param_values[col], &value, sizeof(int64_t));
               break;
             }
+            case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
+              const auto buf =
+                  array_view->children[col]->buffer_views[1].data.as_uint8 + row * 16;
+              const int32_t raw_months = *(int32_t*)buf;
+              const int32_t raw_days = *(int32_t*)(buf + 4);
+              const int64_t raw_ns = *(int64_t*)(buf + 8);
+
+              const uint32_t months = ToNetworkInt32(raw_months);
+              const uint32_t days = ToNetworkInt32(raw_days);
+              const uint64_t ms = ToNetworkInt64(raw_ns / 1000);
+
+              std::memcpy(param_values[col], &ms, sizeof(uint64_t));
+              std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t));
+              std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t),
+                          &months, sizeof(uint32_t));
+              break;
+            }
             default:
               SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('",
                        bind_schema->children[col]->name,
@@ -787,6 +808,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
           create += " TIMESTAMP";
         }
         break;
+      case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
+        create += " INTERVAL";
+        break;
       default:
         SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
                  static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index 3ab21fe3..00332066 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -195,6 +195,9 @@ class SqliteStatementTest : public ::testing::Test,
   }
 
   void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; }
+  void TestSqlIngestInterval() {
+    GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
+  }
 
  protected:
   void ValidateIngestedTemporalData(struct ArrowArrayView* values,
diff --git a/c/driver_manager/adbc_driver_manager_test.cc b/c/driver_manager/adbc_driver_manager_test.cc
index 4475bd19..149da7c6 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -232,6 +232,9 @@ class SqliteStatementTest : public ::testing::Test,
   void TestSqlIngestTimestampTz() {
     GTEST_SKIP() << "Cannot ingest TIMESTAMP WITH TIMEZONE (not implemented)";
   }
+  void TestSqlIngestInterval() {
+    GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
+  }
 
  protected:
   SqliteQuirks quirks_;
diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc
index e104b6cd..54f3981c 100644
--- a/c/validation/adbc_validation.cc
+++ b/c/validation/adbc_validation.cc
@@ -1193,6 +1193,89 @@ void StatementTest::TestSqlIngestTimestampTz() {
       TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
 }
 
+void StatementTest::TestSqlIngestInterval() {
+  if (!quirks()->supports_bulk_ingest()) {
+    GTEST_SKIP();
+  }
+
+  ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error),
+              IsOkStatus(&error));
+
+  Handle<struct ArrowSchema> schema;
+  Handle<struct ArrowArray> array;
+  struct ArrowError na_error;
+  const enum ArrowType type = NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO;
+  // values are days, months, ns
+  struct ArrowInterval neg_interval;
+  struct ArrowInterval zero_interval;
+  struct ArrowInterval pos_interval;
+
+  ArrowIntervalInit(&neg_interval, type);
+  ArrowIntervalInit(&zero_interval, type);
+  ArrowIntervalInit(&pos_interval, type);
+
+  neg_interval.months = -5;
+  neg_interval.days = -5;
+  neg_interval.ns = -42000;
+
+  pos_interval.months = 5;
+  pos_interval.days = 5;
+  pos_interval.ns = 42000;
+
+  const std::vector<std::optional<ArrowInterval*>> values = {
+      std::nullopt, &neg_interval, &zero_interval, &pos_interval};
+
+  ASSERT_THAT(MakeSchema(&schema.value, {{"col", type}}), IsOkErrno());
+
+  ASSERT_THAT(MakeBatch<ArrowInterval*>(&schema.value, &array.value, &na_error, values),
+              IsOkErrno());
+
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
+                                     "bulk_ingest", &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
+              IsOkStatus(&error));
+
+  int64_t rows_affected = 0;
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(rows_affected,
+              ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));
+
+  ASSERT_THAT(AdbcStatementSetSqlQuery(
+                  &statement,
+                  "SELECT * FROM bulk_ingest ORDER BY \"col\" ASC NULLS FIRST", &error),
+              IsOkStatus(&error));
+  {
+    StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(reader.rows_affected,
+                ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));
+
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type);
+    ASSERT_NO_FATAL_FAILURE(
+        CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}}));
+
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_NE(nullptr, reader.array->release);
+    ASSERT_EQ(values.size(), reader.array->length);
+    ASSERT_EQ(1, reader.array->n_children);
+
+    if (round_trip_type == type) {
+      ASSERT_NO_FATAL_FAILURE(
+          CompareArray<ArrowInterval*>(reader.array_view->children[0], values));
+    }
+
+    ASSERT_NO_FATAL_FAILURE(reader.Next());
+    ASSERT_EQ(nullptr, reader.array->release);
+  }
+  ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
 void StatementTest::TestSqlIngestTableEscaping() {
   std::string name = "create_table_escaping";
 
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 61208d13..dc5d69c2 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -232,6 +232,7 @@ class StatementTest {
   // Temporal
   void TestSqlIngestTimestamp();
   void TestSqlIngestTimestampTz();
+  void TestSqlIngestInterval();
 
   // ---- End Type-specific tests ----------------
 
@@ -302,6 +303,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); }                           \
   TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); }                     \
   TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); }                 \
+  TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); }                       \
   TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); }             \
   TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); }                           \
   TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); }                           \
diff --git a/c/validation/adbc_validation_util.h b/c/validation/adbc_validation_util.h
index fec5e758..5845676b 100644
--- a/c/validation/adbc_validation_util.h
+++ b/c/validation/adbc_validation_util.h
@@ -264,6 +264,10 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* array,
         if (int errno_res = ArrowArrayAppendBytes(array, view); errno_res != 0) {
           return errno_res;
         }
+      } else if constexpr (std::is_same<T, ArrowInterval*>::value) {
+        if (int errno_res = ArrowArrayAppendInterval(array, *v); errno_res != 0) {
+          return errno_res;
+        }
       } else {
         static_assert(!sizeof(T), "Not yet implemented");
         return ENOTSUP;
@@ -375,6 +379,13 @@ void CompareArray(struct ArrowArrayView* array,
         struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i);
         std::string str(view.data, view.size_bytes);
         ASSERT_EQ(*v, str);
+      } else if constexpr (std::is_same<T, ArrowInterval*>::value) {
+        ASSERT_NE(array->buffer_views[1].data.data, nullptr);
+        const auto buf = array->buffer_views[1].data.as_uint8;
+        const auto record = buf + i * 16;
+        ASSERT_EQ(memcmp(record, &(*v)->months, 4), 0);
+        ASSERT_EQ(memcmp(record + 4, &(*v)->days, 4), 0);
+        ASSERT_EQ(memcmp(record + 8, &(*v)->ns, 8), 0);
       } else {
         static_assert(!sizeof(T), "Not yet implemented");
       }
diff --git a/c/vendor/nanoarrow/nanoarrow.h b/c/vendor/nanoarrow/nanoarrow.h
index 85353df4..dccb842c 100644
--- a/c/vendor/nanoarrow/nanoarrow.h
+++ b/c/vendor/nanoarrow/nanoarrow.h
@@ -688,6 +688,29 @@ struct ArrowArrayPrivateData {
   int8_t union_type_id_is_child_index;
 };
 
+/// \brief A representation of an interval.
+/// \ingroup nanoarrow-utils
+struct ArrowInterval {
+  /// \brief The type of interval being used
+  enum ArrowType type;
+  /// \brief The number of months represented by the interval
+  int32_t months;
+  /// \brief The number of days represented by the interval
+  int32_t days;
+  /// \brief The number of ms represented by the interval
+  int32_t ms;
+  /// \brief The number of ns represented by the interval
+  int64_t ns;
+};
+
+/// \brief Zero initialize an Interval with a given unit
+/// \ingroup nanoarrow-utils
+static inline void ArrowIntervalInit(struct ArrowInterval* interval,
+                                     enum ArrowType type) {
+  memset(interval, 0, sizeof(struct ArrowInterval));
+  interval->type = type;
+}
+
 /// \brief A representation of a fixed-precision decimal number
 /// \ingroup nanoarrow-utils
 ///
@@ -1649,6 +1672,13 @@ static inline ArrowErrorCode ArrowArrayAppendBytes(struct ArrowArray* array,
 static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array,
                                                     struct ArrowStringView value);
 
+/// \brief Append a Interval to an array
+///
+/// Returns NANOARROW_OK if value can be exactly represented by
+/// the underlying storage type or EINVAL otherwise.
+static inline ArrowErrorCode ArrowArrayAppendInterval(struct ArrowArray* array,
+                                                      struct ArrowInterval* value);
+
 /// \brief Append a decimal value to an array
 ///
 /// Returns NANOARROW_OK if array is a decimal array with the appropriate
@@ -2891,6 +2921,49 @@ static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array,
   }
 }
 
+static inline ArrowErrorCode ArrowArrayAppendInterval(struct ArrowArray* array,
+                                                      struct ArrowInterval* value) {
+  struct ArrowArrayPrivateData* private_data =
+      (struct ArrowArrayPrivateData*)array->private_data;
+
+  struct ArrowBuffer* data_buffer = ArrowArrayBuffer(array, 1);
+
+  switch (private_data->storage_type) {
+    case NANOARROW_TYPE_INTERVAL_MONTHS: {
+      if (value->type != NANOARROW_TYPE_INTERVAL_MONTHS) {
+        return EINVAL;
+      }
+
+      NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->months));
+      break;
+    }
+    case NANOARROW_TYPE_INTERVAL_DAY_TIME: {
+      if (value->type != NANOARROW_TYPE_INTERVAL_DAY_TIME) {
+        return EINVAL;
+      }
+
+      NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->days));
+      NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->ms));
+      break;
+    }
+    case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
+      if (value->type != NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO) {
+        return EINVAL;
+      }
+
+      NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->months));
+      NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->days));
+      NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt64(data_buffer, value->ns));
+      break;
+    }
+    default:
+      return EINVAL;
+  }
+
+  array->length++;
+  return NANOARROW_OK;
+}
+
 static inline ArrowErrorCode ArrowArrayAppendDecimal(struct ArrowArray* array,
                                                      struct ArrowDecimal* value) {
   struct ArrowArrayPrivateData* private_data =