You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2024/01/06 07:57:51 UTC

(arrow) branch main updated: GH-39163: [C++] Add missing data copy in StreamDecoder::Consume(data) (#39164)

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ab7a18fdc GH-39163: [C++] Add missing data copy in StreamDecoder::Consume(data) (#39164)
6ab7a18fdc is described below

commit 6ab7a18fdc4cc3a48c1f40da3b2fedd58f5bfc23
Author: Sutou Kouhei <ko...@clear-code.com>
AuthorDate: Sat Jan 6 16:57:44 2024 +0900

    GH-39163: [C++] Add missing data copy in StreamDecoder::Consume(data) (#39164)
    
    ### Rationale for this change
    
    We need to copy data for metadata message. Because it may be used in subsequent `Consume(data)` calls. We can't assume that the given `data` is still valid in subsequent `Consume(data)`.
    
    We also need to copy buffered `data` because it's used in subsequent `Consume(data)` calls.
    
    ### What changes are included in this PR?
    
    * Add missing copies.
    * Clean up existing buffer copy codes.
    * Change tests to use ephemeral `data` to detect this case.
    * Add `copy_record_batch` option to `CollectListener` to copy decoded record batches.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    Yes.
    
    * Closes #39163
    * Closes: #39163
    
    Authored-by: Sutou Kouhei <ko...@clear-code.com>
    Signed-off-by: Sutou Kouhei <ko...@clear-code.com>
---
 cpp/src/arrow/ipc/message.cc         | 37 +++++++++++++++++++++--------
 cpp/src/arrow/ipc/read_write_test.cc | 46 ++++++++++++++++++++++++++++++++----
 2 files changed, 68 insertions(+), 15 deletions(-)

diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc
index 36754518d2..fbcd6f139b 100644
--- a/cpp/src/arrow/ipc/message.cc
+++ b/cpp/src/arrow/ipc/message.cc
@@ -626,10 +626,24 @@ class MessageDecoder::MessageDecoderImpl {
             RETURN_NOT_OK(ConsumeMetadataLengthData(data, next_required_size_));
             break;
           case State::METADATA: {
-            auto buffer = std::make_shared<Buffer>(data, next_required_size_);
+            // We need to copy metadata because it's used in
+            // ConsumeBody(). ConsumeBody() may be called from another
+            // ConsumeData(). We can't assume that the given data for
+            // the current ConsumeData() call is still valid in the
+            // next ConsumeData() call. So we need to copy metadata
+            // here.
+            ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> buffer,
+                                  AllocateBuffer(next_required_size_, pool_));
+            memcpy(buffer->mutable_data(), data, next_required_size_);
             RETURN_NOT_OK(ConsumeMetadataBuffer(buffer));
           } break;
           case State::BODY: {
+            // We don't need to copy the given data for body because
+            // we can assume that a decoded record batch should be
+            // valid only in a listener_->OnMessageDecoded() call. If
+            // the passed message is needed to be valid after the
+            // call, it's a listener_'s responsibility. The listener_
+            // may copy the data for it.
             auto buffer = std::make_shared<Buffer>(data, next_required_size_);
             RETURN_NOT_OK(ConsumeBodyBuffer(buffer));
           } break;
@@ -645,7 +659,12 @@ class MessageDecoder::MessageDecoderImpl {
       return Status::OK();
     }
 
-    chunks_.push_back(std::make_shared<Buffer>(data, size));
+    // We need to copy unused data because the given data for the
+    // current ConsumeData() call may be invalid in the next
+    // ConsumeData() call.
+    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> chunk, AllocateBuffer(size, pool_));
+    memcpy(chunk->mutable_data(), data, size);
+    chunks_.push_back(std::move(chunk));
     buffered_size_ += size;
     return ConsumeChunks();
   }
@@ -830,8 +849,7 @@ class MessageDecoder::MessageDecoderImpl {
       }
       buffered_size_ -= next_required_size_;
     } else {
-      ARROW_ASSIGN_OR_RAISE(auto metadata, AllocateBuffer(next_required_size_, pool_));
-      metadata_ = std::shared_ptr<Buffer>(metadata.release());
+      ARROW_ASSIGN_OR_RAISE(metadata_, AllocateBuffer(next_required_size_, pool_));
       RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, metadata_->mutable_data()));
     }
     return ConsumeMetadata();
@@ -846,9 +864,8 @@ class MessageDecoder::MessageDecoderImpl {
     next_required_size_ = skip_body_ ? 0 : body_length;
     RETURN_NOT_OK(listener_->OnBody());
     if (next_required_size_ == 0) {
-      ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_));
-      std::shared_ptr<Buffer> shared_body(body.release());
-      return ConsumeBody(&shared_body);
+      auto body = std::make_shared<Buffer>(nullptr, 0);
+      return ConsumeBody(&body);
     } else {
       return Status::OK();
     }
