You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2019/06/24 18:13:29 UTC

[arrow] branch master updated: ARROW-4885: [C++/Python] Enable Decimal parsing in CSV

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 4be0ba4  ARROW-4885: [C++/Python] Enable Decimal parsing in CSV
4be0ba4 is described below

commit 4be0ba495992cebcb5137927284e56c47041a1f7
Author: Micah Kornfield <em...@gmail.com>
AuthorDate: Mon Jun 24 20:13:20 2019 +0200

    ARROW-4885: [C++/Python] Enable Decimal parsing in CSV
    
    - Create a new Decimal128 converter in csv (copies some code for eliminating white space, please let me know if this should be factored out).
    - Add python unit test
    -  I filed ARROW-5699 to track 2 performance enhancements:
       *  Add an UnsafeAppend and use it on the Decimal128Builder
       *  Avoid multiple string copies in Decimal128::FromString.
    
    Author: Micah Kornfield <em...@gmail.com>
    
    Closes #4660 from emkornfield/decimal_csv and squashes the following commits:
    
    9933cd94f <Micah Kornfield> fix style
    b54013079 <Micah Kornfield> make format
    5ef3c6cf1 <Micah Kornfield> address review feedback
    86e03f973 <Micah Kornfield> Decimal128 CSV parsing
---
 cpp/src/arrow/csv/converter-test.cc | 50 ++++++++++++++++++++--
 cpp/src/arrow/csv/converter.cc      | 84 ++++++++++++++++++++++++++++++-------
 python/pyarrow/tests/test_csv.py    | 11 +++--
 3 files changed, 122 insertions(+), 23 deletions(-)

diff --git a/cpp/src/arrow/csv/converter-test.cc b/cpp/src/arrow/csv/converter-test.cc
index 105131f..531f40b 100644
--- a/cpp/src/arrow/csv/converter-test.cc
+++ b/cpp/src/arrow/csv/converter-test.cc
@@ -30,6 +30,7 @@
 #include "arrow/status.h"
 #include "arrow/testing/gtest_util.h"
 #include "arrow/type.h"
