You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/07/12 19:49:19 UTC

[GitHub] [arrow] lidavidm commented on a change in pull request #10660: ARROW-12759: [C++][Compute] Add ExecNode for group by

lidavidm commented on a change in pull request #10660:
URL: https://github.com/apache/arrow/pull/10660#discussion_r667957128



##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -748,67 +752,128 @@ struct GrouperFastImpl : Grouper {
 /// Implementations should be default constructible and perform initialization in
 /// Init().
 struct GroupedAggregator : KernelState {
-  virtual Status Init(ExecContext*, const FunctionOptions*,
-                      const std::shared_ptr<DataType>&) = 0;
+  virtual Status Init(ExecContext*, const FunctionOptions*) = 0;
+
+  virtual Status Resize(int64_t new_num_groups) = 0;
 
   virtual Status Consume(const ExecBatch& batch) = 0;
 
-  virtual Result<Datum> Finalize() = 0;
+  virtual Status Merge(GroupedAggregator&& other, const ArrayData& group_id_mapping) = 0;
 
-  template <typename Reserve>
-  Status MaybeReserve(int64_t old_num_groups, const ExecBatch& batch,
-                      const Reserve& reserve) {
-    int64_t new_num_groups = batch[2].scalar_as<UInt32Scalar>().value;
-    if (new_num_groups <= old_num_groups) {
-      return Status::OK();
-    }
-    return reserve(new_num_groups - old_num_groups);
-  }
+  virtual Result<Datum> Finalize() = 0;
 
   virtual std::shared_ptr<DataType> out_type() const = 0;
 };
 
