You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/07/26 13:51:43 UTC

[GitHub] [arrow] cyb70289 commented on a change in pull request #10792: ARROW-13295: [C++] add hash_mean, hash_variance, hash_stddev kernels

cyb70289 commented on a change in pull request #10792:
URL: https://github.com/apache/arrow/pull/10792#discussion_r676583689



##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -1005,6 +1007,325 @@ struct GroupedSumFactory {
   InputType argument_type;
 };
 
+// ----------------------------------------------------------------------
+// Mean implementation
+
+template <typename Type>
+struct GroupedMeanImpl : public GroupedSumImpl<Type> {
+  Result<Datum> Finalize() override {
+    using SumType = typename GroupedSumImpl<Type>::SumType;
+    std::shared_ptr<Buffer> null_bitmap;
+    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> values,
+                          AllocateBuffer(num_groups_ * sizeof(double), pool_));
+    int64_t null_count = 0;
+
+    const int64_t* counts = reinterpret_cast<const int64_t*>(counts_.data());
+    const auto* sums = reinterpret_cast<const SumType*>(sums_.data());
+    double* means = reinterpret_cast<double*>(values->mutable_data());
+    for (int64_t i = 0; i < num_groups_; ++i) {
+      if (counts[i] > 0) {
+        means[i] = static_cast<double>(sums[i] / counts[i]);

Review comment:
       Should it be `static_cast<double>(sums[i]) / counts[i]`?

##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc
##########
@@ -1005,6 +1007,325 @@ struct GroupedSumFactory {
   InputType argument_type;
 };
 
+// ----------------------------------------------------------------------
+// Mean implementation
+
+template <typename Type>
+struct GroupedMeanImpl : public GroupedSumImpl<Type> {
+  Result<Datum> Finalize() override {
+    using SumType = typename GroupedSumImpl<Type>::SumType;
+    std::shared_ptr<Buffer> null_bitmap;
+    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> values,
+                          AllocateBuffer(num_groups_ * sizeof(double), pool_));
+    int64_t null_count = 0;
+
+    const int64_t* counts = reinterpret_cast<const int64_t*>(counts_.data());
+    const auto* sums = reinterpret_cast<const SumType*>(sums_.data());
+    double* means = reinterpret_cast<double*>(values->mutable_data());
+    for (int64_t i = 0; i < num_groups_; ++i) {
+      if (counts[i] > 0) {
+        means[i] = static_cast<double>(sums[i] / counts[i]);
+        continue;
+      }
+      means[i] = 0;
+
+      if (null_bitmap == nullptr) {
+        ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_));
+        BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true);
+      }
+
+      null_count += 1;
+      BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
+    }
+
+    return ArrayData::Make(float64(), num_groups_,
+                           {std::move(null_bitmap), std::move(values)}, null_count);
+  }
+
+  std::shared_ptr<DataType> out_type() const override { return float64(); }
+
+  using GroupedSumImpl<Type>::num_groups_;
+  using GroupedSumImpl<Type>::pool_;
+  using GroupedSumImpl<Type>::counts_;
+  using GroupedSumImpl<Type>::sums_;
+};
+
+struct GroupedMeanFactory {
+  template <typename T, typename AccType = typename FindAccumulatorType<T>::Type>
+  Status Visit(const T&) {
+    kernel = MakeKernel(std::move(argument_type), HashAggregateInit<GroupedMeanImpl<T>>);
+    return Status::OK();
+  }
+
+  Status Visit(const HalfFloatType& type) {
+    return Status::NotImplemented("Computing mean of type ", type);
+  }
+
+  Status Visit(const DataType& type) {
+    return Status::NotImplemented("Computing mean of type ", type);
+  }
+
+  static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) {
+    GroupedMeanFactory factory;
+    factory.argument_type = InputType::Array(type);
+    RETURN_NOT_OK(VisitTypeInline(*type, &factory));
+    return std::move(factory.kernel);
+  }
+
+  HashAggregateKernel kernel;
+  InputType argument_type;
+};
+
+// Variance/Stdev implementation
+
+using arrow::internal::int128_t;
+
+template <typename Type>
+struct GroupedVarStdImpl : public GroupedAggregator {
+  using CType = typename Type::c_type;
+
+  Status Init(ExecContext* ctx, const FunctionOptions* options) override {
+    options_ = *checked_cast<const VarianceOptions*>(options);
+    ctx_ = ctx;
+    pool_ = ctx->memory_pool();
+    counts_ = BufferBuilder(pool_);
+    means_ = BufferBuilder(pool_);
+    m2s_ = BufferBuilder(pool_);
+    return Status::OK();
+  }
+
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    num_groups_ = new_num_groups;
+    RETURN_NOT_OK(counts_.Append(added_groups * sizeof(int64_t), 0));
+    RETURN_NOT_OK(means_.Append(added_groups * sizeof(double), 0));
+    RETURN_NOT_OK(m2s_.Append(added_groups * sizeof(double), 0));
+    return Status::OK();
+  }
+
+  Status Consume(const ExecBatch& batch) override { return ConsumeImpl(batch); }
+
+  // float/double/int64: calculate `m2` (sum((X-mean)^2)) with `two pass algorithm`
+  // (see aggregate_var_std.cc)
+  template <typename T = Type>
+  enable_if_t<is_floating_type<T>::value || (sizeof(CType) > 4), Status> ConsumeImpl(
+      const ExecBatch& batch) {
+    using SumType =
+        typename std::conditional<is_floating_type<T>::value, double, int128_t>::type;
+
+    int64_t* counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
+    double* means = reinterpret_cast<double*>(means_.mutable_data());
+    double* m2s = reinterpret_cast<double*>(m2s_.mutable_data());

Review comment:
       Will these buffers always 8 bytes aligned?
   Not a real issue. Just to satisfy ubsan.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org