You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "westonpace (via GitHub)" <gi...@apache.org> on 2023/04/10 22:07:37 UTC

[GitHub] [arrow] westonpace commented on a diff in pull request #34912: GH-34911: [C++] [WIP] Add first and last aggregator

westonpace commented on code in PR #34912:
URL: https://github.com/apache/arrow/pull/34912#discussion_r1162119127


##########
cpp/src/arrow/compute/kernels/aggregate_basic_internal.h:
##########
@@ -272,8 +273,120 @@ struct MeanKernelInit : public SumLikeInit<KernelClass> {
 };
 
 // ----------------------------------------------------------------------
-// MinMax implementation
+// Last implementation
+template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable = void>
+struct FirstLastState {};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct FirstLastState<ArrowType, SimdLevel, enable_if_floating_point<ArrowType>> {
+  using ThisType = FirstLastState<ArrowType, SimdLevel>;
+  using T = typename ArrowType::c_type;
+  using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+  ThisType& operator+=(const ThisType& rhs) {
+    this->has_nulls |= rhs.has_nulls;
+    this->first = this->first.has_value() ? this->first : rhs.first;
+    this->last = rhs.last.has_value() ? rhs.last : this->last;
+    return *this;
+  }
+
+  void MergeOne(T value) {
+    if (!this->first.has_value()) {
+      this->first = value;
+    }
+    this->last = value;
+  }
+
+  std::optional<T> first = std::nullopt;
+  std::optional<T> last = std::nullopt;
+  bool has_nulls = false;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct FirstLastImpl : public ScalarAggregator {
+  using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+  using ThisType = FirstLastImpl<ArrowType, SimdLevel>;
+  using StateType = FirstLastState<ArrowType, SimdLevel>;
+
+  FirstLastImpl(std::shared_ptr<DataType> out_type, ScalarAggregateOptions options)
+      : out_type(std::move(out_type)), options(std::move(options)), count(0) {
+    this->options.min_count = std::max<uint32_t>(1, this->options.min_count);
+  }
+
+  Status Consume(KernelContext*, const ExecSpan& batch) override {
+    if (batch[0].is_array()) {

Review Comment:
   Since this is a unary function it should be impossible for this to be false.



##########
cpp/src/arrow/compute/kernels/hash_aggregate.cc:
##########
@@ -1251,6 +1251,210 @@ HashAggregateKernel MakeApproximateMedianKernel(HashAggregateFunction* tdigest_f
   return kernel;
 }
 
+// ----------------------------------------------------------------------
+// FirstLast implementation
+
+template <typename CType>
+struct NullSentinel {
+  static constexpr CType value() { return std::numeric_limits<CType>::min(); }
+};
+
+template <>
+struct NullSentinel<float> {
+  static constexpr float value() { return std::numeric_limits<float>::infinity(); }
+};
+
+template <>
+struct NullSentinel<double> {
+  static constexpr double value() { return std::numeric_limits<double>::infinity(); }
+};
+
+template <typename Type, typename Enable = void>
+struct GroupedFirstLastImpl final : public GroupedAggregator {
+  using CType = typename TypeTraits<Type>::CType;
+  using GetSet = GroupedValueTraits<Type>;
+  using ArrType =
+      typename std::conditional<is_boolean_type<Type>::value, uint8_t, CType>::type;
+
+  Status Init(ExecContext* ctx, const KernelInitArgs& args) override {
+    options_ = *checked_cast<const ScalarAggregateOptions*>(args.options);
+
+    firsts_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+    lasts_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+    has_values_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+    has_nulls_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+    return Status::OK();
+  }
+
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    num_groups_ = new_num_groups;
+    RETURN_NOT_OK(firsts_.Append(added_groups, NullSentinel<CType>::value()));
+    RETURN_NOT_OK(lasts_.Append(added_groups, NullSentinel<CType>::value()));
+    RETURN_NOT_OK(has_values_.Append(added_groups, false));
+    RETURN_NOT_OK(has_nulls_.Append(added_groups, false));
+    return Status::OK();
+  }
+
+  Status Consume(const ExecSpan& batch) override {
+    auto raw_firsts = firsts_.mutable_data();
+    auto raw_lasts = lasts_.mutable_data();
+    auto raw_has_values = has_values_.mutable_data();
+
+    VisitGroupedValues<Type>(
+        batch,
+        [&](uint32_t g, CType val) {
+          if (!bit_util::GetBit(raw_has_values, g)) {
+            GetSet::Set(raw_firsts, g, val);
+            bit_util::SetBit(raw_has_values, g);
+          }
+          GetSet::Set(raw_lasts, g, val);
+          DCHECK(bit_util::GetBit(has_values_.mutable_data(), g));
+        },
+        [&](uint32_t g) { bit_util::SetBit(has_nulls_.mutable_data(), g); });
+    return Status::OK();
+  }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    // The merge is asymmetric. "first" from this state gets pick over "first" from other
+    // state. "last" from other state gets pick over from this state. This is so that when
+    // using with segmeneted aggregation, we still get the correct "first" and "last"
+    // value for the entire segement.
+    auto other = checked_cast<GroupedFirstLastImpl*>(&raw_other);
+
+    auto raw_firsts = firsts_.mutable_data();
+    auto raw_lasts = lasts_.mutable_data();
+    auto raw_has_values = has_values_.mutable_data();
+    auto raw_has_nulls = has_nulls_.mutable_data();
+
+    auto other_raw_firsts = other->firsts_.mutable_data();
+    auto other_raw_lasts = other->lasts_.mutable_data();
+    auto other_raw_has_values = other->has_values_.mutable_data();
+    auto other_raw_has_nulls = other->has_nulls_.mutable_data();
+
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+
+    for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < group_id_mapping.length;
+         ++other_g, ++g) {
+      if (!bit_util::GetBit(raw_has_values, *g)) {
+        if (bit_util::GetBit(other_raw_has_values, other_g)) {
+          GetSet::Set(raw_firsts, *g, GetSet::Get(other_raw_firsts, other_g));
+        }
+      }
+
+      if (bit_util::GetBit(other_raw_has_values, other_g)) {
+        GetSet::Set(raw_lasts, *g, GetSet::Get(other_raw_lasts, other_g));
+      }
+
+      if (bit_util::GetBit(other_raw_has_values, other_g)) {
+        bit_util::SetBit(raw_has_values, *g);
+      }
+      if (bit_util::GetBit(other_raw_has_nulls, other_g)) {
+        bit_util::SetBit(raw_has_nulls, *g);
+      }
+    }
+    return Status::OK();
+  }
+
+  Result<Datum> Finalize() override {
+    ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_values_.Finish());
+
+    if (!options_.skip_nulls) {
+      return Status::NotImplemented("Don't support first/last with skip nulls = False");

Review Comment:
   You don't actually have to compare values do you?  This seems like it would be pretty straightforward.
   
   `first([null, 5, 7])` is `null`.



##########
cpp/src/arrow/compute/kernels/hash_aggregate.cc:
##########
@@ -1251,6 +1251,210 @@ HashAggregateKernel MakeApproximateMedianKernel(HashAggregateFunction* tdigest_f
   return kernel;
 }
 
+// ----------------------------------------------------------------------
+// FirstLast implementation
+
+template <typename CType>
+struct NullSentinel {
+  static constexpr CType value() { return std::numeric_limits<CType>::min(); }
+};
+
+template <>
+struct NullSentinel<float> {
+  static constexpr float value() { return std::numeric_limits<float>::infinity(); }
+};
+
+template <>
+struct NullSentinel<double> {
+  static constexpr double value() { return std::numeric_limits<double>::infinity(); }
+};

Review Comment:
   Maybe `UninitializedSentinel`?



##########
cpp/src/arrow/compute/kernels/aggregate_basic_internal.h:
##########
@@ -272,8 +273,120 @@ struct MeanKernelInit : public SumLikeInit<KernelClass> {
 };
 
 // ----------------------------------------------------------------------
-// MinMax implementation
+// Last implementation
+template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable = void>
+struct FirstLastState {};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct FirstLastState<ArrowType, SimdLevel, enable_if_floating_point<ArrowType>> {
+  using ThisType = FirstLastState<ArrowType, SimdLevel>;
+  using T = typename ArrowType::c_type;
+  using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+  ThisType& operator+=(const ThisType& rhs) {
+    this->has_nulls |= rhs.has_nulls;
+    this->first = this->first.has_value() ? this->first : rhs.first;
+    this->last = rhs.last.has_value() ? rhs.last : this->last;
+    return *this;
+  }
+
+  void MergeOne(T value) {
+    if (!this->first.has_value()) {
+      this->first = value;
+    }
+    this->last = value;
+  }
+
+  std::optional<T> first = std::nullopt;
+  std::optional<T> last = std::nullopt;
+  bool has_nulls = false;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct FirstLastImpl : public ScalarAggregator {
+  using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+  using ThisType = FirstLastImpl<ArrowType, SimdLevel>;
+  using StateType = FirstLastState<ArrowType, SimdLevel>;
+
+  FirstLastImpl(std::shared_ptr<DataType> out_type, ScalarAggregateOptions options)
+      : out_type(std::move(out_type)), options(std::move(options)), count(0) {
+    this->options.min_count = std::max<uint32_t>(1, this->options.min_count);
+  }
+
+  Status Consume(KernelContext*, const ExecSpan& batch) override {
+    if (batch[0].is_array()) {
+      return ConsumeArray(batch[0].array);
+    }
+    return ConsumeScalar(*batch[0].scalar);
+  }
+
+  Status ConsumeScalar(const Scalar& scalar) {
+    return Status::NotImplemented("Consume scalar");
+  }
+
+  Status ConsumeArray(const ArraySpan& arr_span) {
+    StateType local;
+
+    ArrayType arr(arr_span.ToArrayData());
+    const auto null_count = arr.null_count();
+    local.has_nulls = null_count > 0;
+    this->count += arr.length() - null_count;
+
+    if (!local.has_nulls) {
+      for (int64_t i = 0; i < arr.length(); i++) {
+        local.MergeOne(arr.GetView(i));
+      }

Review Comment:
   Wouldn't you break as soon as you encounter a value?  Why do you need to iterate the entire array?  I suppose if you want both first AND last then you might need to iterate from both directions.  Something like...
   
   ```
   if (!has_first) {
     int index = 0;
     while (index < length) {
       if (arr[index] != null || !skip_nulls) {
         has_first = true;
         first = arr[index];
         break;
       } else {
         index++;
       }
     }
   }
   // No need to check has_last here since we always assume the current batch is replacing the last
   int index = length - 1;
   while (index >= 0) {
     if (arr[index] != null || !skip_nulls) {
       last = arr[index];
       break;
     } else {
       index--;
     }
   }  
   ```
   
   Also, it appears that `last` carries quite a bit more cost than `first`.  Imagine you were searching for `first` and `skip_nulls=false`.  All you need to do is look at one value and you can skip all future batches.
   
   Given this I'm not sure if we want to combine first/last into a single kernel.  Or at least, make it possible in some way to skip data if `last` isn't needed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org