+template <typename Impl>
+Result<std::unique_ptr<KernelState>> HashAggregateInit(KernelContext* ctx,
+                                                       const KernelInitArgs& args) {
+  auto impl = ::arrow::internal::make_unique<Impl>();
+  RETURN_NOT_OK(impl->Init(ctx->exec_context(), args.options));
+  return std::move(impl);
+}
+
+HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
+  HashAggregateKernel kernel;
+
+  kernel.init = std::move(init);
+
+  kernel.signature = KernelSignature::Make(
+      {std::move(argument_type), InputType::Array(Type::UINT32)},
+      OutputType(
+          [](KernelContext* ctx, const std::vector<ValueDescr>&) -> Result<ValueDescr> {
+            return checked_cast<GroupedAggregator*>(ctx->state())->out_type();
+          }));
+
+  kernel.resize = [](KernelContext* ctx, int64_t num_groups) {
+    return checked_cast<GroupedAggregator*>(ctx->state())->Resize(num_groups);
+  };
+
+  kernel.consume = [](KernelContext* ctx, const ExecBatch& batch) {
+    return checked_cast<GroupedAggregator*>(ctx->state())->Consume(batch);
+  };
+
+  kernel.merge = [](KernelContext* ctx, KernelState&& other,
+                    const ArrayData& group_id_mapping) {
+    return checked_cast<GroupedAggregator*>(ctx->state())
+        ->Merge(checked_cast<GroupedAggregator&&>(other), group_id_mapping);
+  };
+
+  kernel.finalize = [](KernelContext* ctx, Datum* out) {
+    ARROW_ASSIGN_OR_RAISE(*out,
+                          checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
+    return Status::OK();
+  };
+
+  return kernel;
+}
+
+Status AddHashAggKernels(
+    const std::vector<std::shared_ptr<DataType>>& types,
+    Result<HashAggregateKernel> make_kernel(const std::shared_ptr<DataType>&),
+    HashAggregateFunction* function) {
+  for (const auto& ty : types) {
+    ARROW_ASSIGN_OR_RAISE(auto kernel, make_kernel(ty));
+    RETURN_NOT_OK(function->AddKernel(std::move(kernel)));
+  }
+  return Status::OK();
+}
+
 // ----------------------------------------------------------------------
 // Count implementation
 
 struct GroupedCountImpl : public GroupedAggregator {
-  Status Init(ExecContext* ctx, const FunctionOptions* options,
-              const std::shared_ptr<DataType>&) override {
+  Status Init(ExecContext* ctx, const FunctionOptions* options) override {
     options_ = checked_cast<const ScalarAggregateOptions&>(*options);
     counts_ = BufferBuilder(ctx->memory_pool());
     return Status::OK();
   }
 
-  Status Consume(const ExecBatch& batch) override {
-    RETURN_NOT_OK(MaybeReserve(num_groups_, batch, [&](int64_t added_groups) {
-      num_groups_ += added_groups;
-      return counts_.Append(added_groups * sizeof(int64_t), 0);
-    }));
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    num_groups_ = new_num_groups;
+    return counts_.Append(added_groups * sizeof(int64_t), 0);
+  }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    auto other = checked_cast<GroupedCountImpl*>(&raw_other);
 
-    auto group_ids = batch[1].array()->GetValues<uint32_t>(1);
-    auto raw_counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+    auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+    auto other_counts = reinterpret_cast<const int64_t*>(other->counts_.mutable_data());
+
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+    for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+      counts[*g] += other_counts[other_g];
+    }
+    return Status::OK();
+  }
+
+  Status Consume(const ExecBatch& batch) override {
+    auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
 
     const auto& input = batch[0].array();
 
-    if (!options_.skip_nulls) {
-      if (input->GetNullCount() != 0) {
-        for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, ++input_i) {
-          auto g = group_ids[i];
-          raw_counts[g] += !BitUtil::GetBit(input->buffers[0]->data(), input_i);
-        }
+    if (options_.skip_nulls) {
+      auto g_begin =
+          reinterpret_cast<const uint32_t*>(batch[1].array()->buffers[1]->data());
+
+      arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset,
+                                           input->length,
+                                           [&](int64_t offset, int64_t length) {
+                                             auto g = g_begin + offset;
+                                             for (int64_t i = 0; i < length; ++i, ++g) {
+                                               counts[*g] += 1;
+                                             }
+                                           });
+    } else if (input->MayHaveNulls()) {

Review comment:
       Doesn't this fail to count if !skip_nulls && !MayHaveNulls?

##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -748,67 +752,128 @@ struct GrouperFastImpl : Grouper {
 /// Implementations should be default constructible and perform initialization in
 /// Init().
 struct GroupedAggregator : KernelState {
-  virtual Status Init(ExecContext*, const FunctionOptions*,
-                      const std::shared_ptr<DataType>&) = 0;
+  virtual Status Init(ExecContext*, const FunctionOptions*) = 0;
+
+  virtual Status Resize(int64_t new_num_groups) = 0;
 
   virtual Status Consume(const ExecBatch& batch) = 0;
 
-  virtual Result<Datum> Finalize() = 0;
+  virtual Status Merge(GroupedAggregator&& other, const ArrayData& group_id_mapping) = 0;
 
-  template <typename Reserve>
-  Status MaybeReserve(int64_t old_num_groups, const ExecBatch& batch,
-                      const Reserve& reserve) {
-    int64_t new_num_groups = batch[2].scalar_as<UInt32Scalar>().value;
-    if (new_num_groups <= old_num_groups) {
-      return Status::OK();
-    }
-    return reserve(new_num_groups - old_num_groups);
-  }
+  virtual Result<Datum> Finalize() = 0;
 
   virtual std::shared_ptr<DataType> out_type() const = 0;
 };
 
+template <typename Impl>
+Result<std::unique_ptr<KernelState>> HashAggregateInit(KernelContext* ctx,
+                                                       const KernelInitArgs& args) {
+  auto impl = ::arrow::internal::make_unique<Impl>();
+  RETURN_NOT_OK(impl->Init(ctx->exec_context(), args.options));
+  return std::move(impl);
+}
+
+HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
+  HashAggregateKernel kernel;
+
+  kernel.init = std::move(init);
+
+  kernel.signature = KernelSignature::Make(
+      {std::move(argument_type), InputType::Array(Type::UINT32)},
+      OutputType(
+          [](KernelContext* ctx, const std::vector<ValueDescr>&) -> Result<ValueDescr> {
+            return checked_cast<GroupedAggregator*>(ctx->state())->out_type();
+          }));
+
+  kernel.resize = [](KernelContext* ctx, int64_t num_groups) {
+    return checked_cast<GroupedAggregator*>(ctx->state())->Resize(num_groups);
+  };
+
+  kernel.consume = [](KernelContext* ctx, const ExecBatch& batch) {
+    return checked_cast<GroupedAggregator*>(ctx->state())->Consume(batch);
+  };
+
+  kernel.merge = [](KernelContext* ctx, KernelState&& other,
+                    const ArrayData& group_id_mapping) {
+    return checked_cast<GroupedAggregator*>(ctx->state())
+        ->Merge(checked_cast<GroupedAggregator&&>(other), group_id_mapping);
+  };
+
+  kernel.finalize = [](KernelContext* ctx, Datum* out) {
+    ARROW_ASSIGN_OR_RAISE(*out,
+                          checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
+    return Status::OK();
+  };
+
+  return kernel;
+}
+
+Status AddHashAggKernels(
+    const std::vector<std::shared_ptr<DataType>>& types,
+    Result<HashAggregateKernel> make_kernel(const std::shared_ptr<DataType>&),
+    HashAggregateFunction* function) {
+  for (const auto& ty : types) {
+    ARROW_ASSIGN_OR_RAISE(auto kernel, make_kernel(ty));
+    RETURN_NOT_OK(function->AddKernel(std::move(kernel)));
+  }
+  return Status::OK();
+}
+
 // ----------------------------------------------------------------------
 // Count implementation
 
 struct GroupedCountImpl : public GroupedAggregator {
-  Status Init(ExecContext* ctx, const FunctionOptions* options,
-              const std::shared_ptr<DataType>&) override {
+  Status Init(ExecContext* ctx, const FunctionOptions* options) override {
     options_ = checked_cast<const ScalarAggregateOptions&>(*options);
     counts_ = BufferBuilder(ctx->memory_pool());
     return Status::OK();
   }
 
-  Status Consume(const ExecBatch& batch) override {
-    RETURN_NOT_OK(MaybeReserve(num_groups_, batch, [&](int64_t added_groups) {
-      num_groups_ += added_groups;
-      return counts_.Append(added_groups * sizeof(int64_t), 0);
-    }));
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    num_groups_ = new_num_groups;
+    return counts_.Append(added_groups * sizeof(int64_t), 0);
+  }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    auto other = checked_cast<GroupedCountImpl*>(&raw_other);
 
-    auto group_ids = batch[1].array()->GetValues<uint32_t>(1);
-    auto raw_counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+    auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+    auto other_counts = reinterpret_cast<const int64_t*>(other->counts_.mutable_data());
+
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+    for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+      counts[*g] += other_counts[other_g];
+    }
+    return Status::OK();
+  }
+
+  Status Consume(const ExecBatch& batch) override {
+    auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
 
     const auto& input = batch[0].array();
 
-    if (!options_.skip_nulls) {
-      if (input->GetNullCount() != 0) {
-        for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, ++input_i) {
-          auto g = group_ids[i];
-          raw_counts[g] += !BitUtil::GetBit(input->buffers[0]->data(), input_i);
-        }
+    if (options_.skip_nulls) {
+      auto g_begin =
+          reinterpret_cast<const uint32_t*>(batch[1].array()->buffers[1]->data());
+
+      arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset,
+                                           input->length,
+                                           [&](int64_t offset, int64_t length) {
+                                             auto g = g_begin + offset;
+                                             for (int64_t i = 0; i < length; ++i, ++g) {
+                                               counts[*g] += 1;
+                                             }
+                                           });
+    } else if (input->MayHaveNulls()) {

Review comment:
       Ah, sorry, I misunderstood what skip_nulls means. Thanks.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org