@@ -872,10 +889,10 @@ class MessageDecoder::MessageDecoderImpl {
       buffered_size_ -= used_size;
       return Status::OK();
     } else {
-      ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(next_required_size_, pool_));
+      ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> body,
+                            AllocateBuffer(next_required_size_, pool_));
       RETURN_NOT_OK(ConsumeDataChunks(next_required_size_, body->mutable_data()));
-      std::shared_ptr<Buffer> shared_body(body.release());
-      return ConsumeBody(&shared_body);
+      return ConsumeBody(&body);
     }
   }
 
diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc
index 17c4c5636d..bd2c2b716d 100644
--- a/cpp/src/arrow/ipc/read_write_test.cc
+++ b/cpp/src/arrow/ipc/read_write_test.cc
@@ -1330,11 +1330,44 @@ struct StreamWriterHelper {
   std::shared_ptr<RecordBatchWriter> writer_;
 };
 
+class CopyCollectListener : public CollectListener {
+ public:
+  CopyCollectListener() : CollectListener() {}
+
+  Status OnRecordBatchWithMetadataDecoded(
+      RecordBatchWithMetadata record_batch_with_metadata) override {
+    auto& record_batch = record_batch_with_metadata.batch;
+    for (auto column_data : record_batch->column_data()) {
+      ARROW_RETURN_NOT_OK(CopyArrayData(column_data));
+    }
+    return CollectListener::OnRecordBatchWithMetadataDecoded(record_batch_with_metadata);
+  }
+
+ private:
+  Status CopyArrayData(std::shared_ptr<ArrayData> data) {
+    auto& buffers = data->buffers;
+    for (size_t i = 0; i < buffers.size(); ++i) {
+      auto& buffer = buffers[i];
+      if (!buffer) {
+        continue;
+      }
+      ARROW_ASSIGN_OR_RAISE(buffers[i], Buffer::Copy(buffer, buffer->memory_manager()));
+    }
+    for (auto child_data : data->child_data) {
+      ARROW_RETURN_NOT_OK(CopyArrayData(child_data));
+    }
+    if (data->dictionary) {
+      ARROW_RETURN_NOT_OK(CopyArrayData(data->dictionary));
+    }
+    return Status::OK();
+  }
+};
+
 struct StreamDecoderWriterHelper : public StreamWriterHelper {
   Status ReadBatches(const IpcReadOptions& options, RecordBatchVector* out_batches,
                      ReadStats* out_stats = nullptr,
                      MetadataVector* out_metadata_list = nullptr) override {
-    auto listener = std::make_shared<CollectListener>();
+    auto listener = std::make_shared<CopyCollectListener>();
     StreamDecoder decoder(listener, options);
     RETURN_NOT_OK(DoConsume(&decoder));
     *out_batches = listener->record_batches();
@@ -1358,7 +1391,10 @@ struct StreamDecoderWriterHelper : public StreamWriterHelper {
 
 struct StreamDecoderDataWriterHelper : public StreamDecoderWriterHelper {
   Status DoConsume(StreamDecoder* decoder) override {
-    return decoder->Consume(buffer_->data(), buffer_->size());
+    // This data is valid only in this function.
+    ARROW_ASSIGN_OR_RAISE(auto temporary_buffer,
+                          Buffer::Copy(buffer_, arrow::default_cpu_memory_manager()));
+    return decoder->Consume(temporary_buffer->data(), temporary_buffer->size());
   }
 };
 
@@ -1369,7 +1405,9 @@ struct StreamDecoderBufferWriterHelper : public StreamDecoderWriterHelper {
 struct StreamDecoderSmallChunksWriterHelper : public StreamDecoderWriterHelper {
   Status DoConsume(StreamDecoder* decoder) override {
     for (int64_t offset = 0; offset < buffer_->size() - 1; ++offset) {
-      RETURN_NOT_OK(decoder->Consume(buffer_->data() + offset, 1));
+      // This data is valid only in this block.
+      ARROW_ASSIGN_OR_RAISE(auto temporary_buffer, buffer_->CopySlice(offset, 1));
+      RETURN_NOT_OK(decoder->Consume(temporary_buffer->data(), temporary_buffer->size()));
     }
     return Status::OK();
   }
@@ -2172,7 +2210,6 @@ TEST(TestRecordBatchStreamReader, MalformedInput) {
   ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader));
 }
 
-namespace {
 class EndlessCollectListener : public CollectListener {
  public:
   EndlessCollectListener() : CollectListener(), decoder_(nullptr) {}
@@ -2184,7 +2221,6 @@ class EndlessCollectListener : public CollectListener {
  private:
   StreamDecoder* decoder_;
 };
-};  // namespace
 
 TEST(TestStreamDecoder, Reset) {
   auto listener = std::make_shared<EndlessCollectListener>();