You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by bk...@apache.org on 2022/05/16 15:32:35 UTC

[arrow] branch master updated: ARROW-16525: [C++] Tee node not properly marking node finished

This is an automated email from the ASF dual-hosted git repository.

bkietz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new d040fb568c ARROW-16525: [C++] Tee node not properly marking node finished
d040fb568c is described below

commit d040fb568c8f7567e2d7c4e535f461743832003f
Author: Weston Pace <we...@gmail.com>
AuthorDate: Mon May 16 11:32:24 2022 -0400

    ARROW-16525: [C++] Tee node not properly marking node finished
    
    Closes #13117 from westonpace/feature/ARROW-16525--tee-node-not-marking-finished
    
    Lead-authored-by: Weston Pace <we...@gmail.com>
    Co-authored-by: %(trailers:key=Co-authored-by,valueonly)
    Signed-off-by: Benjamin Kietzman <be...@gmail.com>
---
 cpp/src/arrow/compute/exec/exec_plan.h    |   2 +-
 cpp/src/arrow/compute/exec/sink_node.cc   |   2 +-
 cpp/src/arrow/compute/exec/source_node.cc |   2 +-
 cpp/src/arrow/compute/exec/test_util.cc   |   6 ++
 cpp/src/arrow/compute/exec/test_util.h    |   3 +
 cpp/src/arrow/compute/exec/util.cc        |   2 +-
 cpp/src/arrow/dataset/file_base.cc        |  10 +++
 cpp/src/arrow/dataset/file_test.cc        | 107 ++++++++++++++++++++++++++++++
 8 files changed, 130 insertions(+), 4 deletions(-)

diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h
index c20dc0d048..be2f23ad24 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -316,7 +316,7 @@ class ARROW_EXPORT MapNode : public ExecNode {
  protected:
   void SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn, ExecBatch batch);
 
-  void Finish(Status finish_st = Status::OK());
+  virtual void Finish(Status finish_st = Status::OK());
 
  protected:
   // Counter for the number of batches received
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index 56573f61d7..bd6c3b79b8 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -363,7 +363,7 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl {
   }
 
  protected:
