You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by em...@apache.org on 2020/09/25 08:01:40 UTC

[arrow] branch decimal256 updated: Archery C++ round trip working. Java disabled. Fix c-bridge (#8268)

This is an automated email from the ASF dual-hosted git repository.

emkornfield pushed a commit to branch decimal256
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/decimal256 by this push:
     new 50c956b  Archery C++ round trip working.  Java disabled.  Fix c-bridge (#8268)
50c956b is described below

commit 50c956bce0f38566e79ff2b7318f3da3b837d917
Author: emkornfield <em...@gmail.com>
AuthorDate: Fri Sep 25 01:00:58 2020 -0700

    Archery C++ round trip working.  Java disabled.  Fix c-bridge (#8268)
    
    Archery lint issue needs to be fixed, i'll do that in a follow-up
---
 cpp/src/arrow/c/bridge.cc                  | 22 ++++++++++++++++++----
 cpp/src/arrow/c/bridge_test.cc             |  2 ++
 cpp/src/arrow/ipc/metadata_internal.cc     | 28 +++++++++++++++++++++-------
 dev/archery/archery/integration/datagen.py | 19 ++++++++++++++-----
 4 files changed, 55 insertions(+), 16 deletions(-)

diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc
index 1585b50..b5af364 100644
--- a/cpp/src/arrow/c/bridge.cc
+++ b/cpp/src/arrow/c/bridge.cc
@@ -304,8 +304,15 @@ struct SchemaExporter {
   }
 
   Status Visit(const DecimalType& type) {
-    return SetFormat("d:" + std::to_string(type.precision()) + "," +
-                     std::to_string(type.scale()));
+    if (type.bit_width() == 128) {
+      // 128 is the default bit-width
+      return SetFormat("d:" + std::to_string(type.precision()) + "," +
+                       std::to_string(type.scale()));
+    } else {
+      return SetFormat("d:" + std::to_string(type.precision()) + "," +
+                       std::to_string(type.scale()) + "," +
+                       std::to_string(type.bit_width()));
+    }
   }
 
   Status Visit(const BinaryType& type) { return SetFormat("z"); }
@@ -972,13 +979,20 @@ struct SchemaImporter {
   Status ProcessDecimal() {
     RETURN_NOT_OK(f_parser_.CheckNext(':'));
     ARROW_ASSIGN_OR_RAISE(auto prec_scale, f_parser_.ParseInts(f_parser_.Rest()));
-    if (prec_scale.size() != 2) {
+    // 3 elements indicates bit width was communicated as well.
+    if (prec_scale.size() != 2 && prec_scale.size() != 3) {
       return f_parser_.Invalid();
     }
     if (prec_scale[0] <= 0 || prec_scale[1] <= 0) {
       return f_parser_.Invalid();
     }
-    type_ = decimal(prec_scale[0], prec_scale[1]);
+    if (prec_scale.size() == 2 || prec_scale[2] == 128) {
+      type_ = decimal(prec_scale[0], prec_scale[1]);
+    } else if (prec_scale[2] == 256) {
+      type_ = decimal256(prec_scale[0], prec_scale[1]);
+    } else {
+      return f_parser_.Invalid();
+    }
     return Status::OK();
   }
 
diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc
index 6695d6e..ecb5655 100644
--- a/cpp/src/arrow/c/bridge_test.cc
+++ b/cpp/src/arrow/c/bridge_test.cc
@@ -277,6 +277,7 @@ TEST_F(TestSchemaExport, Primitive) {
   TestPrimitive(large_utf8(), "U");
 
   TestPrimitive(decimal(16, 4), "d:16,4");
+  TestPrimitive(decimal256(16, 4), "d:16,4,256");
 }
 
 TEST_F(TestSchemaExport, Temporal) {
@@ -736,6 +737,7 @@ TEST_F(TestArrayExport, Primitive) {
   TestPrimitive(large_utf8(), R"(["foo", "bar", null])");
 
   TestPrimitive(decimal(16, 4), R"(["1234.5670", null])");
+  TestPrimitive(decimal256(16, 4), R"(["1234.5670", null])");
 }
 
 TEST_F(TestArrayExport, PrimitiveSliced) {
diff --git a/cpp/src/arrow/ipc/metadata_internal.cc b/cpp/src/arrow/ipc/metadata_internal.cc
index fe43149..cb26a15 100644
--- a/cpp/src/arrow/ipc/metadata_internal.cc
+++ b/cpp/src/arrow/ipc/metadata_internal.cc
@@ -236,7 +236,8 @@ static inline TimeUnit::type FromFlatbufferUnit(flatbuf::TimeUnit unit) {
   return TimeUnit::SECOND;
 }
 
-constexpr int32_t kDecimalBitWidth = 128;
+constexpr int32_t kDecimalBitWidth128 = 128;
+constexpr int32_t kDecimalBitWidth256 = 256;
 
 Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
                                   const std::vector<std::shared_ptr<Field>>& children,
@@ -273,10 +274,13 @@ Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
       return Status::OK();
     case flatbuf::Type::Decimal: {
       auto dec_type = static_cast<const flatbuf::Decimal*>(type_data);
-      if (dec_type->bitWidth() != kDecimalBitWidth) {
-        return Status::Invalid("Library only supports 128-bit decimal values");
+      if (dec_type->bitWidth() == kDecimalBitWidth128) {
+        return Decimal128Type::Make(dec_type->precision(), dec_type->scale()).Value(out);
+      } else if (dec_type->bitWidth() == kDecimalBitWidth256) {
+        return Decimal256Type::Make(dec_type->precision(), dec_type->scale()).Value(out);
+      } else {
+        return Status::Invalid("Library only supports 128-bit or 256-bit decimal values");
       }
-      return Decimal128Type::Make(dec_type->precision(), dec_type->scale()).Value(out);
     }
     case flatbuf::Type::Date: {
       auto date_type = static_cast<const flatbuf::Date*>(type_data);
@@ -594,11 +598,21 @@ class FieldToFlatbufferVisitor {
     return Status::OK();
   }
 
-  Status Visit(const DecimalType& type) {
+  Status Visit(const Decimal128Type& type) {
     const auto& dec_type = checked_cast<const Decimal128Type&>(type);
     fb_type_ = flatbuf::Type::Decimal;
-    type_offset_ =
-        flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale()).Union();
+    type_offset_ = flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale(),
+                                          /*bitWidth=*/128)
+                       .Union();
+    return Status::OK();
+  }
+
+  Status Visit(const Decimal256Type& type) {
+    const auto& dec_type = checked_cast<const Decimal256Type&>(type);
+    fb_type_ = flatbuf::Type::Decimal;
+    type_offset_ = flatbuf::CreateDecimal(fbb_, dec_type.precision(), dec_type.scale(),
+                                          /*bitWith=*/256)
+                       .Union();
     return Status::OK();
   }
 
diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py
index 69f463b..b740198 100644
--- a/dev/archery/archery/integration/datagen.py
+++ b/dev/archery/archery/integration/datagen.py
@@ -400,14 +400,15 @@ class FloatingPointField(PrimitiveField):
 
 DECIMAL_PRECISION_TO_VALUE = {
     key: (1 << (8 * i - 1)) - 1 for i, key in enumerate(
-        [1, 3, 5, 7, 10, 12, 15, 17, 19, 22, 24, 27, 29, 32, 34, 36],
+        [1, 3, 5, 7, 10, 12, 15, 17, 19, 22, 24, 27, 29, 32, 34, 36,
+         38, 40, 42, 44, 50, 60, 70],
         start=1,
     )
 }
 
 
 def decimal_range_from_precision(precision):
-    assert 1 <= precision <= 38
+    assert 1 <= precision <= 76
     try:
         max_value = DECIMAL_PRECISION_TO_VALUE[precision]
     except KeyError:
@@ -417,7 +418,7 @@ def decimal_range_from_precision(precision):
 
 
 class DecimalField(PrimitiveField):
-    def __init__(self, name, precision, scale, bit_width=128, *,
+    def __init__(self, name, precision, scale, bit_width, *,
                  nullable=True, metadata=None):
         super().__init__(name, nullable=True,
                          metadata=metadata)
@@ -434,6 +435,7 @@ class DecimalField(PrimitiveField):
             ('name', 'decimal'),
             ('precision', self.precision),
             ('scale', self.scale),
+            ('bitWidth', self.bit_width),
         ])
 
     def generate_column(self, size, name=None):
@@ -448,7 +450,7 @@ class DecimalField(PrimitiveField):
 
 class DecimalColumn(PrimitiveColumn):
 
-    def __init__(self, name, count, is_valid, values, bit_width=128):
+    def __init__(self, name, count, is_valid, values, bit_width):
         super().__init__(name, count, is_valid, values)
         self.bit_width = bit_width
 
@@ -1274,8 +1276,13 @@ def generate_null_trivial_case(batch_sizes):
 
 def generate_decimal_case():
     fields = [
-        DecimalField(name='f{}'.format(i), precision=precision, scale=2)
+        DecimalField(name='f{}'.format(i), precision=precision, scale=2,
+            bit_width=128)
         for i, precision in enumerate(range(3, 39))
+    ] + [
+        DecimalField(name='f{}'.format(i), precision=precision, scale=5,
+            bit_width=256)
+        for i, precision in enumerate(range(37, 70))
     ]
 
     possible_batch_sizes = 7, 10
@@ -1516,6 +1523,8 @@ def get_generated_json_files(tempdir=None):
         generate_decimal_case()
         .skip_category('Go')  # TODO(ARROW-7948): Decimal + Go
         .skip_category('Rust'),
+        .skip_category('Java'),
+
 
         generate_datetime_case()
         .skip_category('Rust'),