+#include "arrow/util/decimal.h"
 
 namespace arrow {
 namespace csv {
@@ -312,10 +313,51 @@ TEST(TimestampConversion, CustomNulls) {
                                            {{true}, {false}, {false}}, options);
 }
 
-TEST(DecimalConversion, NotImplemented) {
-  std::shared_ptr<Converter> converter;
-  ASSERT_RAISES(NotImplemented,
-                Converter::Make(decimal(12, 3), ConvertOptions::Defaults(), &converter));
+Decimal128 Dec128(util::string_view value) {
+  Decimal128 dec;
+  int32_t scale = 0;
+  int32_t precision = 0;
+  DCHECK_OK(Decimal128::FromString(value, &dec, &precision, &scale));
+  return dec;
+}
+
+TEST(DecimalConversion, Basics) {
+  AssertConversion<Decimal128Type, Decimal128>(
+      decimal(23, 2), {"12,34.5\n", "36.37,-1e5\n"},
+      {{Dec128("12.00"), Dec128("36.37")}, {Dec128("34.50"), Dec128("-100000.00")}});
+}
+
+TEST(DecimalConversion, Nulls) {
+  AssertConversion<Decimal128Type, Decimal128>(
+      decimal(14, 3), {"1.5,0.\n", ",-1e3\n"},
+      {{Dec128("1.500"), Decimal128()}, {Decimal128(), Dec128("-1000.000")}},
+      {{true, false}, {true, true}});
+
+  AssertConversionAllNulls<Decimal128Type, Decimal128>(decimal(14, 2));
+}
+
+TEST(DecimalConversion, CustomNulls) {
+  auto options = ConvertOptions::Defaults();
+  options.null_values = {"xxx", "zzz"};
+
+  AssertConversion<Decimal128Type, Decimal128>(
+      decimal(14, 3), {"1.5,xxx\n", "zzz,-1e3\n"},
+      {{Dec128("1.500"), Decimal128()}, {Decimal128(), Dec128("-1000.000")}},
+      {{true, false}, {false, true}}, options);
+}
+
+TEST(DecimalConversion, Whitespace) {
+  AssertConversion<Decimal128Type, Decimal128>(
+      decimal(5, 1), {" 12.00,34.5\n", " 0 ,-1e2 \n"},
+      {{Dec128("12.0"), Decimal128()}, {Dec128("34.5"), Dec128("-100.0")}});
+}
+
+TEST(DecimalConversion, OverflowFails) {
+  AssertConversionError(decimal(5, 0), {"1e6,0\n"}, {0});
+
+  AssertConversionError(decimal(5, 1), {"123.22\n"}, {0});
+  AssertConversionError(decimal(5, 1), {"12345.6\n"}, {0});
+  AssertConversionError(decimal(5, 1), {"1.61\n"}, {0});
 }
 
 }  // namespace csv
diff --git a/cpp/src/arrow/csv/converter.cc b/cpp/src/arrow/csv/converter.cc
index 6304c55..336a3bb 100644
--- a/cpp/src/arrow/csv/converter.cc
+++ b/cpp/src/arrow/csv/converter.cc
@@ -29,6 +29,7 @@
 #include "arrow/status.h"
 #include "arrow/type.h"
 #include "arrow/type_traits.h"
+#include "arrow/util/decimal.h"
 #include "arrow/util/parsing.h"  // IWYU pragma: keep
 #include "arrow/util/trie.h"
 #include "arrow/util/utf8.h"
@@ -56,6 +57,28 @@ inline bool IsWhitespace(uint8_t c) {
   return c == ' ' || c == '\t';
 }
 
+// Updates data_inout and size_inout to not include leading/trailing whitespace
+// characters.
+inline void TrimWhiteSpace(const uint8_t** data_inout, uint32_t* size_inout) {
+  const uint8_t*& data = *data_inout;
+  uint32_t& size = *size_inout;
+  // Skip trailing whitespace
+  if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[size - 1]))) {
+    const uint8_t* p = data + size - 1;
+    while (size > 0 && IsWhitespace(*p)) {
+      --size;
+      --p;
+    }
+  }
+  // Skip leading whitespace
+  if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[0]))) {
+    while (size > 0 && IsWhitespace(*data)) {
+      --size;
+      ++data;
+    }
+  }
+}
+
 Status InitializeTrie(const std::vector<std::string>& inputs, Trie* trie) {
   TrieBuilder builder;
   for (const auto& s : inputs) {
@@ -280,22 +303,7 @@ Status NumericConverter<T>::Convert(const BlockParser& parser, int32_t col_index
       return Status::OK();
     }
     if (!std::is_same<BooleanType, T>::value) {
-      // Skip trailing whitespace
-      if (ARROW_PREDICT_TRUE(size > 0) &&
-          ARROW_PREDICT_FALSE(IsWhitespace(data[size - 1]))) {
-        const uint8_t* p = data + size - 1;
-        while (size > 0 && IsWhitespace(*p)) {
-          --size;
-          --p;
-        }
-      }
-      // Skip leading whitespace
-      if (ARROW_PREDICT_TRUE(size > 0) && ARROW_PREDICT_FALSE(IsWhitespace(data[0]))) {
-        while (size > 0 && IsWhitespace(*data)) {
-          --size;
-          ++data;
-        }
-      }
+      TrimWhiteSpace(&data, &size);
     }
     if (ARROW_PREDICT_FALSE(
             !converter(reinterpret_cast<const char*>(data), size, &value))) {
@@ -346,6 +354,49 @@ class TimestampConverter : public ConcreteConverter {
   }
 };
 
+/////////////////////////////////////////////////////////////////////////
+// Concrete Converter for Decimals
+
+class DecimalConverter : public ConcreteConverter {
+ public:
+  using ConcreteConverter::ConcreteConverter;
+
+  Status Convert(const BlockParser& parser, int32_t col_index,
+                 std::shared_ptr<Array>* out) override {
+    Decimal128Builder builder(type_, pool_);
+
+    auto visit = [&](const uint8_t* data, uint32_t size, bool quoted) -> Status {
+      if (IsNull(data, size, quoted)) {
+        builder.UnsafeAppendNull();
+        return Status::OK();
+      }
+      TrimWhiteSpace(&data, &size);
+      Decimal128 decimal;
+      int32_t precision, scale;
+      util::string_view view(reinterpret_cast<const char*>(data), size);
+      RETURN_NOT_OK(Decimal128::FromString(view, &decimal, &precision, &scale));
+      DecimalType& type = *internal::checked_cast<DecimalType*>(type_.get());
+      if (precision > type.precision()) {
+        return Status::Invalid("Error converting ", view, " to ", type_->ToString(),
+                               " precision not supported by type.");
+      }
+      if (scale != type.scale()) {
+        Decimal128 scaled;
+        RETURN_NOT_OK(decimal.Rescale(scale, type.scale(), &scaled));
+        RETURN_NOT_OK(builder.Append(scaled.ToBytes()));
+      } else {
+        RETURN_NOT_OK(builder.Append(decimal.ToBytes()));
+      }
+      return Status::OK();
+    };
+    RETURN_NOT_OK(builder.Resize(parser.num_rows()));
+    RETURN_NOT_OK(parser.VisitColumn(col_index, visit));
+    RETURN_NOT_OK(builder.Finish(out));
+
+    return Status::OK();
+  }
+};
+
 }  // namespace
 
 /////////////////////////////////////////////////////////////////////////
@@ -381,6 +432,7 @@ Status Converter::Make(const std::shared_ptr<DataType>& type,
     CONVERTER_CASE(Type::TIMESTAMP, TimestampConverter)
     CONVERTER_CASE(Type::BINARY, (VarSizeBinaryConverter<BinaryType, false>))
     CONVERTER_CASE(Type::FIXED_SIZE_BINARY, FixedSizeBinaryConverter)
+    CONVERTER_CASE(Type::DECIMAL, DecimalConverter)
 
     case Type::STRING:
       if (options.check_utf8) {
diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py
index e3c1c37..df4d0a5 100644
--- a/python/pyarrow/tests/test_csv.py
+++ b/python/pyarrow/tests/test_csv.py
@@ -17,6 +17,7 @@
 
 import bz2
 from datetime import datetime
+from decimal import Decimal
 import gzip
 import io
 import itertools
@@ -342,18 +343,21 @@ class BaseTestCSVRead:
         opts = ConvertOptions(column_types={'b': 'float32',
                                             'c': 'string',
                                             'd': 'boolean',
+                                            'e': pa.decimal128(11, 2),
                                             'zz': 'null'})
-        rows = b"a,b,c,d\n1,2,3,true\n4,-5,6,false\n"
+        rows = b"a,b,c,d,e\n1,2,3,true,1.0\n4,-5,6,false,0\n"
         table = self.read_bytes(rows, convert_options=opts)
         schema = pa.schema([('a', pa.int64()),
                             ('b', pa.float32()),
                             ('c', pa.string()),
-                            ('d', pa.bool_())])
+                            ('d', pa.bool_()),
+                            ('e', pa.decimal128(11, 2))])
         expected = {
             'a': [1, 4],
             'b': [2.0, -5.0],
             'c': ["3", "6"],
             'd': [True, False],
+            'e': [Decimal("1.00"), Decimal("0.00")]
             }
         assert table.schema == schema
         assert table.to_pydict() == expected
@@ -362,12 +366,13 @@ class BaseTestCSVRead:
             column_types=pa.schema([('b', pa.float32()),
                                     ('c', pa.string()),
                                     ('d', pa.bool_()),
+                                    ('e', pa.decimal128(11, 2)),
                                     ('zz', pa.bool_())]))
         table = self.read_bytes(rows, convert_options=opts)
         assert table.schema == schema
         assert table.to_pydict() == expected
         # One of the columns in column_types fails converting
-        rows = b"a,b,c,d\n1,XXX,3,true\n4,-5,6,false\n"
+        rows = b"a,b,c,d,e\n1,XXX,3,true,5\n4,-5,6,false,7\n"
         with pytest.raises(pa.ArrowInvalid) as exc:
             self.read_bytes(rows, convert_options=opts)
         err = str(exc.value)