You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/07/12 20:09:38 UTC

[GitHub] [arrow] bkietz commented on a change in pull request #10705: ARROW-13313: [C++][Compute] Add scalar aggregate node

bkietz commented on a change in pull request #10705:
URL: https://github.com/apache/arrow/pull/10705#discussion_r668221912



##########
File path: cpp/src/arrow/compute/exec/exec_plan.cc
##########
@@ -601,5 +618,215 @@ AsyncGenerator<util::optional<ExecBatch>> MakeSinkNode(ExecNode* input,
   return out;
 }
 
+std::shared_ptr<RecordBatchReader> MakeSinkNodeReader(ExecNode* input,
+                                                      std::string label) {
+  struct Impl : RecordBatchReader {
+    std::shared_ptr<Schema> schema() const override { return schema_; }
+    Status ReadNext(std::shared_ptr<RecordBatch>* record_batch) override {
+      ARROW_ASSIGN_OR_RAISE(auto batch, iterator_.Next());
+      if (batch) {
+        ARROW_ASSIGN_OR_RAISE(*record_batch, batch->ToRecordBatch(schema_, pool_));
+      } else {
+        *record_batch = IterationEnd<std::shared_ptr<RecordBatch>>();
+      }
+      return Status::OK();
+    }
+
+    MemoryPool* pool_;
+    std::shared_ptr<Schema> schema_;
+    Iterator<util::optional<ExecBatch>> iterator_;
+  };
+
+  auto out = std::make_shared<Impl>();
+  out->pool_ = input->plan()->exec_context()->memory_pool();
+  out->schema_ = input->output_schema();
+  out->iterator_ = MakeGeneratorIterator(MakeSinkNode(input, std::move(label)));
+  return out;
+}
+
+struct ScalarAggregateNode : ExecNode {
+  ScalarAggregateNode(ExecNode* input, std::string label,
+                      std::shared_ptr<Schema> output_schema,
+                      std::vector<const ScalarAggregateKernel*> kernels,
+                      std::vector<std::vector<std::unique_ptr<KernelState>>> states)
+      : ExecNode(input->plan(), std::move(label), {input}, {"target"},
+                 /*output_schema=*/std::move(output_schema),
+                 /*num_outputs=*/1),
+        kernels_(std::move(kernels)),
+        states_(std::move(states)) {}
+
+  const char* kind_name() override { return "ScalarAggregateNode"; }
+
+  Status DoConsume(const ExecBatch& batch,
+                   const std::vector<std::unique_ptr<KernelState>>& states) {
+    for (size_t i = 0; i < states.size(); ++i) {
+      KernelContext batch_ctx{plan()->exec_context()};
+      batch_ctx.SetState(states[i].get());
+      ExecBatch single_column_batch{{batch.values[i]}, batch.length};
+      RETURN_NOT_OK(kernels_[i]->consume(&batch_ctx, single_column_batch));
+    }
+    return Status::OK();
+  }
+
+  void InputReceived(ExecNode* input, int seq, ExecBatch batch) override {
+    DCHECK_EQ(input, inputs_[0]);
+
+    std::unique_lock<std::mutex> lock(mutex_);
+    auto it =
+        thread_indices_.emplace(std::this_thread::get_id(), thread_indices_.size()).first;
+    ++num_received_;
+    auto thread_index = it->second;
+
+    lock.unlock();
+
+    const auto& thread_local_state = states_[thread_index];
+    Status st = DoConsume(std::move(batch), thread_local_state);
+    if (!st.ok()) {
+      outputs_[0]->ErrorReceived(this, std::move(st));
+      return;
+    }
+
+    lock.lock();
+    st = MaybeFinish(&lock);
+    if (!st.ok()) {
+      outputs_[0]->ErrorReceived(this, std::move(st));
+    }
+  }
+
+  void ErrorReceived(ExecNode* input, Status error) override {
+    DCHECK_EQ(input, inputs_[0]);
+    outputs_[0]->ErrorReceived(this, std::move(error));
+  }
+
+  void InputFinished(ExecNode* input, int seq) override {
+    DCHECK_EQ(input, inputs_[0]);
+    std::unique_lock<std::mutex> lock(mutex_);
+    num_total_ = seq;
+    Status st = MaybeFinish(&lock);
+
+    if (!st.ok()) {
+      outputs_[0]->ErrorReceived(this, std::move(st));
+    }
+  }
+
+  Status StartProducing() override {
+    finished_ = Future<>::Make();
+    // Scalar aggregates will only output a single batch
+    outputs_[0]->InputFinished(this, 1);
+    return Status::OK();
+  }
+
+  void PauseProducing(ExecNode* output) override {}
+
+  void ResumeProducing(ExecNode* output) override {}
+
+  void StopProducing(ExecNode* output) override {
+    DCHECK_EQ(output, outputs_[0]);
+    StopProducing();
+  }
+
+  void StopProducing() override {
+    inputs_[0]->StopProducing(this);
+    finished_.MarkFinished();
+  }
+
+  Future<> finished() override { return finished_; }
+
+ private:
+  Status MaybeFinish(std::unique_lock<std::mutex>* lock) {
+    if (num_received_ != num_total_) return Status::OK();
+
+    if (finished_.is_finished()) return Status::OK();
+
+    ExecBatch batch{{}, 1};
+    batch.values.resize(kernels_.size());
+
+    for (size_t i = 0; i < kernels_.size(); ++i) {
+      KernelContext ctx{plan()->exec_context()};
+      ctx.SetState(states_[0][i].get());
+
+      for (size_t thread_index = 1; thread_index < thread_indices_.size();
+           ++thread_index) {
+        RETURN_NOT_OK(
+            kernels_[i]->merge(&ctx, std::move(*states_[thread_index][i]), ctx.state()));
+      }
+      RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
+    }
+    lock->unlock();
+
+    outputs_[0]->InputReceived(this, 0, batch);
+
+    finished_.MarkFinished();
+    return Status::OK();
+  }
+
+  Future<> finished_ = Future<>::MakeFinished();
+  std::vector<const ScalarAggregateKernel*> kernels_;
+  std::vector<std::vector<std::unique_ptr<KernelState>>> states_;
+  std::unordered_map<std::thread::id, size_t> thread_indices_;
+  std::mutex mutex_;
+  int num_received_ = 0, num_total_;
+};
+
+Result<ExecNode*> MakeScalarAggregateNode(ExecNode* input, std::string label,
+                                          std::vector<internal::Aggregate> aggregates) {
+  if (input->output_schema()->num_fields() != static_cast<int>(aggregates.size())) {
+    return Status::Invalid("Provided ", aggregates.size(),

Review comment:
       That'd be a ProjectNode whose expressions repeat a field reference, for example:
   
   ```c++
     ASSERT_OK_AND_ASSIGN(auto projection,
                          MakeProjectNode(source, "project", {field_ref("i32"), field_ref("i32")}));
   
     ASSERT_OK_AND_ASSIGN(auto scalar_agg,
                          MakeScalarAggregateNode(projection, "scalar_agg",
                                                  {{"sum", nullptr}, {"mean", nullptr}}));
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org