You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/02/27 18:29:09 UTC

[arrow] branch master updated: ARROW-2392: [C++] Check schema compatibility when writing a RecordBatch

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 4a084b7  ARROW-2392: [C++] Check schema compatibility when writing a RecordBatch
4a084b7 is described below

commit 4a084b79f9ab5c1f73658a5e5ff3581f5b875c42
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Wed Feb 27 12:28:50 2019 -0600

    ARROW-2392: [C++] Check schema compatibility when writing a RecordBatch
    
    Author: Antoine Pitrou <an...@python.org>
    
    Closes #3762 from pitrou/ARROW-2392-record-batch-writer-invalid-schema and squashes the following commits:
    
    eb6b85ae <Antoine Pitrou> ARROW-2392:  Check schema compatibility when writing a RecordBatch
---
 cpp/src/arrow/ipc/read-write-test.cc     | 263 +++++++++++++++++--------------
 cpp/src/arrow/ipc/writer.cc              |   4 +
 cpp/src/arrow/util/key_value_metadata.cc |   5 +
 cpp/src/arrow/util/key_value_metadata.h  |   7 +
 4 files changed, 158 insertions(+), 121 deletions(-)

diff --git a/cpp/src/arrow/ipc/read-write-test.cc b/cpp/src/arrow/ipc/read-write-test.cc
index 867fd0e..6f4da28 100644
--- a/cpp/src/arrow/ipc/read-write-test.cc
+++ b/cpp/src/arrow/ipc/read-write-test.cc
@@ -566,121 +566,183 @@ TEST_F(RecursionLimits, StressLimit) {
 }
 #endif  // !defined(_WIN32) || defined(NDEBUG)
 
