You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by bk...@apache.org on 2022/05/16 15:32:35 UTC
[arrow] branch master updated: ARROW-16525: [C++] Tee node not properly marking node finished
This is an automated email from the ASF dual-hosted git repository.
bkietz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d040fb568c ARROW-16525: [C++] Tee node not properly marking node finished
d040fb568c is described below
commit d040fb568c8f7567e2d7c4e535f461743832003f
Author: Weston Pace <we...@gmail.com>
AuthorDate: Mon May 16 11:32:24 2022 -0400
ARROW-16525: [C++] Tee node not properly marking node finished
Closes #13117 from westonpace/feature/ARROW-16525--tee-node-not-marking-finished
Lead-authored-by: Weston Pace <we...@gmail.com>
Co-authored-by: %(trailers:key=Co-authored-by,valueonly)
Signed-off-by: Benjamin Kietzman <be...@gmail.com>
---
cpp/src/arrow/compute/exec/exec_plan.h | 2 +-
cpp/src/arrow/compute/exec/sink_node.cc | 2 +-
cpp/src/arrow/compute/exec/source_node.cc | 2 +-
cpp/src/arrow/compute/exec/test_util.cc | 6 ++
cpp/src/arrow/compute/exec/test_util.h | 3 +
cpp/src/arrow/compute/exec/util.cc | 2 +-
cpp/src/arrow/dataset/file_base.cc | 10 +++
cpp/src/arrow/dataset/file_test.cc | 107 ++++++++++++++++++++++++++++++
8 files changed, 130 insertions(+), 4 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h
index c20dc0d048..be2f23ad24 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -316,7 +316,7 @@ class ARROW_EXPORT MapNode : public ExecNode {
protected:
void SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn, ExecBatch batch);
- void Finish(Status finish_st = Status::OK());
+ virtual void Finish(Status finish_st = Status::OK());
protected:
// Counter for the number of batches received
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index 56573f61d7..bd6c3b79b8 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -363,7 +363,7 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl {
}
protected:
- virtual void Finish(const Status& finish_st) {
+ void Finish(const Status& finish_st) {
consumer_->Finish().AddCallback([this, finish_st](const Status& st) {
// Prefer the plan error over the consumer error
Status final_status = finish_st & st;
diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc
index 7e72497186..ec2b91050d 100644
--- a/cpp/src/arrow/compute/exec/source_node.cc
+++ b/cpp/src/arrow/compute/exec/source_node.cc
@@ -231,7 +231,7 @@ struct TableSourceNode : public SourceNode {
static arrow::Status ValidateTableSourceNodeInput(const std::shared_ptr<Table> table,
const int64_t batch_size) {
if (table == nullptr) {
- return Status::Invalid("TableSourceNode node requires table which is not null");
+ return Status::Invalid("TableSourceNode requires table which is not null");
}
if (batch_size <= 0) {
diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc
index 41eb401ced..3f5d094774 100644
--- a/cpp/src/arrow/compute/exec/test_util.cc
+++ b/cpp/src/arrow/compute/exec/test_util.cc
@@ -165,6 +165,12 @@ ExecBatch ExecBatchFromJSON(const std::vector<ValueDescr>& descrs,
return batch;
}
+Future<> StartAndFinish(ExecPlan* plan) {
+ RETURN_NOT_OK(plan->Validate());
+ RETURN_NOT_OK(plan->StartProducing());
+ return plan->finished();
+}
+
Future<std::vector<ExecBatch>> StartAndCollect(
ExecPlan* plan, AsyncGenerator<util::optional<ExecBatch>> gen) {
RETURN_NOT_OK(plan->Validate());
diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h
index 9347d1343f..9cb615ac45 100644
--- a/cpp/src/arrow/compute/exec/test_util.h
+++ b/cpp/src/arrow/compute/exec/test_util.h
@@ -82,6 +82,9 @@ struct BatchesWithSchema {
}
};
+ARROW_TESTING_EXPORT
+Future<> StartAndFinish(ExecPlan* plan);
+
ARROW_TESTING_EXPORT
Future<std::vector<ExecBatch>> StartAndCollect(
ExecPlan* plan, AsyncGenerator<util::optional<ExecBatch>> gen);
diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc
index ef56e6128a..f6ac70ad45 100644
--- a/cpp/src/arrow/compute/exec/util.cc
+++ b/cpp/src/arrow/compute/exec/util.cc
@@ -287,7 +287,7 @@ namespace compute {
Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector<ExecNode*>& inputs,
int expected_num_inputs, const char* kind_name) {
if (static_cast<int>(inputs.size()) != expected_num_inputs) {
- return Status::Invalid(kind_name, " node requires ", expected_num_inputs,
+ return Status::Invalid(kind_name, " requires ", expected_num_inputs,
" inputs but got ", inputs.size());
}
diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc
index 822fc71462..1027781057 100644
--- a/cpp/src/arrow/dataset/file_base.cc
+++ b/cpp/src/arrow/dataset/file_base.cc
@@ -460,6 +460,16 @@ class TeeNode : public compute::MapNode {
const char* kind_name() const override { return "TeeNode"; }
+ void Finish(Status finish_st) override {
+ dataset_writer_->Finish().AddCallback([this, finish_st](const Status& dw_status) {
+ // Need to wait for the task group to complete regardless of dw_status
+ task_group_.End().AddCallback(
+ [this, dw_status, finish_st](const Status& tg_status) {
+ finished_.MarkFinished(dw_status & finish_st & tg_status);
+ });
+ });
+ }
+
Result<compute::ExecBatch> DoTee(const compute::ExecBatch& batch) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> record_batch,
batch.ToRecordBatch(output_schema()));
diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc
index 226c23ef5e..4dfc6bc584 100644
--- a/cpp/src/arrow/dataset/file_test.cc
+++ b/cpp/src/arrow/dataset/file_test.cc
@@ -18,14 +18,17 @@
#include <cstdint>
#include <memory>
#include <string>
+#include <tuple>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "arrow/array/array_primitive.h"
+#include "arrow/compute/exec/test_util.h"
#include "arrow/dataset/api.h"
#include "arrow/dataset/partition.h"
+#include "arrow/dataset/plan.h"
#include "arrow/dataset/test_util.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/filesystem/test_util.h"
@@ -34,6 +37,8 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/util/io_util.h"
+namespace cp = arrow::compute;
+
namespace arrow {
using internal::TemporaryDir;
@@ -342,5 +347,107 @@ TEST_F(TestFileSystemDataset, WriteProjected) {
}
}
}
+
+class FileSystemWriteTest : public testing::TestWithParam<std::tuple<bool, bool>> {
+ using PlanFactory = std::function<std::vector<cp::Declaration>(
+ const FileSystemDatasetWriteOptions&,
+ std::function<Future<util::optional<cp::ExecBatch>>()>*)>;
+
+ protected:
+ bool IsParallel() { return std::get<0>(GetParam()); }
+ bool IsSlow() { return std::get<1>(GetParam()); }
+
+ FileSystemWriteTest() { dataset::internal::Initialize(); }
+
+ void TestDatasetWriteRoundTrip(PlanFactory plan_factory, bool has_output) {
+ // Runs in-memory data through the plan and then scans out the written
+ // data to ensure it matches the source data
+ auto format = std::make_shared<IpcFileFormat>();
+ auto fs = std::make_shared<fs::internal::MockFileSystem>(fs::kNoTime);
+ FileSystemDatasetWriteOptions write_options;
+ write_options.file_write_options = format->DefaultWriteOptions();
+ write_options.filesystem = fs;
+ write_options.base_dir = "root";
+ write_options.partitioning = std::make_shared<HivePartitioning>(schema({}));
+ write_options.basename_template = "{i}.feather";
+ const std::string kExpectedFilename = "root/0.feather";
+
+ cp::BatchesWithSchema source_data;
+ source_data.batches = {
+ cp::ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"),
+ cp::ExecBatchFromJSON({int32(), boolean()},
+ "[[5, null], [6, false], [7, false]]")};
+ source_data.schema = schema({field("i32", int32()), field("bool", boolean())});
+
+ AsyncGenerator<util::optional<cp::ExecBatch>> sink_gen;
+
+ ASSERT_OK_AND_ASSIGN(auto plan, cp::ExecPlan::Make());
+ auto source_decl = cp::Declaration::Sequence(
+ {{"source", cp::SourceNodeOptions{source_data.schema,
+ source_data.gen(IsParallel(), IsSlow())}}});
+ auto declarations = plan_factory(write_options, &sink_gen);
+ declarations.insert(declarations.begin(), std::move(source_decl));
+ ASSERT_OK(cp::Declaration::Sequence(std::move(declarations)).AddToPlan(plan.get()));
+
+ if (has_output) {
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto out_batches,
+ cp::StartAndCollect(plan.get(), sink_gen));
+ cp::AssertExecBatchesEqual(source_data.schema, source_data.batches, out_batches);
+ } else {
+ ASSERT_FINISHES_OK(cp::StartAndFinish(plan.get()));
+ }
+
+ // Read written dataset and make sure it matches
+ ASSERT_OK_AND_ASSIGN(auto dataset_factory, FileSystemDatasetFactory::Make(
+ fs, {kExpectedFilename}, format, {}));
+ ASSERT_OK_AND_ASSIGN(auto written_dataset, dataset_factory->Finish(FinishOptions{}));
+ AssertSchemaEqual(*source_data.schema, *written_dataset->schema());
+
+ ASSERT_OK_AND_ASSIGN(plan, cp::ExecPlan::Make());
+ ASSERT_OK_AND_ASSIGN(auto scanner_builder, written_dataset->NewScan());
+ ASSERT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());
+ ASSERT_OK(cp::Declaration::Sequence(
+ {
+ {"scan", ScanNodeOptions{written_dataset, scanner->options()}},
+ {"sink", cp::SinkNodeOptions{&sink_gen}},
+ })
+ .AddToPlan(plan.get()));
+
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto written_batches,
+ cp::StartAndCollect(plan.get(), sink_gen));
+ cp::AssertExecBatchesEqual(source_data.schema, source_data.batches, written_batches);
+ }
+};
+
+TEST_P(FileSystemWriteTest, Write) {
+ auto plan_factory =
+ [](const FileSystemDatasetWriteOptions& write_options,
+ std::function<Future<util::optional<cp::ExecBatch>>()>* sink_gen) {
+ return std::vector<cp::Declaration>{{"write", WriteNodeOptions{write_options}}};
+ };
+ TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/false);
+}
+
+TEST_P(FileSystemWriteTest, TeeWrite) {
+ auto plan_factory =
+ [](const FileSystemDatasetWriteOptions& write_options,
+ std::function<Future<util::optional<cp::ExecBatch>>()>* sink_gen) {
+ return std::vector<cp::Declaration>{
+ {"tee", WriteNodeOptions{write_options}},
+ {"sink", cp::SinkNodeOptions{sink_gen}},
+ };
+ };
+ TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/true);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ FileSystemWrite, FileSystemWriteTest,
+ testing::Combine(testing::Values(false, true), testing::Values(false, true)),
+ [](const testing::TestParamInfo<FileSystemWriteTest::ParamType>& info) {
+ std::string parallel_desc = std::get<0>(info.param) ? "parallel" : "serial";
+ std::string speed_desc = std::get<1>(info.param) ? "slow" : "fast";
+ return parallel_desc + "_" + speed_desc;
+ });
+
} // namespace dataset
} // namespace arrow