You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2022/12/22 19:22:21 UTC
[arrow] branch master updated: ARROW-17837: [C++][Acero] Create ExecPlan-owned QueryContext that will store a plan's shared data structures (#14227)
This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 54c487bf27 ARROW-17837: [C++][Acero] Create ExecPlan-owned QueryContext that will store a plan's shared data structures (#14227)
54c487bf27 is described below
commit 54c487bf279a6129547e893d8b90a6813f1f12ab
Author: Sasha Krassovsky <kr...@gmail.com>
AuthorDate: Thu Dec 22 11:22:11 2022 -0800
ARROW-17837: [C++][Acero] Create ExecPlan-owned QueryContext that will store a plan's shared data structures (#14227)
Lead-authored-by: Sasha Krassovsky <kr...@gmail.com>
Co-authored-by: Weston Pace <we...@gmail.com>
Signed-off-by: Weston Pace <we...@gmail.com>
---
cpp/src/arrow/CMakeLists.txt | 1 +
cpp/src/arrow/compute/exec.h | 6 -
cpp/src/arrow/compute/exec/aggregate_node.cc | 51 ++++---
cpp/src/arrow/compute/exec/asof_join_node.cc | 6 +-
cpp/src/arrow/compute/exec/exec_plan.cc | 103 +++-----------
cpp/src/arrow/compute/exec/exec_plan.h | 83 +----------
cpp/src/arrow/compute/exec/filter_node.cc | 11 +-
cpp/src/arrow/compute/exec/hash_join.cc | 31 +++--
cpp/src/arrow/compute/exec/hash_join.h | 3 +-
cpp/src/arrow/compute/exec/hash_join_benchmark.cc | 15 +-
cpp/src/arrow/compute/exec/hash_join_node.cc | 109 +++++++--------
cpp/src/arrow/compute/exec/project_node.cc | 10 +-
cpp/src/arrow/compute/exec/query_context.cc | 95 +++++++++++++
cpp/src/arrow/compute/exec/query_context.h | 161 ++++++++++++++++++++++
cpp/src/arrow/compute/exec/sink_node.cc | 19 +--
cpp/src/arrow/compute/exec/source_node.cc | 44 +++---
cpp/src/arrow/compute/exec/swiss_join.cc | 26 ++--
cpp/src/arrow/compute/exec/tpch_node.cc | 7 +-
cpp/src/arrow/compute/type_fwd.h | 2 +
cpp/src/arrow/dataset/file_base.cc | 12 +-
cpp/src/arrow/dataset/scan_node.cc | 20 +--
cpp/src/arrow/dataset/scanner.cc | 8 +-
cpp/src/arrow/util/io_util.cc | 31 +++++
cpp/src/arrow/util/io_util.h | 6 +
cpp/src/arrow/util/io_util_test.cc | 12 ++
cpp/src/arrow/util/rle_encoding.h | 31 +++--
cpp/src/arrow/util/type_fwd.h | 1 +
27 files changed, 550 insertions(+), 354 deletions(-)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index e1e409d0a7..78fff49d32 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -403,6 +403,7 @@ if(ARROW_COMPUTE)
compute/exec/partition_util.cc
compute/exec/options.cc
compute/exec/project_node.cc
+ compute/exec/query_context.cc
compute/exec/sink_node.cc
compute/exec/source_node.cc
compute/exec/swiss_join.cc
diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h
index a1f72b3e50..30f4113f6c 100644
--- a/cpp/src/arrow/compute/exec.h
+++ b/cpp/src/arrow/compute/exec.h
@@ -39,12 +39,6 @@
#include "arrow/util/visibility.h"
namespace arrow {
-namespace internal {
-
-class CpuInfo;
-
-} // namespace internal
-
namespace compute {
// It seems like 64K might be a good default chunksize to use for execution
diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc
index cca266ad69..0b70577ae7 100644
--- a/cpp/src/arrow/compute/exec/aggregate_node.cc
+++ b/cpp/src/arrow/compute/exec/aggregate_node.cc
@@ -24,6 +24,7 @@
#include "arrow/compute/exec/aggregate.h"
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/util.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/registry.h"
@@ -83,7 +84,7 @@ class ScalarAggregateNode : public ExecNode {
auto aggregates = aggregate_options.aggregates;
const auto& input_schema = *inputs[0]->output_schema();
- auto exec_ctx = plan->exec_context();
+ auto exec_ctx = plan->query_context()->exec_context();
std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
std::vector<std::vector<std::unique_ptr<KernelState>>> states(kernels.size());
@@ -113,7 +114,7 @@ class ScalarAggregateNode : public ExecNode {
}
KernelContext kernel_ctx{exec_ctx};
- states[i].resize(plan->max_concurrency());
+ states[i].resize(plan->query_context()->max_concurrency());
RETURN_NOT_OK(Kernel::InitAll(&kernel_ctx,
KernelInitArgs{kernels[i],
{
@@ -150,7 +151,7 @@ class ScalarAggregateNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Consume"}});
- KernelContext batch_ctx{plan()->exec_context()};
+ KernelContext batch_ctx{plan()->query_context()->exec_context()};
batch_ctx.SetState(states_[i][thread_index].get());
ExecSpan single_column_batch{{batch.values[target_field_ids_[i]]}, batch.length};
@@ -168,7 +169,7 @@ class ScalarAggregateNode : public ExecNode {
{"batch.length", batch.length}});
DCHECK_EQ(input, inputs_[0]);
- auto thread_index = plan_->GetThreadIndex();
+ auto thread_index = plan_->query_context()->GetThreadIndex();
if (ErrorIfNotOk(DoConsume(ExecSpan(batch), thread_index))) return;
@@ -245,7 +246,7 @@ class ScalarAggregateNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Finalize"}});
- KernelContext ctx{plan()->exec_context()};
+ KernelContext ctx{plan()->query_context()->exec_context()};
ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll(
kernels_[i], &ctx, std::move(states_[i])));
RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
@@ -267,20 +268,19 @@ class ScalarAggregateNode : public ExecNode {
class GroupByNode : public ExecNode {
public:
- GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema, ExecContext* ctx,
+ GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema,
std::vector<int> key_field_ids, std::vector<int> agg_src_field_ids,
std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema),
/*num_outputs=*/1),
- ctx_(ctx),
key_field_ids_(std::move(key_field_ids)),
agg_src_field_ids_(std::move(agg_src_field_ids)),
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}
Status Init() override {
- output_task_group_id_ = plan_->RegisterTaskGroup(
+ output_task_group_id_ = plan_->query_context()->RegisterTaskGroup(
[this](size_t, int64_t task_id) {
OutputNthBatch(task_id);
return Status::OK();
@@ -326,7 +326,7 @@ class GroupByNode : public ExecNode {
agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get();
}
- auto ctx = input->plan()->exec_context();
+ auto ctx = plan->query_context()->exec_context();
// Construct aggregates
ARROW_ASSIGN_OR_RAISE(auto agg_kernels,
@@ -354,7 +354,7 @@ class GroupByNode : public ExecNode {
}
return input->plan()->EmplaceNode<GroupByNode>(
- input, schema(std::move(output_fields)), ctx, std::move(key_field_ids),
+ input, schema(std::move(output_fields)), std::move(key_field_ids),
std::move(agg_src_field_ids), std::move(aggs), std::move(agg_kernels));
}
@@ -366,7 +366,7 @@ class GroupByNode : public ExecNode {
{{"group_by", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});
- size_t thread_index = plan_->GetThreadIndex();
+ size_t thread_index = plan_->query_context()->GetThreadIndex();
if (thread_index >= local_states_.size()) {
return Status::IndexError("thread index ", thread_index, " is out of range [0, ",
local_states_.size(), ")");
@@ -393,7 +393,8 @@ class GroupByNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Consume"}});
- KernelContext kernel_ctx{ctx_};
+ auto ctx = plan_->query_context()->exec_context();
+ KernelContext kernel_ctx{ctx};
kernel_ctx.SetState(state->agg_states[i].get());
ExecSpan agg_batch({batch[agg_src_field_ids_[i]], ExecValue(*id_batch.array())},
@@ -429,7 +430,9 @@ class GroupByNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Merge"}});
- KernelContext batch_ctx{ctx_};
+
+ auto ctx = plan_->query_context()->exec_context();
+ KernelContext batch_ctx{ctx};
DCHECK(state0->agg_states[i]);
batch_ctx.SetState(state0->agg_states[i].get());
@@ -462,7 +465,7 @@ class GroupByNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Finalize"}});
- KernelContext batch_ctx{ctx_};
+ KernelContext batch_ctx{plan_->query_context()->exec_context()};
batch_ctx.SetState(state->agg_states[i].get());
RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out_data.values[i]));
state->agg_states[i].reset();
@@ -497,7 +500,8 @@ class GroupByNode : public ExecNode {
int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size());
outputs_[0]->InputFinished(this, static_cast<int>(num_output_batches));
- RETURN_NOT_OK(plan_->StartTaskGroup(output_task_group_id_, num_output_batches));
+ RETURN_NOT_OK(plan_->query_context()->StartTaskGroup(output_task_group_id_,
+ num_output_batches));
return Status::OK();
}
@@ -548,7 +552,7 @@ class GroupByNode : public ExecNode {
{"node.detail", ToString()},
{"node.kind", kind_name()}});
- local_states_.resize(plan_->max_concurrency());
+ local_states_.resize(plan_->query_context()->max_concurrency());
return Status::OK();
}
@@ -593,7 +597,7 @@ class GroupByNode : public ExecNode {
};
ThreadLocalState* GetLocalState() {
- size_t thread_index = plan_->GetThreadIndex();
+ size_t thread_index = plan_->query_context()->GetThreadIndex();
return &local_states_[thread_index];
}
@@ -611,7 +615,8 @@ class GroupByNode : public ExecNode {
}
// Construct grouper
- ARROW_ASSIGN_OR_RAISE(state->grouper, Grouper::Make(key_types, ctx_));
+ ARROW_ASSIGN_OR_RAISE(
+ state->grouper, Grouper::Make(key_types, plan_->query_context()->exec_context()));
// Build vector of aggregate source field data types
std::vector<TypeHolder> agg_src_types(agg_kernels_.size());
@@ -620,21 +625,23 @@ class GroupByNode : public ExecNode {
agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get();
}
- ARROW_ASSIGN_OR_RAISE(state->agg_states, internal::InitKernels(agg_kernels_, ctx_,
- aggs_, agg_src_types));
+ ARROW_ASSIGN_OR_RAISE(
+ state->agg_states,
+ internal::InitKernels(agg_kernels_, plan_->query_context()->exec_context(), aggs_,
+ agg_src_types));
return Status::OK();
}
int output_batch_size() const {
- int result = static_cast<int>(ctx_->exec_chunksize());
+ int result =
+ static_cast<int>(plan_->query_context()->exec_context()->exec_chunksize());
if (result < 0) {
result = 32 * 1024;
}
return result;
}
- ExecContext* ctx_;
int output_task_group_id_;
const std::vector<int> key_field_ids_;
diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc
index 22f8e5142d..979012e63e 100644
--- a/cpp/src/arrow/compute/exec/asof_join_node.cc
+++ b/cpp/src/arrow/compute/exec/asof_join_node.cc
@@ -29,6 +29,7 @@
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/key_hash.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/schema_util.h"
#include "arrow/compute/exec/util.h"
#include "arrow/compute/light_array.h"
@@ -801,7 +802,7 @@ class AsofJoinNode : public ExecNode {
if (dst.empty()) {
return NULLPTR;
} else {
- return dst.Materialize(plan()->exec_context()->memory_pool(), output_schema(),
+ return dst.Materialize(plan()->query_context()->memory_pool(), output_schema(),
state_);
}
}
@@ -861,7 +862,8 @@ class AsofJoinNode : public ExecNode {
Status Init() override {
auto inputs = this->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
- RETURN_NOT_OK(key_hashers_[i]->Init(plan()->exec_context(), output_schema()));
+ RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(),
+ output_schema()));
state_.push_back(std::make_unique<InputState>(
must_hash_, may_rehash_, key_hashers_[i].get(), inputs[i]->output_schema(),
indices_of_on_key_[i], indices_of_by_key_[i]));
diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc
index 22250d2f99..666ab1d8c0 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.cc
+++ b/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -25,9 +25,8 @@
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/task_util.h"
-#include "arrow/compute/exec/util.h"
-#include "arrow/compute/exec_internal.h"
#include "arrow/compute/registry.h"
#include "arrow/datum.h"
#include "arrow/record_batch.h"
@@ -51,9 +50,9 @@ namespace compute {
namespace {
struct ExecPlanImpl : public ExecPlan {
- explicit ExecPlanImpl(ExecContext* exec_context,
+ explicit ExecPlanImpl(QueryOptions options, ExecContext* exec_context,
std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR)
- : ExecPlan(exec_context), metadata_(std::move(metadata)) {}
+ : metadata_(std::move(metadata)), query_context_(options, *exec_context) {}
~ExecPlanImpl() override {
if (started_ && !finished_.is_finished()) {
@@ -63,9 +62,6 @@ struct ExecPlanImpl : public ExecPlan {
}
}
- size_t GetThreadIndex() { return thread_indexer_(); }
- size_t max_concurrency() const { return thread_indexer_.Capacity(); }
-
ExecNode* AddNode(std::unique_ptr<ExecNode> node) {
if (node->label().empty()) {
node->SetLabel(ToChars(auto_label_counter_++));
@@ -80,44 +76,6 @@ struct ExecPlanImpl : public ExecPlan {
return nodes_.back().get();
}
- Result<Future<>> BeginExternalTask() {
- Future<> completion_future = Future<>::Make();
- if (async_scheduler_->AddSimpleTask(
- [completion_future] { return completion_future; })) {
- return completion_future;
- }
- return Future<>{};
- }
-
- Status ScheduleTask(std::function<Status()> fn) {
- auto executor = exec_context_->executor();
- if (!executor) return fn();
- // Adds a task which submits fn to the executor and tracks its progress. If we're
- // aborted then the task is ignored and fn is not executed.
- async_scheduler_->AddSimpleTask(
- [executor, fn]() { return executor->Submit(std::move(fn)); });
- return Status::OK();
- }
-
- Status ScheduleTask(std::function<Status(size_t)> fn) {
- std::function<Status()> indexed_fn = [this, fn]() {
- size_t thread_index = GetThreadIndex();
- return fn(thread_index);
- };
- return ScheduleTask(std::move(indexed_fn));
- }
-
- int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
- std::function<Status(size_t)> on_finished) {
- return task_scheduler_->RegisterTaskGroup(std::move(task), std::move(on_finished));
- }
-
- Status StartTaskGroup(int task_group_id, int64_t num_tasks) {
- return task_scheduler_->StartTaskGroup(GetThreadIndex(), task_group_id, num_tasks);
- }
-
- util::AsyncTaskScheduler* async_scheduler() { return async_scheduler_; }
-
Status Validate() const {
if (nodes_.empty()) {
return Status::Invalid("ExecPlan has no node");
@@ -132,6 +90,7 @@ struct ExecPlanImpl : public ExecPlan {
if (started_) {
return Status::Invalid("restarted ExecPlan");
}
+
started_ = true;
// We call StartProducing on each of the nodes. The source nodes should generally
@@ -142,7 +101,9 @@ struct ExecPlanImpl : public ExecPlan {
// call.
Future<> scheduler_finished =
util::AsyncTaskScheduler::Make([this](util::AsyncTaskScheduler* async_scheduler) {
- this->async_scheduler_ = async_scheduler;
+ QueryContext* ctx = query_context();
+ RETURN_NOT_OK(ctx->Init(ctx->max_concurrency(), async_scheduler));
+
START_COMPUTE_SPAN(span_, "ExecPlan", {{"plan", ToString()}});
#ifdef ARROW_WITH_OPENTELEMETRY
if (HasMetadata()) {
@@ -162,17 +123,17 @@ struct ExecPlanImpl : public ExecPlan {
async_scheduler->AddSimpleTask([&] { return n->finished(); });
}
- task_scheduler_->RegisterEnd();
+ ctx->scheduler()->RegisterEnd();
int num_threads = 1;
bool sync_execution = true;
- if (auto executor = exec_context()->executor()) {
+ if (auto executor = query_context()->exec_context()->executor()) {
num_threads = executor->GetCapacity();
sync_execution = false;
}
- RETURN_NOT_OK(task_scheduler_->StartScheduling(
+ RETURN_NOT_OK(ctx->scheduler()->StartScheduling(
0 /* thread_index */,
- [this](std::function<Status(size_t)> fn) -> Status {
- return this->ScheduleTask(std::move(fn));
+ [ctx](std::function<Status(size_t)> fn) -> Status {
+ return ctx->ScheduleTask(std::move(fn));
},
/*concurrent_tasks=*/2 * num_threads, sync_execution));
@@ -219,7 +180,7 @@ struct ExecPlanImpl : public ExecPlan {
DCHECK(started_) << "stopped an ExecPlan which never started";
EVENT(span_, "StopProducing");
stopped_ = true;
- task_scheduler_->Abort(
+ query_context()->scheduler()->Abort(
[this]() { StopProducingImpl(sorted_nodes_.begin(), sorted_nodes_.end()); });
}
@@ -335,10 +296,7 @@ struct ExecPlanImpl : public ExecPlan {
uint32_t auto_label_counter_ = 0;
util::tracing::Span span_;
std::shared_ptr<const KeyValueMetadata> metadata_;
-
- ThreadIndexer thread_indexer_;
- util::AsyncTaskScheduler* async_scheduler_ = nullptr;
- std::unique_ptr<TaskScheduler> task_scheduler_ = TaskScheduler::Make();
+ QueryContext query_context_;
};
ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast<ExecPlanImpl*>(ptr); }
@@ -359,9 +317,15 @@ std::optional<int> GetNodeIndex(const std::vector<ExecNode*>& nodes,
const uint32_t ExecPlan::kMaxBatchSize;
+Result<std::shared_ptr<ExecPlan>> ExecPlan::Make(
+ QueryOptions opts, ExecContext* ctx,
+ std::shared_ptr<const KeyValueMetadata> metadata) {
+ return std::shared_ptr<ExecPlan>(new ExecPlanImpl{opts, ctx, std::move(metadata)});
+}
+
Result<std::shared_ptr<ExecPlan>> ExecPlan::Make(
ExecContext* ctx, std::shared_ptr<const KeyValueMetadata> metadata) {
- return std::shared_ptr<ExecPlan>(new ExecPlanImpl{ctx, metadata});
+ return Make(/*opts=*/{}, ctx, std::move(metadata));
}
ExecNode* ExecPlan::AddNode(std::unique_ptr<ExecNode> node) {
@@ -374,30 +338,7 @@ const ExecPlan::NodeVector& ExecPlan::sources() const {
const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; }
-size_t ExecPlan::GetThreadIndex() { return ToDerived(this)->GetThreadIndex(); }
-size_t ExecPlan::max_concurrency() const { return ToDerived(this)->max_concurrency(); }
-
-Result<Future<>> ExecPlan::BeginExternalTask() {
- return ToDerived(this)->BeginExternalTask();
-}
-
-Status ExecPlan::ScheduleTask(std::function<Status()> fn) {
- return ToDerived(this)->ScheduleTask(std::move(fn));
-}
-Status ExecPlan::ScheduleTask(std::function<Status(size_t)> fn) {
- return ToDerived(this)->ScheduleTask(std::move(fn));
-}
-int ExecPlan::RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
- std::function<Status(size_t)> on_finished) {
- return ToDerived(this)->RegisterTaskGroup(std::move(task), std::move(on_finished));
-}
-Status ExecPlan::StartTaskGroup(int task_group_id, int64_t num_tasks) {
- return ToDerived(this)->StartTaskGroup(task_group_id, num_tasks);
-}
-
-util::AsyncTaskScheduler* ExecPlan::async_scheduler() {
- return ToDerived(this)->async_scheduler();
-}
+QueryContext* ExecPlan::query_context() { return &ToDerived(this)->query_context_; }
Status ExecPlan::Validate() { return ToDerived(this)->Validate(); }
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h
index 44cd1acf87..52265b4f28 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -26,6 +26,7 @@
#include <utility>
#include <vector>
+#include "arrow/compute/exec.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/type_fwd.h"
#include "arrow/util/future.h"
@@ -49,11 +50,15 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this<ExecPlan> {
virtual ~ExecPlan() = default;
- ExecContext* exec_context() const { return exec_context_; }
+ QueryContext* query_context();
/// Make an empty exec plan
static Result<std::shared_ptr<ExecPlan>> Make(
- ExecContext* = default_exec_context(),
+ QueryOptions options, ExecContext* exec_context = default_exec_context(),
+ std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
+
+ static Result<std::shared_ptr<ExecPlan>> Make(
+ ExecContext* exec_context = default_exec_context(),
std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR);
ExecNode* AddNode(std::unique_ptr<ExecNode> node);
@@ -66,62 +71,6 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this<ExecPlan> {
return out;
}
- /// \brief Returns the index of the current thread.
- size_t GetThreadIndex();
- /// \brief Returns the maximum number of threads that the plan could use.
- ///
- /// GetThreadIndex will always return something less than this, so it is safe to
- /// e.g. make an array of thread-locals off this.
- size_t max_concurrency() const;
-
- /// \brief Start an external task
- ///
- /// This should be avoided if possible. It is kept in for now for legacy
- /// purposes. This should be called before the external task is started. If
- /// a valid future is returned then it should be marked complete when the
- /// external task has finished.
- ///
- /// \return an invalid future if the plan has already ended, otherwise this
- /// returns a future that must be completed when the external task
- /// finishes.
- Result<Future<>> BeginExternalTask();
-
- /// \brief Add a single function as a task to the plan's task group.
- ///
- /// \param fn The task to run. Takes no arguments and returns a Status.
- Status ScheduleTask(std::function<Status()> fn);
-
- /// \brief Add a single function as a task to the plan's task group.
- ///
- /// \param fn The task to run. Takes the thread index and returns a Status.
- Status ScheduleTask(std::function<Status(size_t)> fn);
- // Register/Start TaskGroup is a way of performing a "Parallel For" pattern:
- // - The task function takes the thread index and the index of the task
- // - The on_finished function takes the thread index
- // Returns an integer ID that will be used to reference the task group in
- // StartTaskGroup. At runtime, call StartTaskGroup with the ID and the number of times
- // you'd like the task to be executed. The need to register a task group before use will
- // be removed after we rewrite the scheduler.
- /// \brief Register a "parallel for" task group with the scheduler
- ///
- /// \param task The function implementing the task. Takes the thread_index and
- /// the task index.
- /// \param on_finished The function that gets run once all tasks have been completed.
- /// Takes the thread_index.
- ///
- /// Must be called inside of ExecNode::Init.
- int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
- std::function<Status(size_t)> on_finished);
-
- /// \brief Start the task group with the specified ID. This can only
- /// be called once per task_group_id.
- ///
- /// \param task_group_id The ID of the task group to run
- /// \param num_tasks The number of times to run the task
- Status StartTaskGroup(int task_group_id, int64_t num_tasks);
-
- util::AsyncTaskScheduler* async_scheduler();
-
/// The initial inputs
const NodeVector& sources() const;
@@ -151,25 +100,7 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this<ExecPlan> {
/// \brief Return the plan's attached metadata
std::shared_ptr<const KeyValueMetadata> metadata() const;
- /// \brief Should the plan use a legacy batching strategy
- ///
- /// This is currently in place only to support the Scanner::ToTable
- /// method. This method relies on batch indices from the scanner
- /// remaining consistent. This is impractical in the ExecPlan which
- /// might slice batches as needed (e.g. for a join)
- ///
- /// However, it still works for simple plans and this is the only way
- /// we have at the moment for maintaining implicit order.
- bool UseLegacyBatching() const { return use_legacy_batching_; }
- // For internal use only, see above comment
- void SetUseLegacyBatching(bool value) { use_legacy_batching_ = value; }
-
std::string ToString() const;
-
- protected:
- ExecContext* exec_context_;
- bool use_legacy_batching_ = false;
- explicit ExecPlan(ExecContext* exec_context) : exec_context_(exec_context) {}
};
class ARROW_EXPORT ExecNode {
diff --git a/cpp/src/arrow/compute/exec/filter_node.cc b/cpp/src/arrow/compute/exec/filter_node.cc
index 19d18ca608..8274453b6c 100644
--- a/cpp/src/arrow/compute/exec/filter_node.cc
+++ b/cpp/src/arrow/compute/exec/filter_node.cc
@@ -21,6 +21,7 @@
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/map_node.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/datum.h"
#include "arrow/result.h"
#include "arrow/util/checked_cast.h"
@@ -51,8 +52,9 @@ class FilterNode : public MapNode {
auto filter_expression = filter_options.filter_expression;
if (!filter_expression.IsBound()) {
- ARROW_ASSIGN_OR_RAISE(filter_expression,
- filter_expression.Bind(*schema, plan->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(
+ filter_expression,
+ filter_expression.Bind(*schema, plan->query_context()->exec_context()));
}
if (filter_expression.type()->id() != Type::BOOL) {
@@ -76,8 +78,9 @@ class FilterNode : public MapNode {
{"filter.expression.simplified", simplified_filter.ToString()},
{"filter.length", target.length}});
- ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target,
- plan()->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(
+ Datum mask, ExecuteScalarExpression(simplified_filter, target,
+ plan()->query_context()->exec_context()));
if (mask.is_scalar()) {
const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc
index da1710fe08..ffd93591e6 100644
--- a/cpp/src/arrow/compute/exec/hash_join.cc
+++ b/cpp/src/arrow/compute/exec/hash_join.cc
@@ -39,7 +39,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
struct ThreadLocalState;
public:
- Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads,
+ Status Init(QueryContext* ctx, JoinType join_type, size_t num_threads,
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
@@ -98,7 +98,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
for (int icol = 0; icol < num_cols; ++icol) {
data_types[icol] = schema_[side]->data_type(projection_handle, icol);
}
- encoder->Init(data_types, ctx_);
+ encoder->Init(data_types, ctx_->exec_context());
encoder->Clear();
}
@@ -296,8 +296,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
AppendFields(left_to_key, left_to_pay, left_key, left_payload);
AppendFields(right_to_key, right_to_pay, right_key, right_payload);
- ARROW_ASSIGN_OR_RAISE(Datum mask,
- ExecuteScalarExpression(filter_, concatenated, ctx_));
+ ARROW_ASSIGN_OR_RAISE(
+ Datum mask, ExecuteScalarExpression(filter_, concatenated, ctx_->exec_context()));
size_t num_probed_rows = match.size() + no_match.size();
if (mask.is_scalar()) {
@@ -397,7 +397,8 @@ class HashJoinBasicImpl : public HashJoinImpl {
ARROW_ASSIGN_OR_RAISE(right_key,
hash_table_keys_.Decode(batch_size_next, opt_right_ids));
// Post process build side keys that use dictionary
- RETURN_NOT_OK(dict_build_.PostDecode(*schema_[1], &right_key, ctx_));
+ RETURN_NOT_OK(
+ dict_build_.PostDecode(*schema_[1], &right_key, ctx_->exec_context()));
}
if (has_right_payload) {
ARROW_ASSIGN_OR_RAISE(right_payload,
@@ -509,13 +510,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
local_state.match_left.clear();
local_state.match_right.clear();
- bool use_key_batch_for_dicts =
- dict_probe_.BatchRemapNeeded(thread_index, *schema_[0], *schema_[1], ctx_);
+ bool use_key_batch_for_dicts = dict_probe_.BatchRemapNeeded(
+ thread_index, *schema_[0], *schema_[1], ctx_->exec_context());
RowEncoder* row_encoder_for_lookups = &local_state.exec_batch_keys;
if (use_key_batch_for_dicts) {
- RETURN_NOT_OK(dict_probe_.EncodeBatch(thread_index, *schema_[0], *schema_[1],
- dict_build_, batch, &row_encoder_for_lookups,
- &batch_key_for_lookups, ctx_));
+ RETURN_NOT_OK(dict_probe_.EncodeBatch(
+ thread_index, *schema_[0], *schema_[1], dict_build_, batch,
+ &row_encoder_for_lookups, &batch_key_for_lookups, ctx_->exec_context()));
}
// Collect information about all nulls in key columns.
@@ -560,7 +561,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
Status BuildHashTable_exec_task(size_t thread_index, int64_t /*task_id*/) {
AccumulationQueue batches = std::move(build_batches_);
- dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_);
+ dict_build_.InitEncoder(*schema_[1], &hash_table_keys_, ctx_->exec_context());
bool has_payload = (schema_[1]->num_cols(HashJoinProjection::PAYLOAD) > 0);
if (has_payload) {
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
@@ -577,11 +578,11 @@ class HashJoinBasicImpl : public HashJoinImpl {
} else if (hash_table_empty_) {
hash_table_empty_ = false;
- RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_));
+ RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_->exec_context()));
}
int32_t num_rows_before = hash_table_keys_.num_rows();
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch,
- &hash_table_keys_, ctx_));
+ &hash_table_keys_, ctx_->exec_context()));
if (has_payload) {
RETURN_NOT_OK(
EncodeBatch(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_, batch));
@@ -593,7 +594,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
if (hash_table_empty_) {
- RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_));
+ RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_->exec_context()));
}
return Status::OK();
@@ -734,7 +735,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
// Metadata
//
- ExecContext* ctx_;
+ QueryContext* ctx_;
JoinType join_type_;
size_t num_threads_;
const HashJoinProjectionMaps* schema_[2];
diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h
index 0c5e43467e..bc053b2f1b 100644
--- a/cpp/src/arrow/compute/exec/hash_join.h
+++ b/cpp/src/arrow/compute/exec/hash_join.h
@@ -24,6 +24,7 @@
#include "arrow/compute/exec/accumulation_queue.h"
#include "arrow/compute/exec/bloom_filter.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/schema_util.h"
#include "arrow/compute/exec/task_util.h"
#include "arrow/result.h"
@@ -47,7 +48,7 @@ class HashJoinImpl {
using AbortContinuationImpl = std::function<void()>;
virtual ~HashJoinImpl() = default;
- virtual Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads,
+ virtual Status Init(QueryContext* ctx, JoinType join_type, size_t num_threads,
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
diff --git a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
index a59ec03749..cc85251f8c 100644
--- a/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_benchmark.cc
@@ -125,9 +125,6 @@ class JoinBenchmark {
stats_.num_probe_rows = settings.num_probe_batches * settings.batch_size;
- ctx_ = std::make_unique<ExecContext>(default_memory_pool(),
- arrow::internal::GetCpuThreadPool());
-
schema_mgr_ = std::make_unique<HashJoinSchema>();
Expression filter = literal(true);
DCHECK_OK(schema_mgr_->Init(settings.join_type, *l_batches_with_schema.schema,
@@ -148,6 +145,7 @@ class JoinBenchmark {
};
scheduler_ = TaskScheduler::Make();
+ DCHECK_OK(ctx_.Init(settings.num_threads, nullptr));
auto register_task_group_callback = [&](std::function<Status(size_t, int64_t)> task,
std::function<Status(size_t)> cont) {
@@ -159,11 +157,10 @@ class JoinBenchmark {
};
DCHECK_OK(join_->Init(
- ctx_.get(), settings.join_type, settings.num_threads,
- &(schema_mgr_->proj_maps[0]), &(schema_mgr_->proj_maps[1]), std::move(key_cmp),
- std::move(filter), std::move(register_task_group_callback),
- std::move(start_task_group_callback), [](int64_t, ExecBatch) {},
- [](int64_t x) {}));
+ &ctx_, settings.join_type, settings.num_threads, &(schema_mgr_->proj_maps[0]),
+ &(schema_mgr_->proj_maps[1]), std::move(key_cmp), std::move(filter),
+ std::move(register_task_group_callback), std::move(start_task_group_callback),
+ [](int64_t, ExecBatch) {}, [](int64_t x) {}));
task_group_probe_ = scheduler_->RegisterTaskGroup(
[this](size_t thread_index, int64_t task_id) -> Status {
@@ -199,7 +196,7 @@ class JoinBenchmark {
AccumulationQueue r_batches_;
std::unique_ptr<HashJoinSchema> schema_mgr_;
std::unique_ptr<HashJoinImpl> join_;
- std::unique_ptr<ExecContext> ctx_;
+ QueryContext ctx_;
int task_group_probe_;
struct {
diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc
index 666ed92ec0..be3b01eb7f 100644
--- a/cpp/src/arrow/compute/exec/hash_join_node.cc
+++ b/cpp/src/arrow/compute/exec/hash_join_node.cc
@@ -490,7 +490,7 @@ struct BloomFilterPushdownContext {
std::function<Status(size_t, int64_t)>, std::function<Status(size_t)>)>;
using StartTaskGroupCallback = std::function<Status(int, int64_t)>;
using BuildFinishedCallback = std::function<Status(size_t, AccumulationQueue)>;
- using FiltersReceivedCallback = std::function<Status()>;
+ using FiltersReceivedCallback = std::function<Status(size_t)>;
using FilterFinishedCallback = std::function<Status(size_t, AccumulationQueue)>;
void Init(HashJoinNode* owner, size_t num_threads,
RegisterTaskGroupCallback register_task_group_callback,
@@ -498,7 +498,7 @@ struct BloomFilterPushdownContext {
FiltersReceivedCallback on_bloom_filters_received, bool disable_bloom_filter,
bool use_sync_execution);
- Status StartProducing();
+ Status StartProducing(size_t thread_index);
void ExpectBloomFilter() { eval_.num_expected_bloom_filters_ += 1; }
@@ -508,10 +508,11 @@ struct BloomFilterPushdownContext {
BuildFinishedCallback on_finished);
// Sends the Bloom filter to the pushdown target.
- Status PushBloomFilter();
+ Status PushBloomFilter(size_t thread_index);
// Receives a Bloom filter and its associated column map.
- Status ReceiveBloomFilter(std::unique_ptr<BlockedBloomFilter> filter,
+ Status ReceiveBloomFilter(size_t thread_index,
+ std::unique_ptr<BlockedBloomFilter> filter,
std::vector<int> column_map) {
bool proceed;
{
@@ -524,7 +525,7 @@ struct BloomFilterPushdownContext {
ARROW_DCHECK_LE(eval_.received_filters_.size(), eval_.num_expected_bloom_filters_);
}
if (proceed) {
- return eval_.all_received_callback_();
+ return eval_.all_received_callback_(thread_index);
}
return Status::OK();
}
@@ -553,7 +554,8 @@ struct BloomFilterPushdownContext {
std::vector<uint32_t> hashes(batch.length);
std::vector<uint8_t> bv(bit_vector_bytes);
- ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * stack, GetStack(thread_index));
+ ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * stack,
+ ctx_->GetTempStack(thread_index));
// Start with full selection for the current batch
memset(selected.data(), 0xff, bit_vector_bytes);
@@ -587,8 +589,8 @@ struct BloomFilterPushdownContext {
size_t first_nonscalar = batch.values.size();
for (size_t i = 0; i < batch.values.size(); i++) {
if (!batch.values[i].is_scalar()) {
- ARROW_ASSIGN_OR_RAISE(batch.values[i],
- Filter(batch.values[i], selected_datum, options, ctx_));
+ ARROW_ASSIGN_OR_RAISE(batch.values[i], Filter(batch.values[i], selected_datum,
+ options, ctx_->exec_context()));
first_nonscalar = std::min(first_nonscalar, i);
ARROW_DCHECK_EQ(batch.values[i].length(), batch.values[first_nonscalar].length());
}
@@ -619,25 +621,10 @@ struct BloomFilterPushdownContext {
// the disable_bloom_filter_ flag.
std::pair<HashJoinNode*, std::vector<int>> GetPushdownTarget(HashJoinNode* start);
- Result<util::TempVectorStack*> GetStack(size_t thread_index) {
- if (!tld_[thread_index].is_init) {
- RETURN_NOT_OK(tld_[thread_index].stack.Init(
- ctx_->memory_pool(), 4 * util::MiniBatch::kMiniBatchLength * sizeof(uint32_t)));
- tld_[thread_index].is_init = true;
- }
- return &tld_[thread_index].stack;
- }
-
StartTaskGroupCallback start_task_group_callback_;
bool disable_bloom_filter_;
HashJoinSchema* schema_mgr_;
- ExecContext* ctx_;
-
- struct ThreadLocalData {
- bool is_init = false;
- util::TempVectorStack stack;
- };
- std::vector<ThreadLocalData> tld_;
+ QueryContext* ctx_;
struct {
int task_id_;
@@ -736,9 +723,10 @@ class HashJoinNode : public ExecNode {
join_options.output_suffix_for_left, join_options.output_suffix_for_right));
}
- ARROW_ASSIGN_OR_RAISE(Expression filter,
- schema_mgr->BindFilter(join_options.filter, left_schema,
- right_schema, plan->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(
+ Expression filter,
+ schema_mgr->BindFilter(join_options.filter, left_schema, right_schema,
+ plan->query_context()->exec_context()));
// Generate output schema
std::shared_ptr<Schema> output_schema = schema_mgr->MakeOutputSchema(
@@ -786,7 +774,7 @@ class HashJoinNode : public ExecNode {
}
Status OnBloomFilterFinished(size_t thread_index, AccumulationQueue batches) {
- RETURN_NOT_OK(pushdown_context_.PushBloomFilter());
+ RETURN_NOT_OK(pushdown_context_.PushBloomFilter(thread_index));
return impl_->BuildHashTable(
thread_index, std::move(batches),
[this](size_t thread_index) { return OnHashTableFinished(thread_index); });
@@ -837,10 +825,9 @@ class HashJoinNode : public ExecNode {
return Status::OK();
}
- Status OnFiltersReceived() {
+ Status OnFiltersReceived(size_t thread_index) {
std::unique_lock<std::mutex> guard(probe_side_mutex_);
bloom_filters_ready_ = true;
- size_t thread_index = plan_->GetThreadIndex();
AccumulationQueue batches = std::move(probe_accumulator_);
guard.unlock();
return pushdown_context_.FilterBatches(
@@ -869,8 +856,8 @@ class HashJoinNode : public ExecNode {
std::lock_guard<std::mutex> guard(probe_side_mutex_);
queued_batches_to_probe_ = std::move(probe_accumulator_);
}
- return plan_->StartTaskGroup(task_group_probe_,
- queued_batches_to_probe_.batch_count());
+ return plan_->query_context()->StartTaskGroup(task_group_probe_,
+ queued_batches_to_probe_.batch_count());
}
Status OnQueuedBatchesProbed(size_t thread_index) {
@@ -891,7 +878,7 @@ class HashJoinNode : public ExecNode {
return;
}
- size_t thread_index = plan_->GetThreadIndex();
+ size_t thread_index = plan_->query_context()->GetThreadIndex();
int side = (input == inputs_[0]) ? 0 : 1;
EVENT(span_, "InputReceived", {{"batch.length", batch.length}, {"side", side}});
@@ -929,7 +916,7 @@ class HashJoinNode : public ExecNode {
void InputFinished(ExecNode* input, int total_batches) override {
ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());
- size_t thread_index = plan_->GetThreadIndex();
+ size_t thread_index = plan_->query_context()->GetThreadIndex();
int side = (input == inputs_[0]) ? 0 : 1;
EVENT(span_, "InputFinished", {{"side", side}, {"batches.length", total_batches}});
@@ -947,13 +934,14 @@ class HashJoinNode : public ExecNode {
}
Status Init() override {
- RETURN_NOT_OK(ExecNode::Init());
- if (plan_->UseLegacyBatching()) {
+ QueryContext* ctx = plan_->query_context();
+ if (ctx->options().use_legacy_batching) {
return Status::Invalid(
"The plan was configured to use legacy batching but contained a join node "
"which is incompatible with legacy batching");
}
- bool use_sync_execution = !(plan_->exec_context()->executor());
+
+ bool use_sync_execution = !(ctx->executor());
// TODO(ARROW-15732)
// Each side of join might have an IO thread being called from. Once this is fixed
// we will change it back to just the CPU's thread pool capacity.
@@ -961,32 +949,32 @@ class HashJoinNode : public ExecNode {
pushdown_context_.Init(
this, num_threads,
- [this](std::function<Status(size_t, int64_t)> fn,
- std::function<Status(size_t)> on_finished) {
- return plan_->RegisterTaskGroup(std::move(fn), std::move(on_finished));
+ [ctx](std::function<Status(size_t, int64_t)> fn,
+ std::function<Status(size_t)> on_finished) {
+ return ctx->RegisterTaskGroup(std::move(fn), std::move(on_finished));
},
- [this](int task_group_id, int64_t num_tasks) {
- return plan_->StartTaskGroup(task_group_id, num_tasks);
+ [ctx](int task_group_id, int64_t num_tasks) {
+ return ctx->StartTaskGroup(task_group_id, num_tasks);
},
- [this]() { return OnFiltersReceived(); }, disable_bloom_filter_,
- use_sync_execution);
+ [this](size_t thread_index) { return OnFiltersReceived(thread_index); },
+ disable_bloom_filter_, use_sync_execution);
RETURN_NOT_OK(impl_->Init(
- plan_->exec_context(), join_type_, num_threads, &(schema_mgr_->proj_maps[0]),
+ ctx, join_type_, num_threads, &(schema_mgr_->proj_maps[0]),
&(schema_mgr_->proj_maps[1]), key_cmp_, filter_,
- [this](std::function<Status(size_t, int64_t)> fn,
- std::function<Status(size_t)> on_finished) {
- return plan_->RegisterTaskGroup(std::move(fn), std::move(on_finished));
+ [ctx](std::function<Status(size_t, int64_t)> fn,
+ std::function<Status(size_t)> on_finished) {
+ return ctx->RegisterTaskGroup(std::move(fn), std::move(on_finished));
},
- [this](int task_group_id, int64_t num_tasks) {
- return plan_->StartTaskGroup(task_group_id, num_tasks);
+ [ctx](int task_group_id, int64_t num_tasks) {
+ return ctx->StartTaskGroup(task_group_id, num_tasks);
},
[this](int64_t, ExecBatch batch) { this->OutputBatchCallback(batch); },
[this](int64_t total_num_batches) {
this->FinishedCallback(total_num_batches);
}));
- task_group_probe_ = plan_->RegisterTaskGroup(
+ task_group_probe_ = ctx->RegisterTaskGroup(
[this](size_t thread_index, int64_t task_id) -> Status {
return impl_->ProbeSingleBatch(thread_index,
std::move(queued_batches_to_probe_[task_id]));
@@ -1004,7 +992,8 @@ class HashJoinNode : public ExecNode {
{"node.detail", ToString()},
{"node.kind", kind_name()}});
END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
- RETURN_NOT_OK(pushdown_context_.StartProducing());
+ RETURN_NOT_OK(
+ pushdown_context_.StartProducing(plan_->query_context()->GetThreadIndex()));
return Status::OK();
}
@@ -1084,8 +1073,7 @@ void BloomFilterPushdownContext::Init(
FiltersReceivedCallback on_bloom_filters_received, bool disable_bloom_filter,
bool use_sync_execution) {
schema_mgr_ = owner->schema_mgr_.get();
- ctx_ = owner->plan_->exec_context();
- tld_.resize(num_threads);
+ ctx_ = owner->plan_->query_context();
disable_bloom_filter_ = disable_bloom_filter;
std::tie(push_.pushdown_target_, push_.column_map_) = GetPushdownTarget(owner);
eval_.all_received_callback_ = std::move(on_bloom_filters_received);
@@ -1117,8 +1105,9 @@ void BloomFilterPushdownContext::Init(
start_task_group_callback_ = std::move(start_task_group_callback);
}
-Status BloomFilterPushdownContext::StartProducing() {
- if (eval_.num_expected_bloom_filters_ == 0) return eval_.all_received_callback_();
+Status BloomFilterPushdownContext::StartProducing(size_t thread_index) {
+ if (eval_.num_expected_bloom_filters_ == 0)
+ return eval_.all_received_callback_(thread_index);
return Status::OK();
}
@@ -1132,7 +1121,7 @@ Status BloomFilterPushdownContext::BuildBloomFilter(size_t thread_index,
return build_.on_finished_(thread_index, std::move(build_.batches_));
RETURN_NOT_OK(build_.builder_->Begin(
- /*num_threads=*/tld_.size(), ctx_->cpu_info()->hardware_flags(),
+ /*num_threads=*/ctx_->max_concurrency(), ctx_->cpu_info()->hardware_flags(),
ctx_->memory_pool(), build_.batches_.row_count(), build_.batches_.batch_count(),
push_.bloom_filter_.get()));
@@ -1140,10 +1129,10 @@ Status BloomFilterPushdownContext::BuildBloomFilter(size_t thread_index,
/*num_tasks=*/build_.batches_.batch_count());
}
-Status BloomFilterPushdownContext::PushBloomFilter() {
+Status BloomFilterPushdownContext::PushBloomFilter(size_t thread_index) {
if (!disable_bloom_filter_)
return push_.pushdown_target_->pushdown_context_.ReceiveBloomFilter(
- std::move(push_.bloom_filter_), std::move(push_.column_map_));
+ thread_index, std::move(push_.bloom_filter_), std::move(push_.column_map_));
return Status::OK();
}
@@ -1164,7 +1153,7 @@ Status BloomFilterPushdownContext::BuildBloomFilter_exec_task(size_t thread_inde
}
ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(std::move(key_columns)));
- ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * stack, GetStack(thread_index));
+ ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * stack, ctx_->GetTempStack(thread_index));
util::TempVectorHolder<uint32_t> hash_holder(stack, util::MiniBatch::kMiniBatchLength);
uint32_t* hashes = hash_holder.mutable_data();
for (int64_t i = 0; i < key_batch.length; i += util::MiniBatch::kMiniBatchLength) {
diff --git a/cpp/src/arrow/compute/exec/project_node.cc b/cpp/src/arrow/compute/exec/project_node.cc
index 5ce5428a15..5e8c2245a2 100644
--- a/cpp/src/arrow/compute/exec/project_node.cc
+++ b/cpp/src/arrow/compute/exec/project_node.cc
@@ -23,6 +23,7 @@
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/map_node.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/util.h"
#include "arrow/datum.h"
#include "arrow/result.h"
@@ -64,8 +65,8 @@ class ProjectNode : public MapNode {
int i = 0;
for (auto& expr : exprs) {
if (!expr.IsBound()) {
- ARROW_ASSIGN_OR_RAISE(
- expr, expr.Bind(*inputs[0]->output_schema(), plan->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*inputs[0]->output_schema(),
+ plan->query_context()->exec_context()));
}
fields[i] = field(std::move(names[i]), expr.type()->GetSharedPtr());
++i;
@@ -87,8 +88,9 @@ class ProjectNode : public MapNode {
ARROW_ASSIGN_OR_RAISE(Expression simplified_expr,
SimplifyWithGuarantee(exprs_[i], target.guarantee));
- ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target,
- plan()->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(
+ values[i], ExecuteScalarExpression(simplified_expr, target,
+ plan()->query_context()->exec_context()));
}
return ExecBatch{std::move(values), target.length};
}
diff --git a/cpp/src/arrow/compute/exec/query_context.cc b/cpp/src/arrow/compute/exec/query_context.cc
new file mode 100644
index 0000000000..7957b42034
--- /dev/null
+++ b/cpp/src/arrow/compute/exec/query_context.cc
@@ -0,0 +1,95 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec/query_context.h"
+#include "arrow/util/cpu_info.h"
+#include "arrow/util/io_util.h"
+
+namespace arrow {
+using internal::CpuInfo;
+namespace compute {
+QueryOptions::QueryOptions() : use_legacy_batching(false) {}
+
+QueryContext::QueryContext(QueryOptions opts, ExecContext exec_context)
+ : options_(opts),
+ exec_context_(exec_context),
+ io_context_(exec_context_.memory_pool()) {}
+
+const CpuInfo* QueryContext::cpu_info() const { return CpuInfo::GetInstance(); }
+int64_t QueryContext::hardware_flags() const { return cpu_info()->hardware_flags(); }
+
+Status QueryContext::Init(size_t max_num_threads, util::AsyncTaskScheduler* scheduler) {
+ tld_.resize(max_num_threads);
+ async_scheduler_ = scheduler;
+ return Status::OK();
+}
+
+size_t QueryContext::GetThreadIndex() { return thread_indexer_(); }
+
+size_t QueryContext::max_concurrency() const { return thread_indexer_.Capacity(); }
+
+Result<util::TempVectorStack*> QueryContext::GetTempStack(size_t thread_index) {
+ if (!tld_[thread_index].is_init) {
+ RETURN_NOT_OK(tld_[thread_index].stack.Init(
+ memory_pool(), 8 * util::MiniBatch::kMiniBatchLength * sizeof(uint64_t)));
+ tld_[thread_index].is_init = true;
+ }
+ return &tld_[thread_index].stack;
+}
+
+Result<Future<>> QueryContext::BeginExternalTask() {
+ Future<> completion_future = Future<>::Make();
+ if (async_scheduler_->AddSimpleTask(
+ [completion_future] { return completion_future; })) {
+ return completion_future;
+ }
+ return Future<>{};
+}
+
+Status QueryContext::ScheduleTask(std::function<Status()> fn) {
+ ::arrow::internal::Executor* exec = executor();
+ if (!exec) return fn();
+ // Adds a task which submits fn to the executor and tracks its progress. If we're
+ // already stopping then the task is ignored and fn is not executed.
+ async_scheduler_->AddSimpleTask([exec, fn]() { return exec->Submit(std::move(fn)); });
+ return Status::OK();
+}
+
+Status QueryContext::ScheduleTask(std::function<Status(size_t)> fn) {
+ std::function<Status()> indexed_fn = [this, fn]() {
+ size_t thread_index = GetThreadIndex();
+ return fn(thread_index);
+ };
+ return ScheduleTask(std::move(indexed_fn));
+}
+
+Status QueryContext::ScheduleIOTask(std::function<Status()> fn) {
+ async_scheduler_->AddSimpleTask(
+ [this, fn]() { return io_context_.executor()->Submit(std::move(fn)); });
+ return Status::OK();
+}
+
+int QueryContext::RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
+ std::function<Status(size_t)> on_finished) {
+ return task_scheduler_->RegisterTaskGroup(std::move(task), std::move(on_finished));
+}
+
+Status QueryContext::StartTaskGroup(int task_group_id, int64_t num_tasks) {
+ return task_scheduler_->StartTaskGroup(GetThreadIndex(), task_group_id, num_tasks);
+}
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/query_context.h b/cpp/src/arrow/compute/exec/query_context.h
new file mode 100644
index 0000000000..12ddbc56fa
--- /dev/null
+++ b/cpp/src/arrow/compute/exec/query_context.h
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/task_util.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/io/interfaces.h"
+#include "arrow/util/async_util.h"
+
+#pragma once
+
+namespace arrow {
+
+using io::IOContext;
+namespace compute {
+struct ARROW_EXPORT QueryOptions {
+ QueryOptions();
+
+ /// \brief Should the plan use a legacy batching strategy
+ ///
+ /// This is currently in place only to support the Scanner::ToTable
+ /// method. This method relies on batch indices from the scanner
+ /// remaining consistent. This is impractical in the ExecPlan which
+ /// might slice batches as needed (e.g. for a join)
+ ///
+ /// However, it still works for simple plans and this is the only way
+ /// we have at the moment for maintaining implicit order.
+ bool use_legacy_batching;
+};
+
+class ARROW_EXPORT QueryContext {
+ public:
+ QueryContext(QueryOptions opts = {},
+ ExecContext exec_context = *default_exec_context());
+
+ Status Init(size_t max_num_threads, util::AsyncTaskScheduler* scheduler);
+
+ const ::arrow::internal::CpuInfo* cpu_info() const;
+ int64_t hardware_flags() const;
+ const QueryOptions& options() const { return options_; }
+ MemoryPool* memory_pool() const { return exec_context_.memory_pool(); }
+ ::arrow::internal::Executor* executor() const { return exec_context_.executor(); }
+ ExecContext* exec_context() { return &exec_context_; }
+ IOContext* io_context() { return &io_context_; }
+ TaskScheduler* scheduler() { return task_scheduler_.get(); }
+ util::AsyncTaskScheduler* async_scheduler() { return async_scheduler_; }
+
+ size_t GetThreadIndex();
+ size_t max_concurrency() const;
+ Result<util::TempVectorStack*> GetTempStack(size_t thread_index);
+
+ /// \brief Start an external task
+ ///
+ /// This should be avoided if possible. It is kept in for now for legacy
+ /// purposes. This should be called before the external task is started. If
+ /// a valid future is returned then it should be marked complete when the
+ /// external task has finished.
+ ///
+ /// \return an invalid future if the plan has already ended, otherwise this
+ /// returns a future that must be completed when the external task
+ /// finishes.
+ Result<Future<>> BeginExternalTask();
+
+ /// \brief Add a single function as a task to the query's task group
+ /// on the compute threadpool.
+ ///
+ /// \param fn The task to run. Takes no arguments and returns a Status.
+ Status ScheduleTask(std::function<Status()> fn);
+ /// \brief Add a single function as a task to the query's task group
+ /// on the compute threadpool.
+ ///
+ /// \param fn The task to run. Takes the thread index and returns a Status.
+ Status ScheduleTask(std::function<Status(size_t)> fn);
+ /// \brief Add a single function as a task to the query's task group on
+ /// the IO thread pool
+ ///
+ /// \param fn The task to run. Returns a status.
+ Status ScheduleIOTask(std::function<Status()> fn);
+
+ // Register/Start TaskGroup is a way of performing a "Parallel For" pattern:
+ // - The task function takes the thread index and the index of the task
+ // - The on_finished function takes the thread index
+ // Returns an integer ID that will be used to reference the task group in
+ // StartTaskGroup. At runtime, call StartTaskGroup with the ID and the number of times
+ // you'd like the task to be executed. The need to register a task group before use will
+ // be removed after we rewrite the scheduler.
+ /// \brief Register a "parallel for" task group with the scheduler
+ ///
+ /// \param task The function implementing the task. Takes the thread_index and
+ /// the task index.
+ /// \param on_finished The function that gets run once all tasks have been completed.
+ /// Takes the thread_index.
+ ///
+ /// Must be called inside of ExecNode::Init.
+ int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
+ std::function<Status(size_t)> on_finished);
+
+ /// \brief Start the task group with the specified ID. This can only
+ /// be called once per task_group_id.
+ ///
+ /// \param task_group_id The ID of the task group to run
+ /// \param num_tasks The number of times to run the task
+ Status StartTaskGroup(int task_group_id, int64_t num_tasks);
+
+ // This is an RAII class for keeping track of in-flight file IO. Useful for getting
+ // an estimate of memory use, and how much memory we expect to be freed soon.
+ // Returned by ReportTempFileIO.
+ struct [[nodiscard]] TempFileIOMark {
+ QueryContext* ctx_;
+ size_t bytes_;
+
+ TempFileIOMark(QueryContext* ctx, size_t bytes) : ctx_(ctx), bytes_(bytes) {
+ ctx_->in_flight_bytes_to_disk_.fetch_add(bytes_, std::memory_order_acquire);
+ }
+
+ ARROW_DISALLOW_COPY_AND_ASSIGN(TempFileIOMark);
+
+ ~TempFileIOMark() {
+ ctx_->in_flight_bytes_to_disk_.fetch_sub(bytes_, std::memory_order_release);
+ }
+ };
+
+ TempFileIOMark ReportTempFileIO(size_t bytes) { return {this, bytes}; }
+
+ size_t GetCurrentTempFileIO() { return in_flight_bytes_to_disk_.load(); }
+
+ private:
+ QueryOptions options_;
+ // To be replaced with Acero-specific context once scheduler is done and
+ // we don't need ExecContext for kernels
+ ExecContext exec_context_;
+ IOContext io_context_;
+
+ util::AsyncTaskScheduler* async_scheduler_ = NULLPTR;
+ std::unique_ptr<TaskScheduler> task_scheduler_ = TaskScheduler::Make();
+
+ ThreadIndexer thread_indexer_;
+ struct ThreadLocalData {
+ bool is_init = false;
+ util::TempVectorStack stack;
+ };
+ std::vector<ThreadLocalData> tld_;
+
+ std::atomic<size_t> in_flight_bytes_to_disk_{0};
+};
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index f69a2ebfb6..a1bfba945c 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -26,6 +26,7 @@
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/order_by_impl.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/util.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/datum.h"
@@ -380,7 +381,7 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl {
protected:
void Finish(const Status& finish_st) {
- plan_->async_scheduler()->AddSimpleTask([this, &finish_st] {
+ plan_->query_context()->async_scheduler()->AddSimpleTask([this, &finish_st] {
return consumer_->Finish().Then(
[this, finish_st]() {
finished_.MarkFinished(finish_st);
@@ -405,7 +406,7 @@ static Result<ExecNode*> MakeTableConsumingSinkNode(
const compute::ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "TableConsumingSinkNode"));
const auto& sink_options = checked_cast<const TableSinkNodeOptions&>(options);
- MemoryPool* pool = plan->exec_context()->memory_pool();
+ MemoryPool* pool = plan->query_context()->memory_pool();
auto tb_consumer =
std::make_shared<TableSinkNodeConsumer>(sink_options.output_table, pool);
auto consuming_sink_node_options = ConsumingSinkNodeOptions{tb_consumer};
@@ -436,8 +437,8 @@ struct OrderBySinkNode final : public SinkNode {
RETURN_NOT_OK(ValidateOrderByOptions(sink_options));
ARROW_ASSIGN_OR_RAISE(
std::unique_ptr<OrderByImpl> impl,
- OrderByImpl::MakeSort(plan->exec_context(), inputs[0]->output_schema(),
- sink_options.sort_options));
+ OrderByImpl::MakeSort(plan->query_context()->exec_context(),
+ inputs[0]->output_schema(), sink_options.sort_options));
return plan->EmplaceNode<OrderBySinkNode>(plan, std::move(inputs), std::move(impl),
sink_options.generator);
}
@@ -466,10 +467,10 @@ struct OrderBySinkNode final : public SinkNode {
return Status::Invalid("Backpressure cannot be applied to an OrderBySinkNode");
}
RETURN_NOT_OK(ValidateSelectKOptions(sink_options));
- ARROW_ASSIGN_OR_RAISE(
- std::unique_ptr<OrderByImpl> impl,
- OrderByImpl::MakeSelectK(plan->exec_context(), inputs[0]->output_schema(),
- sink_options.select_k_options));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr<OrderByImpl> impl,
+ OrderByImpl::MakeSelectK(plan->query_context()->exec_context(),
+ inputs[0]->output_schema(),
+ sink_options.select_k_options));
return plan->EmplaceNode<OrderBySinkNode>(plan, std::move(inputs), std::move(impl),
sink_options.generator);
}
@@ -491,7 +492,7 @@ struct OrderBySinkNode final : public SinkNode {
DCHECK_EQ(input, inputs_[0]);
auto maybe_batch = batch.ToRecordBatch(inputs_[0]->output_schema(),
- plan()->exec_context()->memory_pool());
+ plan()->query_context()->memory_pool());
if (ErrorIfNotOk(maybe_batch.status())) {
StopProducing();
if (input_counter_.Cancel()) {
diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc
index 3f8f4c9834..3d21dece97 100644
--- a/cpp/src/arrow/compute/exec/source_node.cc
+++ b/cpp/src/arrow/compute/exec/source_node.cc
@@ -23,6 +23,7 @@
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/util.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/datum.h"
@@ -91,7 +92,7 @@ struct SourceNode : ExecNode {
}
CallbackOptions options;
- auto executor = plan()->exec_context()->executor();
+ auto executor = plan()->query_context()->executor();
if (executor) {
// These options will transfer execution to the desired Executor if necessary.
// This can happen for in-memory scans where batches didn't require
@@ -100,7 +101,8 @@ struct SourceNode : ExecNode {
options.executor = executor;
options.should_schedule = ShouldSchedule::IfDifferentExecutor;
}
- ARROW_ASSIGN_OR_RAISE(Future<> scan_task, plan_->BeginExternalTask());
+ ARROW_ASSIGN_OR_RAISE(Future<> scan_task,
+ plan_->query_context()->BeginExternalTask());
if (!scan_task.is_valid()) {
finished_.MarkFinished();
// Plan has already been aborted, no need to start scanning
@@ -121,7 +123,8 @@ struct SourceNode : ExecNode {
return Break(batch_count_);
}
lock.unlock();
- bool use_legacy_batching = plan_->UseLegacyBatching();
+ bool use_legacy_batching =
+ plan_->query_context()->options().use_legacy_batching;
ExecBatch morsel = std::move(*maybe_morsel);
int64_t morsel_length = static_cast<int64_t>(morsel.length);
if (use_legacy_batching || morsel_length == 0) {
@@ -133,23 +136,22 @@ struct SourceNode : ExecNode {
bit_util::CeilDiv(morsel_length, ExecPlan::kMaxBatchSize));
batch_count_ += num_batches;
}
- RETURN_NOT_OK(plan_->ScheduleTask(
- [this, use_legacy_batching, morsel, morsel_length]() {
- int64_t offset = 0;
- do {
- int64_t batch_size = std::min<int64_t>(
- morsel_length - offset, ExecPlan::kMaxBatchSize);
- // In order for the legacy batching model to work we must
- // not slice batches from the source
- if (use_legacy_batching) {
- batch_size = morsel_length;
- }
- ExecBatch batch = morsel.Slice(offset, batch_size);
- offset += batch_size;
- outputs_[0]->InputReceived(this, std::move(batch));
- } while (offset < morsel.length);
- return Status::OK();
- }));
+ RETURN_NOT_OK(plan_->query_context()->ScheduleTask([=]() {
+ int64_t offset = 0;
+ do {
+ int64_t batch_size = std::min<int64_t>(
+ morsel_length - offset, ExecPlan::kMaxBatchSize);
+ // In order for the legacy batching model to work we must
+ // not slice batches from the source
+ if (use_legacy_batching) {
+ batch_size = morsel_length;
+ }
+ ExecBatch batch = morsel.Slice(offset, batch_size);
+ offset += batch_size;
+ outputs_[0]->InputReceived(this, std::move(batch));
+ } while (offset < morsel.length);
+ return Status::OK();
+ }));
lock.lock();
if (!backpressure_future_.is_finished()) {
EVENT(span_, "Source paused due to backpressure");
@@ -309,7 +311,7 @@ struct SchemaSourceNode : public SourceNode {
auto io_executor = cast_options.io_executor;
if (io_executor == NULLPTR) {
- io_executor = plan->exec_context()->executor();
+ io_executor = plan->query_context()->exec_context()->executor();
}
auto it = it_maker();
diff --git a/cpp/src/arrow/compute/exec/swiss_join.cc b/cpp/src/arrow/compute/exec/swiss_join.cc
index 5b01edb119..fee3c5f79d 100644
--- a/cpp/src/arrow/compute/exec/swiss_join.cc
+++ b/cpp/src/arrow/compute/exec/swiss_join.cc
@@ -2022,7 +2022,7 @@ Status JoinProbeProcessor::OnFinished() {
class SwissJoin : public HashJoinImpl {
public:
- Status Init(ExecContext* ctx, JoinType join_type, size_t num_threads,
+ Status Init(QueryContext* ctx, JoinType join_type, size_t num_threads,
const HashJoinProjectionMaps* proj_map_left,
const HashJoinProjectionMaps* proj_map_right,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
@@ -2067,8 +2067,6 @@ class SwissJoin : public HashJoinImpl {
for (int i = 0; i < num_threads_; ++i) {
local_states_[i].hash_table_ready = false;
local_states_[i].num_output_batches = 0;
- RETURN_NOT_OK(CancelIfNotOK(local_states_[i].temp_stack.Init(
- pool_, 1024 + 64 * util::MiniBatch::kMiniBatchLength)));
local_states_[i].materialize.Init(pool_, proj_map_left, proj_map_right);
}
@@ -2116,10 +2114,12 @@ class SwissJoin : public HashJoinImpl {
ExecBatch keypayload_batch;
ARROW_ASSIGN_OR_RAISE(keypayload_batch, KeyPayloadFromInput(/*side=*/0, &batch));
+ ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * temp_stack,
+ ctx_->GetTempStack(thread_index));
- return CancelIfNotOK(probe_processor_.OnNextBatch(
- thread_index, keypayload_batch, &local_states_[thread_index].temp_stack,
- &local_states_[thread_index].temp_column_arrays));
+ return CancelIfNotOK(
+ probe_processor_.OnNextBatch(thread_index, keypayload_batch, temp_stack,
+ &local_states_[thread_index].temp_column_arrays));
}
Status ProbingFinished(size_t thread_index) override {
@@ -2225,9 +2225,11 @@ class SwissJoin : public HashJoinImpl {
input_batch.values[schema->num_cols(HashJoinProjection::KEY) + icol];
}
}
+ ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * temp_stack,
+ ctx_->GetTempStack(thread_id));
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_.PushNextBatch(
static_cast<int64_t>(thread_id), key_batch, no_payload ? nullptr : &payload_batch,
- &local_states_[thread_id].temp_stack)));
+ temp_stack)));
// Release input batch
//
@@ -2259,7 +2261,9 @@ class SwissJoin : public HashJoinImpl {
Status MergeFinished(size_t thread_id) {
RETURN_NOT_OK(status());
- hash_table_build_.FinishPrtnMerge(&local_states_[thread_id].temp_stack);
+ ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * temp_stack,
+ ctx_->GetTempStack(thread_id));
+ hash_table_build_.FinishPrtnMerge(temp_stack);
return CancelIfNotOK(OnBuildHashTableFinished(static_cast<int64_t>(thread_id)));
}
@@ -2311,7 +2315,8 @@ class SwissJoin : public HashJoinImpl {
std::min((task_id + 1) * kNumRowsPerScanTask, hash_table_.num_rows());
// Get thread index and related temp vector stack
//
- util::TempVectorStack* temp_stack = &local_states_[thread_id].temp_stack;
+ ARROW_ASSIGN_OR_RAISE(util::TempVectorStack * temp_stack,
+ ctx_->GetTempStack(thread_id));
// Split into mini-batches
//
@@ -2467,7 +2472,7 @@ class SwissJoin : public HashJoinImpl {
static constexpr int kNumRowsPerScanTask = 512 * 1024;
- ExecContext* ctx_;
+ QueryContext* ctx_;
int64_t hardware_flags_;
MemoryPool* pool_;
int num_threads_;
@@ -2489,7 +2494,6 @@ class SwissJoin : public HashJoinImpl {
struct ThreadLocalState {
JoinResultMaterialize materialize;
- util::TempVectorStack temp_stack;
std::vector<KeyColumnArray> temp_column_arrays;
int64_t num_output_batches;
bool hash_table_ready;
diff --git a/cpp/src/arrow/compute/exec/tpch_node.cc b/cpp/src/arrow/compute/exec/tpch_node.cc
index 30dbd511e6..afff52beaf 100644
--- a/cpp/src/arrow/compute/exec/tpch_node.cc
+++ b/cpp/src/arrow/compute/exec/tpch_node.cc
@@ -31,6 +31,7 @@
#include "arrow/buffer.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/datum.h"
#include "arrow/util/async_util.h"
#include "arrow/util/formatting.h"
@@ -3381,8 +3382,8 @@ class TpchNode : public ExecNode {
Status StartProducing() override {
num_running_++;
- ARROW_RETURN_NOT_OK(generator_->StartProducing(
- plan_->max_concurrency(),
+ RETURN_NOT_OK(generator_->StartProducing(
+ plan_->query_context()->max_concurrency(),
[this](ExecBatch batch) { this->OutputBatchCallback(std::move(batch)); },
[this](int64_t num_batches) { this->FinishedCallback(num_batches); },
[this](std::function<Status(size_t)> func) -> Status {
@@ -3425,7 +3426,7 @@ class TpchNode : public ExecNode {
Status ScheduleTaskCallback(std::function<Status(size_t)> func) {
if (finished_generating_.load()) return Status::OK();
num_running_++;
- return plan_->ScheduleTask([this, func](size_t thread_index) {
+ return plan_->query_context()->ScheduleTask([this, func](size_t thread_index) {
Status status = func(thread_index);
if (!status.ok()) {
StopProducing();
diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h
index 67dc5a278b..70273e38e8 100644
--- a/cpp/src/arrow/compute/type_fwd.h
+++ b/cpp/src/arrow/compute/type_fwd.h
@@ -50,6 +50,8 @@ class ExecNode;
class ExecPlan;
class ExecNodeOptions;
class ExecFactoryRegistry;
+class QueryContext;
+struct QueryOptions;
class SinkNodeConsumer;
diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc
index ffa25332ad..7776a5018b 100644
--- a/cpp/src/arrow/dataset/file_base.cc
+++ b/cpp/src/arrow/dataset/file_base.cc
@@ -29,6 +29,7 @@
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/exec/forest_internal.h"
#include "arrow/compute/exec/map_node.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/subtree_internal.h"
#include "arrow/dataset/dataset_internal.h"
#include "arrow/dataset/dataset_writer.h"
@@ -401,7 +402,7 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer {
ARROW_ASSIGN_OR_RAISE(
dataset_writer_,
internal::DatasetWriter::Make(
- write_options_, plan->async_scheduler(),
+ write_options_, plan->query_context()->async_scheduler(),
[backpressure_control] { backpressure_control->Pause(); },
[backpressure_control] { backpressure_control->Resume(); }, [] {}));
return Status::OK();
@@ -516,10 +517,11 @@ class TeeNode : public compute::MapNode {
write_options_(std::move(write_options)) {}
Status StartProducing() override {
- ARROW_ASSIGN_OR_RAISE(dataset_writer_, internal::DatasetWriter::Make(
- write_options_, plan_->async_scheduler(),
- [this] { Pause(); }, [this] { Resume(); },
- [this] { MapNode::Finish(); }));
+ ARROW_ASSIGN_OR_RAISE(
+ dataset_writer_,
+ internal::DatasetWriter::Make(
+ write_options_, plan_->query_context()->async_scheduler(),
+ [this] { Pause(); }, [this] { Resume(); }, [this] { MapNode::Finish(); }));
return MapNode::StartProducing();
}
diff --git a/cpp/src/arrow/dataset/scan_node.cc b/cpp/src/arrow/dataset/scan_node.cc
index 4e644c1de7..07b2b9886c 100644
--- a/cpp/src/arrow/dataset/scan_node.cc
+++ b/cpp/src/arrow/dataset/scan_node.cc
@@ -24,6 +24,7 @@
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/expression.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/compute/exec/util.h"
#include "arrow/dataset/scanner.h"
#include "arrow/record_batch.h"
@@ -163,8 +164,9 @@ class ScanNode : public cp::ExecNode {
const cp::ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 0, "ScanNode"));
const auto& scan_options = checked_cast<const ScanV2Options&>(options);
- ARROW_ASSIGN_OR_RAISE(ScanV2Options normalized_options,
- NormalizeAndValidate(scan_options, plan->exec_context()));
+ ARROW_ASSIGN_OR_RAISE(
+ ScanV2Options normalized_options,
+ NormalizeAndValidate(scan_options, plan->query_context()->exec_context()));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Schema> output_schema,
OutputSchemaFromOptions(normalized_options));
return plan->EmplaceNode<ScanNode>(plan, std::move(normalized_options),
@@ -216,7 +218,7 @@ class ScanNode : public cp::ExecNode {
compute::ExecBatch evolved_batch,
scan_->fragment_evolution->EvolveBatch(batch, node_->options_.columns,
scan_->scan_request.columns));
- return node_->plan_->ScheduleTask(
+ return node_->plan_->query_context()->ScheduleTask(
[node = node_, evolved_batch = std::move(evolved_batch)] {
node->outputs_[0]->InputReceived(node, std::move(evolved_batch));
return Status::OK();
@@ -237,7 +239,8 @@ class ScanNode : public cp::ExecNode {
Result<Future<>> operator()() override {
return fragment
- ->InspectFragment(node->options_.format_options, node->plan_->exec_context())
+ ->InspectFragment(node->options_.format_options,
+ node->plan_->query_context()->exec_context())
.Then([this](const std::shared_ptr<InspectedFragment>& inspected_fragment) {
return BeginScan(inspected_fragment);
});
@@ -252,7 +255,8 @@ class ScanNode : public cp::ExecNode {
ARROW_RETURN_NOT_OK(InitFragmentScanRequest());
return fragment
->BeginScan(scan_state->scan_request, *inspected_fragment,
- node->options_.format_options, node->plan_->exec_context())
+ node->options_.format_options,
+ node->plan_->query_context()->exec_context())
.Then([this](const std::shared_ptr<FragmentScanner>& fragment_scanner) {
return AddScanTasks(fragment_scanner);
});
@@ -312,7 +316,7 @@ class ScanNode : public cp::ExecNode {
void ScanFragments(const AsyncGenerator<std::shared_ptr<Fragment>>& frag_gen) {
std::shared_ptr<util::AsyncTaskScheduler> fragment_tasks =
util::MakeThrottledAsyncTaskGroup(
- plan_->async_scheduler(), options_.fragment_readahead + 1,
+ plan_->query_context()->async_scheduler(), options_.fragment_readahead + 1,
/*queue=*/nullptr, [this]() {
outputs_[0]->InputFinished(this, num_batches_.load());
finished_.MarkFinished();
@@ -334,8 +338,8 @@ class ScanNode : public cp::ExecNode {
{"node.detail", ToString()}});
END_SPAN_ON_FUTURE_COMPLETION(span_, finished_);
batches_throttle_ = util::ThrottledAsyncTaskScheduler::Make(
- plan_->async_scheduler(), options_.target_bytes_readahead + 1);
- plan_->async_scheduler()->AddSimpleTask([this] {
+ plan_->query_context()->async_scheduler(), options_.target_bytes_readahead + 1);
+ plan_->query_context()->async_scheduler()->AddSimpleTask([this] {
return GetFragments(options_.dataset.get(), options_.filter)
.Then([this](const AsyncGenerator<std::shared_ptr<Fragment>>& frag_gen) {
ScanFragments(frag_gen);
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
index 198c867cb5..c3d016a375 100644
--- a/cpp/src/arrow/dataset/scanner.cc
+++ b/cpp/src/arrow/dataset/scanner.cc
@@ -32,6 +32,7 @@
#include "arrow/compute/cast.h"
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/query_context.h"
#include "arrow/dataset/dataset.h"
#include "arrow/dataset/dataset_internal.h"
#include "arrow/dataset/plan.h"
@@ -422,8 +423,11 @@ Result<EnumeratedRecordBatchGenerator> AsyncScanner::ScanBatchesUnorderedAsync(
auto exec_context =
std::make_shared<compute::ExecContext>(scan_options_->pool, cpu_executor);
- ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context.get()));
- plan->SetUseLegacyBatching(use_legacy_batching);
+ compute::QueryOptions query_options;
+ query_options.use_legacy_batching = use_legacy_batching;
+
+ ARROW_ASSIGN_OR_RAISE(auto plan,
+ compute::ExecPlan::Make(query_options, exec_context.get()));
AsyncGenerator<std::optional<compute::ExecBatch>> sink_gen;
auto exprs = scan_options_->projection.call()->arguments;
diff --git a/cpp/src/arrow/util/io_util.cc b/cpp/src/arrow/util/io_util.cc
index 1f117de7b2..742865424d 100644
--- a/cpp/src/arrow/util/io_util.cc
+++ b/cpp/src/arrow/util/io_util.cc
@@ -108,8 +108,10 @@
#elif __APPLE__
#include <mach/mach.h>
+#include <sys/sysctl.h>
#elif __linux__
+#include <sys/sysinfo.h>
#include <fstream>
#endif
@@ -2164,5 +2166,34 @@ int64_t GetCurrentRSS() {
#endif
}
+int64_t GetTotalMemoryBytes() {
+#if defined(_WIN32)
+ ULONGLONG result_kb;
+ if (!GetPhysicallyInstalledSystemMemory(&result_kb)) {
+ ARROW_LOG(WARNING) << "Failed to resolve total RAM size: "
+ << std::strerror(GetLastError());
+ return -1;
+ }
+ return static_cast<int64_t>(result_kb * 1024);
+#elif defined(__APPLE__)
+ int64_t result;
+ size_t size = sizeof(result);
+ if (sysctlbyname("hw.memsize", &result, &size, nullptr, 0) == -1) {
+ ARROW_LOG(WARNING) << "Failed to resolve total RAM size";
+ return -1;
+ }
+ return result;
+#elif defined(__linux__)
+ struct sysinfo info;
+ if (sysinfo(&info) == -1) {
+ ARROW_LOG(WARNING) << "Failed to resolve total RAM size: " << std::strerror(errno);
+ return -1;
+ }
+ return static_cast<int64_t>(info.totalram * info.mem_unit);
+#else
+ return 0;
+#endif
+}
+
} // namespace internal
} // namespace arrow
diff --git a/cpp/src/arrow/util/io_util.h b/cpp/src/arrow/util/io_util.h
index df63de47e8..43d85ec24e 100644
--- a/cpp/src/arrow/util/io_util.h
+++ b/cpp/src/arrow/util/io_util.h
@@ -410,5 +410,11 @@ uint64_t GetThreadId();
ARROW_EXPORT
int64_t GetCurrentRSS();
+/// \brief Get the total memory available to the system in bytes
+///
+/// This function supports Windows, Linux, and Mac and will return 0 otherwise
+ARROW_EXPORT
+int64_t GetTotalMemoryBytes();
+
} // namespace internal
} // namespace arrow
diff --git a/cpp/src/arrow/util/io_util_test.cc b/cpp/src/arrow/util/io_util_test.cc
index bb5113440b..2599c92d82 100644
--- a/cpp/src/arrow/util/io_util_test.cc
+++ b/cpp/src/arrow/util/io_util_test.cc
@@ -1080,5 +1080,17 @@ TEST(CpuInfo, Basic) {
ASSERT_EQ(ci->hardware_flags(), original_hardware_flags);
}
+TEST(Memory, TotalMemory) {
+#if defined(_WIN32)
+ ASSERT_GT(GetTotalMemoryBytes(), 0);
+#elif defined(__APPLE__)
+ ASSERT_GT(GetTotalMemoryBytes(), 0);
+#elif defined(__linux__)
+ ASSERT_GT(GetTotalMemoryBytes(), 0);
+#else
+ ASSERT_EQ(GetTotalMemoryBytes(), 0);
+#endif
+}
+
} // namespace internal
} // namespace arrow
diff --git a/cpp/src/arrow/util/rle_encoding.h b/cpp/src/arrow/util/rle_encoding.h
index cc90f658f0..09b2cda91e 100644
--- a/cpp/src/arrow/util/rle_encoding.h
+++ b/cpp/src/arrow/util/rle_encoding.h
@@ -139,7 +139,7 @@ class RleDecoder {
int64_t valid_bits_offset);
protected:
- bit_util::BitReader bit_reader_;
+ ::arrow::bit_util::BitReader bit_reader_;
/// Number of bits needed to encode the value. Must be between 0 and 64.
int bit_width_;
uint64_t current_value_;
@@ -186,12 +186,12 @@ class RleEncoder {
/// It is not valid to pass a buffer less than this length.
static int MinBufferSize(int bit_width) {
/// 1 indicator byte and MAX_VALUES_PER_LITERAL_RUN 'bit_width' values.
- int max_literal_run_size =
- 1 +
- static_cast<int>(bit_util::BytesForBits(MAX_VALUES_PER_LITERAL_RUN * bit_width));
+ int max_literal_run_size = 1 + static_cast<int>(::arrow::bit_util::BytesForBits(
+ MAX_VALUES_PER_LITERAL_RUN * bit_width));
/// Up to kMaxVlqByteLength indicator and a single 'bit_width' value.
- int max_repeated_run_size = bit_util::BitReader::kMaxVlqByteLength +
- static_cast<int>(bit_util::BytesForBits(bit_width));
+ int max_repeated_run_size =
+ ::arrow::bit_util::BitReader::kMaxVlqByteLength +
+ static_cast<int>(::arrow::bit_util::BytesForBits(bit_width));
return std::max(max_literal_run_size, max_repeated_run_size);
}
@@ -201,15 +201,16 @@ class RleEncoder {
// and then a repeated run of length 8".
// 8 values per smallest run, 8 bits per byte
int bytes_per_run = bit_width;
- int num_runs = static_cast<int>(bit_util::CeilDiv(num_values, 8));
+ int num_runs = static_cast<int>(::arrow::bit_util::CeilDiv(num_values, 8));
int literal_max_size = num_runs + num_runs * bytes_per_run;
// In the very worst case scenario, the data is a concatenation of repeated
// runs of 8 values. Repeated run has a 1 byte varint followed by the
// bit-packed repeated value
- int min_repeated_run_size = 1 + static_cast<int>(bit_util::BytesForBits(bit_width));
- int repeated_max_size =
- static_cast<int>(bit_util::CeilDiv(num_values, 8)) * min_repeated_run_size;
+ int min_repeated_run_size =
+ 1 + static_cast<int>(::arrow::bit_util::BytesForBits(bit_width));
+ int repeated_max_size = static_cast<int>(::arrow::bit_util::CeilDiv(num_values, 8)) *
+ min_repeated_run_size;
return std::max(literal_max_size, repeated_max_size);
}
@@ -259,7 +260,7 @@ class RleEncoder {
const int bit_width_;
/// Underlying buffer.
- bit_util::BitWriter bit_writer_;
+ ::arrow::bit_util::BitWriter bit_writer_;
/// If true, the buffer is full and subsequent Put()'s will fail.
bool buffer_full_;
@@ -660,8 +661,8 @@ bool RleDecoder::NextCounts() {
}
repeat_count_ = count;
T value = {};
- if (!bit_reader_.GetAligned<T>(static_cast<int>(bit_util::CeilDiv(bit_width_, 8)),
- &value)) {
+ if (!bit_reader_.GetAligned<T>(
+ static_cast<int>(::arrow::bit_util::CeilDiv(bit_width_, 8)), &value)) {
return false;
}
current_value_ = static_cast<uint64_t>(value);
@@ -738,8 +739,8 @@ inline void RleEncoder::FlushRepeatedRun() {
// The lsb of 0 indicates this is a repeated run
int32_t indicator_value = repeat_count_ << 1 | 0;
result &= bit_writer_.PutVlqInt(static_cast<uint32_t>(indicator_value));
- result &= bit_writer_.PutAligned(current_value_,
- static_cast<int>(bit_util::CeilDiv(bit_width_, 8)));
+ result &= bit_writer_.PutAligned(
+ current_value_, static_cast<int>(::arrow::bit_util::CeilDiv(bit_width_, 8)));
DCHECK(result);
num_buffered_values_ = 0;
repeat_count_ = 0;
diff --git a/cpp/src/arrow/util/type_fwd.h b/cpp/src/arrow/util/type_fwd.h
index 976a22bb0b..76e685ffa6 100644
--- a/cpp/src/arrow/util/type_fwd.h
+++ b/cpp/src/arrow/util/type_fwd.h
@@ -34,6 +34,7 @@ namespace internal {
class Executor;
class TaskGroup;
class ThreadPool;
+class CpuInfo;
} // namespace internal