-class TestFileFormat : public ::testing::TestWithParam<MakeRecordBatch*> {
- public:
-  void SetUp() {
-    pool_ = default_memory_pool();
-    ASSERT_OK(AllocateResizableBuffer(pool_, 0, &buffer_));
+struct FileWriterHelper {
+  Status Init(const std::shared_ptr<Schema>& schema) {
+    num_batches_written_ = 0;
+
+    RETURN_NOT_OK(AllocateResizableBuffer(0, &buffer_));
     sink_.reset(new io::BufferOutputStream(buffer_));
-  }
-  void TearDown() {}
 
-  Status RoundTripHelper(const BatchVector& in_batches, BatchVector* out_batches) {
-    // Write the file
-    std::shared_ptr<RecordBatchWriter> writer;
-    RETURN_NOT_OK(
-        RecordBatchFileWriter::Open(sink_.get(), in_batches[0]->schema(), &writer));
+    return RecordBatchFileWriter::Open(sink_.get(), schema, &writer_);
+  }
 
-    const int num_batches = static_cast<int>(in_batches.size());
+  Status WriteBatch(const std::shared_ptr<RecordBatch>& batch) {
+    RETURN_NOT_OK(writer_->WriteRecordBatch(*batch));
+    num_batches_written_++;
+    return Status::OK();
+  }
 
-    for (const auto& batch : in_batches) {
-      RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
-    }
-    RETURN_NOT_OK(writer->Close());
+  Status Finish() {
+    RETURN_NOT_OK(writer_->Close());
     RETURN_NOT_OK(sink_->Close());
-
     // Current offset into stream is the end of the file
-    int64_t footer_offset;
-    RETURN_NOT_OK(sink_->Tell(&footer_offset));
+    return sink_->Tell(&footer_offset_);
+  }
 
-    // Open the file
+  Status ReadBatches(BatchVector* out_batches) {
     auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
     std::shared_ptr<RecordBatchFileReader> reader;
-    RETURN_NOT_OK(RecordBatchFileReader::Open(buf_reader.get(), footer_offset, &reader));
+    RETURN_NOT_OK(RecordBatchFileReader::Open(buf_reader.get(), footer_offset_, &reader));
 
-    EXPECT_EQ(num_batches, reader->num_record_batches());
-    for (int i = 0; i < num_batches; ++i) {
+    EXPECT_EQ(num_batches_written_, reader->num_record_batches());
+    for (int i = 0; i < num_batches_written_; ++i) {
       std::shared_ptr<RecordBatch> chunk;
       RETURN_NOT_OK(reader->ReadRecordBatch(i, &chunk));
-      out_batches->emplace_back(chunk);
+      out_batches->push_back(chunk);
     }
 
     return Status::OK();
   }
 
- protected:
-  MemoryPool* pool_;
-
-  std::unique_ptr<io::BufferOutputStream> sink_;
   std::shared_ptr<ResizableBuffer> buffer_;
+  std::unique_ptr<io::BufferOutputStream> sink_;
+  std::shared_ptr<RecordBatchWriter> writer_;
+  int num_batches_written_;
+  int64_t footer_offset_;
 };
 
-TEST_P(TestFileFormat, RoundTrip) {
-  std::shared_ptr<RecordBatch> batch1;
-  std::shared_ptr<RecordBatch> batch2;
-  ASSERT_OK((*GetParam())(&batch1));  // NOLINT clang-tidy gtest issue
-  ASSERT_OK((*GetParam())(&batch2));  // NOLINT clang-tidy gtest issue
+struct StreamWriterHelper {
+  Status Init(const std::shared_ptr<Schema>& schema) {
+    RETURN_NOT_OK(AllocateResizableBuffer(0, &buffer_));
+    sink_.reset(new io::BufferOutputStream(buffer_));
 
-  BatchVector in_batches = {batch1, batch2};
-  BatchVector out_batches;
+    return RecordBatchStreamWriter::Open(sink_.get(), schema, &writer_);
+  }
 
-  ASSERT_OK(RoundTripHelper(in_batches, &out_batches));
+  Status WriteBatch(const std::shared_ptr<RecordBatch>& batch) {
+    RETURN_NOT_OK(writer_->WriteRecordBatch(*batch));
+    return Status::OK();
+  }
 
-  // Compare batches
-  for (size_t i = 0; i < in_batches.size(); ++i) {
-    CompareBatch(*in_batches[i], *out_batches[i]);
+  Status Finish() {
+    RETURN_NOT_OK(writer_->Close());
+    return sink_->Close();
   }
-}
 
-class TestStreamFormat : public ::testing::TestWithParam<MakeRecordBatch*> {
- public:
-  void SetUp() {
-    pool_ = default_memory_pool();
-    ASSERT_OK(AllocateResizableBuffer(pool_, 0, &buffer_));
-    sink_.reset(new io::BufferOutputStream(buffer_));
+  Status ReadBatches(BatchVector* out_batches) {
+    auto buf_reader = std::make_shared<io::BufferReader>(buffer_);
+    std::shared_ptr<RecordBatchReader> reader;
+    RETURN_NOT_OK(RecordBatchStreamReader::Open(buf_reader, &reader));
+    return reader->ReadAll(out_batches);
   }
-  void TearDown() {}
 
-  Status RoundTripHelper(const BatchVector& batches, BatchVector* out_batches) {
-    // Write the file
-    std::shared_ptr<RecordBatchWriter> writer;
-    RETURN_NOT_OK(
-        RecordBatchStreamWriter::Open(sink_.get(), batches[0]->schema(), &writer));
+  std::shared_ptr<ResizableBuffer> buffer_;
+  std::unique_ptr<io::BufferOutputStream> sink_;
+  std::shared_ptr<RecordBatchWriter> writer_;
+};
+
+// Parameterized mixin with tests for RecordBatchStreamWriter / RecordBatchFileWriter
+
+template <class WriterHelperType>
+class ReaderWriterMixin {
+ public:
+  using WriterHelper = WriterHelperType;
+
+  // Check simple RecordBatch roundtripping
+  template <typename Param>
+  void TestRoundTrip(Param&& param) {
+    std::shared_ptr<RecordBatch> batch1;
+    std::shared_ptr<RecordBatch> batch2;
+    ASSERT_OK(param(&batch1));  // NOLINT clang-tidy gtest issue
+    ASSERT_OK(param(&batch2));  // NOLINT clang-tidy gtest issue
 
-    for (const auto& batch : batches) {
-      RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
+    BatchVector in_batches = {batch1, batch2};
+    BatchVector out_batches;
+
+    ASSERT_OK(RoundTripHelper(in_batches, &out_batches));
+    ASSERT_EQ(out_batches.size(), in_batches.size());
+
+    // Compare batches
+    for (size_t i = 0; i < in_batches.size(); ++i) {
+      CompareBatch(*in_batches[i], *out_batches[i]);
     }
-    RETURN_NOT_OK(writer->Close());
-    RETURN_NOT_OK(sink_->Close());
+  }
 
-    // Open the file
-    io::BufferReader buf_reader(buffer_);
+  void TestDictionaryRoundtrip() {
+    std::shared_ptr<RecordBatch> batch;
+    ASSERT_OK(MakeDictionary(&batch));
 
-    std::shared_ptr<RecordBatchReader> reader;
-    RETURN_NOT_OK(RecordBatchStreamReader::Open(&buf_reader, &reader));
-    return reader->ReadAll(out_batches);
+    BatchVector out_batches;
+    ASSERT_OK(RoundTripHelper({batch}, &out_batches));
+    ASSERT_EQ(out_batches.size(), 1);
+
+    CheckBatchDictionaries(*out_batches[0]);
   }
 
- protected:
-  MemoryPool* pool_;
+  void TestWriteDifferentSchema() {
+    // Test writing batches with a different schema than the RecordBatchWriter
+    // was initialized with.
+    std::shared_ptr<RecordBatch> batch_ints, batch_bools;
+    ASSERT_OK(MakeIntRecordBatch(&batch_ints));
+    ASSERT_OK(MakeBooleanBatch(&batch_bools));
+
+    std::shared_ptr<Schema> schema = batch_bools->schema();
+    ASSERT_FALSE(schema->HasMetadata());
+    schema = schema->AddMetadata(key_value_metadata({"some_key"}, {"some_value"}));
+
+    WriterHelper writer_helper;
+    ASSERT_OK(writer_helper.Init(schema));
+    // Writing a record batch with a different schema
+    ASSERT_RAISES(Invalid, writer_helper.WriteBatch(batch_ints));
+    // Writing a record batch with the same schema (except metadata)
+    ASSERT_OK(writer_helper.WriteBatch(batch_bools));
+    ASSERT_OK(writer_helper.Finish());
+
+    // The single successful batch can be read again
+    BatchVector out_batches;
+    ASSERT_OK(writer_helper.ReadBatches(&out_batches));
+    ASSERT_EQ(out_batches.size(), 1);
+    CompareBatch(*out_batches[0], *batch_bools, false /* compare_metadata */);
+    // Metadata from the RecordBatchWriter initialization schema was kept
+    ASSERT_TRUE(out_batches[0]->schema()->Equals(*schema));
+  }
 
-  std::unique_ptr<io::BufferOutputStream> sink_;
-  std::shared_ptr<ResizableBuffer> buffer_;
-};
+ private:
+  Status RoundTripHelper(const BatchVector& in_batches, BatchVector* out_batches) {
+    WriterHelper writer_helper;
+    RETURN_NOT_OK(writer_helper.Init(in_batches[0]->schema()));
+    for (const auto& batch : in_batches) {
+      RETURN_NOT_OK(writer_helper.WriteBatch(batch));
+    }
+    RETURN_NOT_OK(writer_helper.Finish());
+    return writer_helper.ReadBatches(out_batches);
+  }
 
-TEST_P(TestStreamFormat, RoundTrip) {
-  std::shared_ptr<RecordBatch> batch;
-  ASSERT_OK((*GetParam())(&batch));  // NOLINT clang-tidy gtest issue
+  void CheckBatchDictionaries(const RecordBatch& batch) {
+    // Check that dictionaries that should be the same are the same
+    auto schema = batch.schema();
 
-  BatchVector out_batches;
+    const auto& t0 = checked_cast<const DictionaryType&>(*schema->field(0)->type());
+    const auto& t1 = checked_cast<const DictionaryType&>(*schema->field(1)->type());
 
-  ASSERT_OK(RoundTripHelper({batch, batch, batch}, &out_batches));
+    ASSERT_EQ(t0.dictionary().get(), t1.dictionary().get());
 
-  // Compare batches. Same
-  for (size_t i = 0; i < out_batches.size(); ++i) {
-    CompareBatch(*batch, *out_batches[i]);
+    // Same dictionary used for list values
+    const auto& t3 = checked_cast<const ListType&>(*schema->field(3)->type());
+    const auto& t3_value = checked_cast<const DictionaryType&>(*t3.value_type());
+    ASSERT_EQ(t0.dictionary().get(), t3_value.dictionary().get());
   }
-}
+};
+
+class TestFileFormat : public ReaderWriterMixin<FileWriterHelper>,
+                       public ::testing::TestWithParam<MakeRecordBatch*> {};
+
+class TestStreamFormat : public ReaderWriterMixin<StreamWriterHelper>,
+                         public ::testing::TestWithParam<MakeRecordBatch*> {};
+
+TEST_P(TestFileFormat, RoundTrip) { TestRoundTrip(*GetParam()); }
+
+TEST_P(TestStreamFormat, RoundTrip) { TestRoundTrip(*GetParam()); }
 
 INSTANTIATE_TEST_CASE_P(GenericIpcRoundTripTests, TestIpcRoundTrip, BATCH_CASES());
 INSTANTIATE_TEST_CASE_P(FileRoundTripTests, TestFileFormat, BATCH_CASES());
@@ -719,54 +781,13 @@ TEST_F(TestIpcRoundTrip, LargeRecordBatch) {
 }
 #endif
 
-void CheckBatchDictionaries(const RecordBatch& batch) {
-  // Check that dictionaries that should be the same are the same
-  auto schema = batch.schema();
-
-  const auto& t0 = checked_cast<const DictionaryType&>(*schema->field(0)->type());
-  const auto& t1 = checked_cast<const DictionaryType&>(*schema->field(1)->type());
-
-  ASSERT_EQ(t0.dictionary().get(), t1.dictionary().get());
-
-  // Same dictionary used for list values
-  const auto& t3 = checked_cast<const ListType&>(*schema->field(3)->type());
-  const auto& t3_value = checked_cast<const DictionaryType&>(*t3.value_type());
-  ASSERT_EQ(t0.dictionary().get(), t3_value.dictionary().get());
-}
-
-TEST_F(TestStreamFormat, DictionaryRoundTrip) {
-  std::shared_ptr<RecordBatch> batch;
-  ASSERT_OK(MakeDictionary(&batch));
-
-  BatchVector out_batches;
-  ASSERT_OK(RoundTripHelper({batch}, &out_batches));
-
-  CheckBatchDictionaries(*out_batches[0]);
-}
+TEST_F(TestStreamFormat, DictionaryRoundTrip) { TestDictionaryRoundtrip(); }
 
-TEST_F(TestStreamFormat, WriteTable) {
-  std::shared_ptr<RecordBatch> b1, b2, b3;
-  ASSERT_OK(MakeIntRecordBatch(&b1));
-  ASSERT_OK(MakeIntRecordBatch(&b2));
-  ASSERT_OK(MakeIntRecordBatch(&b3));
+TEST_F(TestFileFormat, DictionaryRoundTrip) { TestDictionaryRoundtrip(); }
 
-  BatchVector out_batches;
-  ASSERT_OK(RoundTripHelper({b1, b2, b3}, &out_batches));
-
-  ASSERT_TRUE(b1->Equals(*out_batches[0]));
-  ASSERT_TRUE(b2->Equals(*out_batches[1]));
-  ASSERT_TRUE(b3->Equals(*out_batches[2]));
-}
+TEST_F(TestStreamFormat, DifferentSchema) { TestWriteDifferentSchema(); }
 
-TEST_F(TestFileFormat, DictionaryRoundTrip) {
-  std::shared_ptr<RecordBatch> batch;
-  ASSERT_OK(MakeDictionary(&batch));
-
-  BatchVector out_batches;
-  ASSERT_OK(RoundTripHelper({batch}, &out_batches));
-
-  CheckBatchDictionaries(*out_batches[0]);
-}
+TEST_F(TestFileFormat, DifferentSchema) { TestWriteDifferentSchema(); }
 
 class TestTensorRoundTrip : public ::testing::Test, public IpcTestFixture {
  public:
diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc
index c5a585a..ba99390 100644
--- a/cpp/src/arrow/ipc/writer.cc
+++ b/cpp/src/arrow/ipc/writer.cc
@@ -995,6 +995,10 @@ class RecordBatchStreamWriter::RecordBatchStreamWriterImpl : public StreamBookKe
 
   Status WriteRecordBatch(const RecordBatch& batch, bool allow_64bit) {
     // Push an empty FileBlock. Can be written in the footer later
+    if (!batch.schema()->Equals(*schema_, false /* check_metadata */)) {
+      return Status::Invalid("Tried to write record batch with different schema");
+    }
+
     record_batches_.push_back({0, 0, 0});
     return WriteRecordBatch(batch, allow_64bit,
                             &record_batches_[record_batches_.size() - 1]);
diff --git a/cpp/src/arrow/util/key_value_metadata.cc b/cpp/src/arrow/util/key_value_metadata.cc
index 73e3129..0b22403 100644
--- a/cpp/src/arrow/util/key_value_metadata.cc
+++ b/cpp/src/arrow/util/key_value_metadata.cc
@@ -140,4 +140,9 @@ std::shared_ptr<KeyValueMetadata> key_value_metadata(
   return std::make_shared<KeyValueMetadata>(pairs);
 }
 
+std::shared_ptr<KeyValueMetadata> key_value_metadata(
+    const std::vector<std::string>& keys, const std::vector<std::string>& values) {
+  return std::make_shared<KeyValueMetadata>(keys, values);
+}
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/util/key_value_metadata.h b/cpp/src/arrow/util/key_value_metadata.h
index 5759a85..2820c98 100644
--- a/cpp/src/arrow/util/key_value_metadata.h
+++ b/cpp/src/arrow/util/key_value_metadata.h
@@ -69,6 +69,13 @@ class ARROW_EXPORT KeyValueMetadata {
 std::shared_ptr<KeyValueMetadata> ARROW_EXPORT
 key_value_metadata(const std::unordered_map<std::string, std::string>& pairs);
 
+/// \brief Create a KeyValueMetadata instance
+///
+/// \param keys sequence of metadata keys
+/// \param values sequence of corresponding metadata values
+std::shared_ptr<KeyValueMetadata> ARROW_EXPORT key_value_metadata(
+    const std::vector<std::string>& keys, const std::vector<std::string>& values);
+
 }  // namespace arrow
 
 #endif  //  ARROW_UTIL_KEY_VALUE_METADATA_H