-  virtual void Finish(const Status& finish_st) {
+  void Finish(const Status& finish_st) {
     consumer_->Finish().AddCallback([this, finish_st](const Status& st) {
       // Prefer the plan error over the consumer error
       Status final_status = finish_st & st;
diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc
index 7e72497186..ec2b91050d 100644
--- a/cpp/src/arrow/compute/exec/source_node.cc
+++ b/cpp/src/arrow/compute/exec/source_node.cc
@@ -231,7 +231,7 @@ struct TableSourceNode : public SourceNode {
   static arrow::Status ValidateTableSourceNodeInput(const std::shared_ptr<Table> table,
                                                     const int64_t batch_size) {
     if (table == nullptr) {
-      return Status::Invalid("TableSourceNode node requires table which is not null");
+      return Status::Invalid("TableSourceNode requires table which is not null");
     }
 
     if (batch_size <= 0) {
diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc
index 41eb401ced..3f5d094774 100644
--- a/cpp/src/arrow/compute/exec/test_util.cc
+++ b/cpp/src/arrow/compute/exec/test_util.cc
@@ -165,6 +165,12 @@ ExecBatch ExecBatchFromJSON(const std::vector<ValueDescr>& descrs,
   return batch;
 }
 
+Future<> StartAndFinish(ExecPlan* plan) {
+  RETURN_NOT_OK(plan->Validate());
+  RETURN_NOT_OK(plan->StartProducing());
+  return plan->finished();
+}
+
 Future<std::vector<ExecBatch>> StartAndCollect(
     ExecPlan* plan, AsyncGenerator<util::optional<ExecBatch>> gen) {
   RETURN_NOT_OK(plan->Validate());
diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h
index 9347d1343f..9cb615ac45 100644
--- a/cpp/src/arrow/compute/exec/test_util.h
+++ b/cpp/src/arrow/compute/exec/test_util.h
@@ -82,6 +82,9 @@ struct BatchesWithSchema {
   }
 };
 
+ARROW_TESTING_EXPORT
+Future<> StartAndFinish(ExecPlan* plan);
+
 ARROW_TESTING_EXPORT
 Future<std::vector<ExecBatch>> StartAndCollect(
     ExecPlan* plan, AsyncGenerator<util::optional<ExecBatch>> gen);
diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc
index ef56e6128a..f6ac70ad45 100644
--- a/cpp/src/arrow/compute/exec/util.cc
+++ b/cpp/src/arrow/compute/exec/util.cc
@@ -287,7 +287,7 @@ namespace compute {
 Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inputs,
                               int expected_num_inputs, const char* kind_name) {
   if (static_cast<int>(inputs.size()) != expected_num_inputs) {
-    return Status::Invalid(kind_name, " node requires ", expected_num_inputs,
+    return Status::Invalid(kind_name, " requires ", expected_num_inputs,
                            " inputs but got ", inputs.size());
   }
 
diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc
index 822fc71462..1027781057 100644
--- a/cpp/src/arrow/dataset/file_base.cc
+++ b/cpp/src/arrow/dataset/file_base.cc
@@ -460,6 +460,16 @@ class TeeNode : public compute::MapNode {
 
   const char* kind_name() const override { return "TeeNode"; }
 
+  void Finish(Status finish_st) override {
+    dataset_writer_->Finish().AddCallback([this, finish_st](const Status& dw_status) {
+      // Need to wait for the task group to complete regardless of dw_status
+      task_group_.End().AddCallback(
+          [this, dw_status, finish_st](const Status& tg_status) {
+            finished_.MarkFinished(dw_status & finish_st & tg_status);
+          });
+    });
+  }
+
   Result<compute::ExecBatch> DoTee(const compute::ExecBatch& batch) {
     ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> record_batch,
                           batch.ToRecordBatch(output_schema()));
diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc
index 226c23ef5e..4dfc6bc584 100644
--- a/cpp/src/arrow/dataset/file_test.cc
+++ b/cpp/src/arrow/dataset/file_test.cc
@@ -18,14 +18,17 @@
 #include <cstdint>
 #include <memory>
 #include <string>
+#include <tuple>
 #include <vector>
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
 #include "arrow/array/array_primitive.h"
+#include "arrow/compute/exec/test_util.h"
 #include "arrow/dataset/api.h"
 #include "arrow/dataset/partition.h"
+#include "arrow/dataset/plan.h"
 #include "arrow/dataset/test_util.h"
 #include "arrow/filesystem/path_util.h"
 #include "arrow/filesystem/test_util.h"
@@ -34,6 +37,8 @@
 #include "arrow/testing/gtest_util.h"
 #include "arrow/util/io_util.h"
 
+namespace cp = arrow::compute;
+
 namespace arrow {
 
 using internal::TemporaryDir;
@@ -342,5 +347,107 @@ TEST_F(TestFileSystemDataset, WriteProjected) {
     }
   }
 }
+
+class FileSystemWriteTest : public testing::TestWithParam<std::tuple<bool, bool>> {
+  using PlanFactory = std::function<std::vector<cp::Declaration>(
+      const FileSystemDatasetWriteOptions&,
+      std::function<Future<util::optional<cp::ExecBatch>>()>*)>;
+
+ protected:
+  bool IsParallel() { return std::get<0>(GetParam()); }
+  bool IsSlow() { return std::get<1>(GetParam()); }
+
+  FileSystemWriteTest() { dataset::internal::Initialize(); }
+
+  void TestDatasetWriteRoundTrip(PlanFactory plan_factory, bool has_output) {
+    // Runs in-memory data through the plan and then scans out the written
+    // data to ensure it matches the source data
+    auto format = std::make_shared<IpcFileFormat>();
+    auto fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
+    FileSystemDatasetWriteOptions write_options;
+    write_options.file_write_options = format->DefaultWriteOptions();
+    write_options.filesystem = fs;
+    write_options.base_dir = "root";
+    write_options.partitioning = std::make_shared<HivePartitioning>(schema({}));
+    write_options.basename_template = "{i}.feather";
+    const std::string kExpectedFilename = "root/0.feather";
+
+    cp::BatchesWithSchema source_data;
+    source_data.batches = {
+        cp::ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"),
+        cp::ExecBatchFromJSON({int32(), boolean()},
+                              "[[5, null], [6, false], [7, false]]")};
+    source_data.schema = schema({field("i32", int32()), field("bool", boolean())});
+
+    AsyncGenerator<util::optional<cp::ExecBatch>> sink_gen;
+
+    ASSERT_OK_AND_ASSIGN(auto plan, cp::ExecPlan::Make());
+    auto source_decl = cp::Declaration::Sequence(
+        {{"source", cp::SourceNodeOptions{source_data.schema,
+                                          source_data.gen(IsParallel(), IsSlow())}}});
+    auto declarations = plan_factory(write_options, &sink_gen);
+    declarations.insert(declarations.begin(), std::move(source_decl));
+    ASSERT_OK(cp::Declaration::Sequence(std::move(declarations)).AddToPlan(plan.get()));
+
+    if (has_output) {
+      ASSERT_FINISHES_OK_AND_ASSIGN(auto out_batches,
+                                    cp::StartAndCollect(plan.get(), sink_gen));
+      cp::AssertExecBatchesEqual(source_data.schema, source_data.batches, out_batches);
+    } else {
+      ASSERT_FINISHES_OK(cp::StartAndFinish(plan.get()));
+    }
+
+    // Read written dataset and make sure it matches
+    ASSERT_OK_AND_ASSIGN(auto dataset_factory, FileSystemDatasetFactory::Make(
+                                                   fs, {kExpectedFilename}, format, {}));
+    ASSERT_OK_AND_ASSIGN(auto written_dataset, dataset_factory->Finish(FinishOptions{}));
+    AssertSchemaEqual(*source_data.schema, *written_dataset->schema());
+
+    ASSERT_OK_AND_ASSIGN(plan, cp::ExecPlan::Make());
+    ASSERT_OK_AND_ASSIGN(auto scanner_builder, written_dataset->NewScan());
+    ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
+    ASSERT_OK(cp::Declaration::Sequence(
+                  {
+                      {"scan", ScanNodeOptions{written_dataset, scanner->options()}},
+                      {"sink", cp::SinkNodeOptions{&sink_gen}},
+                  })
+                  .AddToPlan(plan.get()));
+
+    ASSERT_FINISHES_OK_AND_ASSIGN(auto written_batches,
+                                  cp::StartAndCollect(plan.get(), sink_gen));
+    cp::AssertExecBatchesEqual(source_data.schema, source_data.batches, written_batches);
+  }
+};
+
+TEST_P(FileSystemWriteTest, Write) {
+  auto plan_factory =
+      [](const FileSystemDatasetWriteOptions& write_options,
+         std::function<Future<util::optional<cp::ExecBatch>>()>* sink_gen) {
+        return std::vector<cp::Declaration>{{"write", WriteNodeOptions{write_options}}};
+      };
+  TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/false);
+}
+
+TEST_P(FileSystemWriteTest, TeeWrite) {
+  auto plan_factory =
+      [](const FileSystemDatasetWriteOptions& write_options,
+         std::function<Future<util::optional<cp::ExecBatch>>()>* sink_gen) {
+        return std::vector<cp::Declaration>{
+            {"tee", WriteNodeOptions{write_options}},
+            {"sink", cp::SinkNodeOptions{sink_gen}},
+        };
+      };
+  TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/true);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    FileSystemWrite, FileSystemWriteTest,
+    testing::Combine(testing::Values(false, true), testing::Values(false, true)),
+    [](const testing::TestParamInfo<FileSystemWriteTest::ParamType>& info) {
+      std::string parallel_desc = std::get<0>(info.param) ? "parallel" : "serial";
+      std::string speed_desc = std::get<1>(info.param) ? "slow" : "fast";
+      return parallel_desc + "_" + speed_desc;
+    });
+
 }  // namespace dataset
 }  // namespace arrow