You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/02/28 01:34:55 UTC
[arrow] branch master updated: ARROW-3121: [C++] Mean aggregate
kernel
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 29aa925 ARROW-3121: [C++] Mean aggregate kernel
29aa925 is described below
commit 29aa925683870d2644bb6334610d164aaeef6d10
Author: François Saint-Jacques <fs...@gmail.com>
AuthorDate: Wed Feb 27 19:34:45 2019 -0600
ARROW-3121: [C++] Mean aggregate kernel
Implements the mean (average) kernel aggregates on numeric columns. The final type is always a double. Refactored the Sum kernel implementation to share common parts, notably the consume part is identical. Only the Finalize and output type differ.
Author: François Saint-Jacques <fs...@gmail.com>
Closes #3708 from fsaintjacques/ARROW-3121-mean-aggregate and squashes the following commits:
0d39c1f4 <François Saint-Jacques> reformat
d41db22e <François Saint-Jacques> Refactor with ternary
d1191fd8 <François Saint-Jacques> Deal with NaN values in sum
c448bcbc <François Saint-Jacques> Add documentation per review
79291402 <François Saint-Jacques> Implement mean aggregate
3a1a0cd8 <François Saint-Jacques> Refactor sum implementation
8bc293f1 <François Saint-Jacques> Move TypeTraits into sum-internal.h
---
cpp/src/arrow/CMakeLists.txt | 1 +
cpp/src/arrow/compute/kernels/aggregate-test.cc | 191 +++++++++++++--------
cpp/src/arrow/compute/kernels/mean.cc | 115 +++++++++++++
cpp/src/arrow/compute/kernels/{sum.h => mean.h} | 48 ++----
.../compute/kernels/{sum.cc => sum-internal.h} | 154 +++++------------
cpp/src/arrow/compute/kernels/sum.cc | 181 ++-----------------
cpp/src/arrow/compute/kernels/sum.h | 38 +---
cpp/src/arrow/compute/test-util.h | 35 ++++
cpp/src/arrow/util/bit-util-test.cc | 9 +
cpp/src/arrow/util/bit-util.h | 15 ++
10 files changed, 372 insertions(+), 415 deletions(-)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 94eba0c..3c6a399 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -147,6 +147,7 @@ if(ARROW_COMPUTE)
compute/kernels/boolean.cc
compute/kernels/cast.cc
compute/kernels/hash.cc
+ compute/kernels/mean.cc
compute/kernels/sum.cc
compute/kernels/util-internal.cc)
endif()
diff --git a/cpp/src/arrow/compute/kernels/aggregate-test.cc b/cpp/src/arrow/compute/kernels/aggregate-test.cc
index ca44744..bdf50f5 100644
--- a/cpp/src/arrow/compute/kernels/aggregate-test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate-test.cc
@@ -15,13 +15,18 @@
// specific language governing permissions and limitations
// under the License.
+#include <algorithm>
+#include <memory>
#include <string>
#include <type_traits>
+#include <utility>
#include <gtest/gtest.h>
#include "arrow/array.h"
#include "arrow/compute/kernel.h"
+#include "arrow/compute/kernels/mean.h"
+#include "arrow/compute/kernels/sum-internal.h"
#include "arrow/compute/kernels/sum.h"
#include "arrow/compute/test-util.h"
#include "arrow/type.h"
@@ -38,63 +43,16 @@ using std::vector;
namespace arrow {
namespace compute {
-template <typename Type, typename Enable = void>
-struct DatumEqual {
- static void EnsureEqual(const Datum& lhs, const Datum& rhs) {}
-};
-
-template <typename Type>
-struct DatumEqual<Type, typename std::enable_if<IsFloatingPoint<Type>::Value>::type> {
- static constexpr double kArbitraryDoubleErrorBound = 1.0;
- using ScalarType = typename TypeTraits<Type>::ScalarType;
-
- static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
- ASSERT_EQ(lhs.kind(), rhs.kind());
- if (lhs.kind() == Datum::SCALAR) {
- auto left = static_cast<const ScalarType*>(lhs.scalar().get());
- auto right = static_cast<const ScalarType*>(rhs.scalar().get());
- ASSERT_EQ(left->type->id(), right->type->id());
- ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
- }
- }
-};
-
-template <typename Type>
-struct DatumEqual<Type, typename std::enable_if<!IsFloatingPoint<Type>::value>::type> {
- using ScalarType = typename TypeTraits<Type>::ScalarType;
- static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
- ASSERT_EQ(lhs.kind(), rhs.kind());
- if (lhs.kind() == Datum::SCALAR) {
- auto left = static_cast<const ScalarType*>(lhs.scalar().get());
- auto right = static_cast<const ScalarType*>(rhs.scalar().get());
- ASSERT_EQ(left->type->id(), right->type->id());
- ASSERT_EQ(left->value, right->value);
- }
- }
-};
-
template <typename ArrowType>
-void ValidateSum(FunctionContext* ctx, const Array& input, Datum expected) {
- using OutputType = typename FindAccumulatorType<ArrowType>::Type;
- Datum result;
- ASSERT_OK(Sum(ctx, input, &result));
- DatumEqual<OutputType>::EnsureEqual(result, expected);
-}
+using SumResult =
+ std::pair<typename FindAccumulatorType<ArrowType>::Type::c_type, size_t>;
template <typename ArrowType>
-void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
- auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
- ValidateSum<ArrowType>(ctx, *array, expected);
-}
-
-template <typename ArrowType>
-static Datum DummySum(const Array& array) {
+static SumResult<ArrowType> NaiveSumPartial(const Array& array) {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
- using SumType = typename FindAccumulatorType<ArrowType>::Type;
- using SumScalarType = typename TypeTraits<SumType>::ScalarType;
+ using ResultType = SumResult<ArrowType>;
- typename SumType::c_type sum = 0;
- int64_t count = 0;
+ ResultType result;
auto data = array.data();
internal::BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
@@ -102,30 +60,52 @@ static Datum DummySum(const Array& array) {
const auto values = array_numeric.raw_values();
for (int64_t i = 0; i < array.length(); i++) {
if (reader.IsSet()) {
- sum += values[i];
- count++;
+ result.first += values[i];
+ result.second++;
}
reader.Next();
}
- if (count > 0) {
- return Datum(std::make_shared<SumScalarType>(sum));
- } else {
- return Datum(std::make_shared<SumScalarType>(0, false));
- }
+ return result;
+}
+
+template <typename ArrowType>
+static Datum NaiveSum(const Array& array) {
+ using SumType = typename FindAccumulatorType<ArrowType>::Type;
+ using SumScalarType = typename TypeTraits<SumType>::ScalarType;
+
+ auto result = NaiveSumPartial<ArrowType>(array);
+ bool is_valid = result.second > 0;
+
+ return Datum(std::make_shared<SumScalarType>(result.first, is_valid));
+}
+
+template <typename ArrowType>
+void ValidateSum(FunctionContext* ctx, const Array& input, Datum expected) {
+ using OutputType = typename FindAccumulatorType<ArrowType>::Type;
+
+ Datum result;
+ ASSERT_OK(Sum(ctx, input, &result));
+ DatumEqual<OutputType>::EnsureEqual(result, expected);
+}
+
+template <typename ArrowType>
+void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
+ auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
+ ValidateSum<ArrowType>(ctx, *array, expected);
}
template <typename ArrowType>
void ValidateSum(FunctionContext* ctx, const Array& array) {
- ValidateSum<ArrowType>(ctx, array, DummySum<ArrowType>(array));
+ ValidateSum<ArrowType>(ctx, array, NaiveSum<ArrowType>(array));
}
template <typename ArrowType>
-class TestSumKernelNumeric : public ComputeFixture, public TestBase {};
+class TestNumericSumKernel : public ComputeFixture, public TestBase {};
-TYPED_TEST_CASE(TestSumKernelNumeric, NumericArrowTypes);
-TYPED_TEST(TestSumKernelNumeric, SimpleSum) {
+TYPED_TEST_CASE(TestNumericSumKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericSumKernel, SimpleSum) {
using SumType = typename FindAccumulatorType<TypeParam>::Type;
using ScalarType = typename TypeTraits<SumType>::ScalarType;
using T = typename TypeParam::c_type;
@@ -145,10 +125,10 @@ TYPED_TEST(TestSumKernelNumeric, SimpleSum) {
}
template <typename ArrowType>
-class TestRandomSumKernelNumeric : public ComputeFixture, public TestBase {};
+class TestRandomNumericSumKernel : public ComputeFixture, public TestBase {};
-TYPED_TEST_CASE(TestRandomSumKernelNumeric, NumericArrowTypes);
-TYPED_TEST(TestRandomSumKernelNumeric, RandomArraySum) {
+TYPED_TEST_CASE(TestRandomNumericSumKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericSumKernel, RandomArraySum) {
auto rand = random::RandomArrayGenerator(0x5487655);
for (size_t i = 3; i < 14; i++) {
for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
@@ -161,7 +141,7 @@ TYPED_TEST(TestRandomSumKernelNumeric, RandomArraySum) {
}
}
-TYPED_TEST(TestRandomSumKernelNumeric, RandomSliceArraySum) {
+TYPED_TEST(TestRandomNumericSumKernel, RandomSliceArraySum) {
auto arithmetic = ArrayFromJSON(TypeTraits<TypeParam>::type_singleton(),
"[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]");
ValidateSum<TypeParam>(&this->ctx_, *arithmetic);
@@ -175,12 +155,87 @@ TYPED_TEST(TestRandomSumKernelNumeric, RandomSliceArraySum) {
const int64_t length = 1U << 6;
auto array = rand.Numeric<TypeParam>(length, 0, 10, 0.5);
for (size_t i = 1; i < 16; i++) {
- for (size_t j = 1; i < 16; i++) {
+ for (size_t j = 1; j < 16; j++) {
auto slice = array->Slice(i, length - j);
ValidateSum<TypeParam>(&this->ctx_, *slice);
}
}
}
+template <typename ArrowType>
+static Datum NaiveMean(const Array& array) {
+ using MeanScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+ const auto result = NaiveSumPartial<ArrowType>(array);
+ const double mean = static_cast<double>(result.first) /
+ static_cast<double>(result.second ? result.second : 1UL);
+ const bool is_valid = result.second > 0;
+
+ return Datum(std::make_shared<MeanScalarType>(mean, is_valid));
+}
+
+template <typename ArrowType>
+void ValidateMean(FunctionContext* ctx, const Array& input, Datum expected) {
+ using OutputType = typename FindAccumulatorType<DoubleType>::Type;
+
+ Datum result;
+ ASSERT_OK(Mean(ctx, input, &result));
+ DatumEqual<OutputType>::EnsureEqual(result, expected);
+}
+
+template <typename ArrowType>
+void ValidateMean(FunctionContext* ctx, const char* json, Datum expected) {
+ auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
+ ValidateMean<ArrowType>(ctx, *array, expected);
+}
+
+template <typename ArrowType>
+void ValidateMean(FunctionContext* ctx, const Array& array) {
+ ValidateMean<ArrowType>(ctx, array, NaiveMean<ArrowType>(array));
+}
+
+template <typename ArrowType>
+class TestMeanKernelNumeric : public ComputeFixture, public TestBase {};
+
+TYPED_TEST_CASE(TestMeanKernelNumeric, NumericArrowTypes);
+TYPED_TEST(TestMeanKernelNumeric, SimpleMean) {
+ using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+ ValidateMean<TypeParam>(&this->ctx_, "[]",
+ Datum(std::make_shared<ScalarType>(0.0, false)));
+
+ ValidateMean<TypeParam>(&this->ctx_, "[null]",
+ Datum(std::make_shared<ScalarType>(0.0, false)));
+
+ ValidateMean<TypeParam>(&this->ctx_, "[1, null, 1]",
+ Datum(std::make_shared<ScalarType>(1.0)));
+
+ ValidateMean<TypeParam>(&this->ctx_, "[1, 2, 3, 4, 5, 6, 7, 8]",
+ Datum(std::make_shared<ScalarType>(4.5)));
+
+ ValidateMean<TypeParam>(&this->ctx_, "[0, 0, 0, 0, 0, 0, 0, 0]",
+ Datum(std::make_shared<ScalarType>(0.0)));
+
+ ValidateMean<TypeParam>(&this->ctx_, "[1, 1, 1, 1, 1, 1, 1, 1]",
+ Datum(std::make_shared<ScalarType>(1.0)));
+}
+
+template <typename ArrowType>
+class TestRandomNumericMeanKernel : public ComputeFixture, public TestBase {};
+
+TYPED_TEST_CASE(TestRandomNumericMeanKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericMeanKernel, RandomArrayMean) {
+ auto rand = random::RandomArrayGenerator(0x8afc055);
+ for (size_t i = 3; i < 14; i++) {
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
+ for (auto length_adjust : {-2, -1, 0, 1, 2}) {
+ int64_t length = (1UL << i) + length_adjust;
+ auto array = rand.Numeric<TypeParam>(length, 0, 100, null_probability);
+ ValidateMean<TypeParam>(&this->ctx_, *array);
+ }
+ }
+ }
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/mean.cc b/cpp/src/arrow/compute/kernels/mean.cc
new file mode 100644
index 0000000..d1eaf15
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/mean.cc
@@ -0,0 +1,115 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/mean.h"
+
+#include <algorithm>
+
+#include "arrow/compute/kernels/sum-internal.h"
+
+namespace arrow {
+namespace compute {
+
+template <typename ArrowType,
+ typename SumType = typename FindAccumulatorType<ArrowType>::Type>
+struct MeanState {
+ using ThisType = MeanState<ArrowType, SumType>;
+
+ ThisType operator+(const ThisType& rhs) const {
+ return ThisType(this->count + rhs.count, this->sum + rhs.sum);
+ }
+
+ ThisType& operator+=(const ThisType& rhs) {
+ this->count += rhs.count;
+ this->sum += rhs.sum;
+
+ return *this;
+ }
+
+ std::shared_ptr<Scalar> Finalize() const {
+ using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+ const bool is_valid = count > 0;
+ const double divisor = static_cast<double>(is_valid ? count : 1UL);
+ const double mean = static_cast<double>(sum) / divisor;
+
+ return std::make_shared<ScalarType>(mean, is_valid);
+ }
+
+ static std::shared_ptr<DataType> out_type() {
+ return TypeTraits<DoubleType>::type_singleton();
+ }
+
+ size_t count = 0;
+ typename SumType::c_type sum = 0;
+};
+
+#define MEAN_AGG_FN_CASE(T) \
+ case T::type_id: \
+ return std::static_pointer_cast<AggregateFunction>( \
+ std::make_shared<SumAggregateFunction<T, MeanState<T>>>());
+
+std::shared_ptr<AggregateFunction> MakeMeanAggregateFunction(const DataType& type,
+ FunctionContext* ctx) {
+ switch (type.id()) {
+ MEAN_AGG_FN_CASE(UInt8Type);
+ MEAN_AGG_FN_CASE(Int8Type);
+ MEAN_AGG_FN_CASE(UInt16Type);
+ MEAN_AGG_FN_CASE(Int16Type);
+ MEAN_AGG_FN_CASE(UInt32Type);
+ MEAN_AGG_FN_CASE(Int32Type);
+ MEAN_AGG_FN_CASE(UInt64Type);
+ MEAN_AGG_FN_CASE(Int64Type);
+ MEAN_AGG_FN_CASE(FloatType);
+ MEAN_AGG_FN_CASE(DoubleType);
+ default:
+ return nullptr;
+ }
+
+#undef MEAN_AGG_FN_CASE
+}
+
+static Status GetMeanKernel(FunctionContext* ctx, const DataType& type,
+ std::shared_ptr<AggregateUnaryKernel>& kernel) {
+ std::shared_ptr<AggregateFunction> aggregate = MakeMeanAggregateFunction(type, ctx);
+ if (!aggregate) return Status::Invalid("No mean for type ", type);
+
+ kernel = std::make_shared<AggregateUnaryKernel>(aggregate);
+
+ return Status::OK();
+}
+
+Status Mean(FunctionContext* ctx, const Datum& value, Datum* out) {
+ std::shared_ptr<AggregateUnaryKernel> kernel;
+
+ auto data_type = value.type();
+ if (data_type == nullptr)
+ return Status::Invalid("Datum must be array-like");
+ else if (!is_integer(data_type->id()) && !is_floating(data_type->id()))
+ return Status::Invalid("Datum must contain a NumericType");
+
+ RETURN_NOT_OK(GetMeanKernel(ctx, *data_type, kernel));
+
+ return kernel->Call(ctx, value, out);
+}
+
+Status Mean(FunctionContext* ctx, const Array& array, Datum* out) {
+ return Mean(ctx, array.data(), out);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/sum.h b/cpp/src/arrow/compute/kernels/mean.h
similarity index 51%
copy from cpp/src/arrow/compute/kernels/sum.h
copy to cpp/src/arrow/compute/kernels/mean.h
index 88da2ac..5074d4e 100644
--- a/cpp/src/arrow/compute/kernels/sum.h
+++ b/cpp/src/arrow/compute/kernels/mean.h
@@ -15,8 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-#ifndef ARROW_COMPUTE_KERNELS_SUM_H
-#define ARROW_COMPUTE_KERNELS_SUM_H
+#pragma once
#include <memory>
#include <type_traits>
@@ -33,58 +32,35 @@ class DataType;
namespace compute {
-// Find the largest compatible primitive type for a primitive type.
-template <typename I, typename Enable = void>
-struct FindAccumulatorType {
- using Type = DoubleType;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsSignedInt<I>::value>::type> {
- using Type = Int64Type;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsUnsignedInt<I>::value>::type> {
- using Type = UInt64Type;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsFloatingPoint<I>::value>::type> {
- using Type = DoubleType;
-};
-
struct Datum;
class FunctionContext;
class AggregateFunction;
ARROW_EXPORT
-std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
- FunctionContext* context);
+std::shared_ptr<AggregateFunction> MakeMeanAggregateFunction(const DataType& type,
+ FunctionContext* context);
-/// \brief Sum values of a numeric array.
+/// \brief Compute the mean of a numeric array.
///
/// \param[in] context the FunctionContext
-/// \param[in] value datum to sum, expecting Array or ChunkedArray
-/// \param[out] out resulting datum
+/// \param[in] value datum to compute the mean, expecting Array
+/// \param[out] mean datum of the computed mean as a DoubleScalar
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
-Status Sum(FunctionContext* context, const Datum& value, Datum* out);
+Status Mean(FunctionContext* context, const Datum& value, Datum* mean);
-/// \brief Sum values of a numeric array.
+/// \brief Compute the mean of a numeric array.
///
/// \param[in] context the FunctionContext
-/// \param[in] array to sum
-/// \param[out] out resulting datum
+/// \param[in] array to compute the mean
+/// \param[out] mean datum of the computed mean as a DoubleScalar
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
-Status Sum(FunctionContext* context, const Array& array, Datum* out);
+Status Mean(FunctionContext* context, const Array& array, Datum* mean);
} // namespace compute
-} // namespace arrow
-
-#endif // ARROW_COMPUTE_KERNELS_CAST_H
+}; // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/sum.cc b/cpp/src/arrow/compute/kernels/sum-internal.h
similarity index 56%
copy from cpp/src/arrow/compute/kernels/sum.cc
copy to cpp/src/arrow/compute/kernels/sum-internal.h
index 1799941..a4e7ea6 100644
--- a/cpp/src/arrow/compute/kernels/sum.cc
+++ b/cpp/src/arrow/compute/kernels/sum-internal.h
@@ -1,7 +1,7 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
-// returnGegarding copyright ownership. The ASF licenses this file
+// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
@@ -15,63 +15,56 @@
// specific language governing permissions and limitations
// under the License.
-#include "arrow/compute/kernels/sum.h"
+#pragma once
+
+#include <memory>
+#include <type_traits>
-#include "arrow/array.h"
#include "arrow/compute/kernel.h"
#include "arrow/compute/kernels/aggregate.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/logging.h"
-#include "arrow/visitor_inline.h"
namespace arrow {
-namespace compute {
-template <typename ArrowType,
- typename SumType = typename FindAccumulatorType<ArrowType>::Type>
-struct SumState {
- using ThisType = SumState<ArrowType, SumType>;
-
- ThisType operator+(const ThisType& rhs) const {
- return ThisType(this->count + rhs.count, this->sum + rhs.sum);
- }
+class Array;
+class DataType;
- ThisType& operator+=(const ThisType& rhs) {
- this->count += rhs.count;
- this->sum += rhs.sum;
-
- return *this;
- }
+namespace compute {
- std::shared_ptr<Scalar> AsScalar() const {
- using ScalarType = typename TypeTraits<SumType>::ScalarType;
- return std::make_shared<ScalarType>(this->sum);
- }
+// Find the largest compatible primitive type for a primitive type.
+template <typename I, typename Enable = void>
+struct FindAccumulatorType {};
- size_t count = 0;
- typename SumType::c_type sum = 0;
+template <typename I>
+struct FindAccumulatorType<I, enable_if_signed_integer<I>> {
+ using Type = Int64Type;
};
-constexpr int64_t CoveringBytes(int64_t offset, int64_t length) {
- return (BitUtil::RoundUp(length + offset, 8) - BitUtil::RoundDown(offset, 8)) / 8;
-}
+template <typename I>
+struct FindAccumulatorType<I, enable_if_unsigned_integer<I>> {
+ using Type = UInt64Type;
+};
-static_assert(CoveringBytes(0, 8) == 1, "");
-static_assert(CoveringBytes(0, 9) == 2, "");
-static_assert(CoveringBytes(1, 7) == 1, "");
-static_assert(CoveringBytes(1, 8) == 2, "");
-static_assert(CoveringBytes(2, 19) == 3, "");
-static_assert(CoveringBytes(7, 18) == 4, "");
+template <typename I>
+struct FindAccumulatorType<I, enable_if_floating_point<I>> {
+ using Type = DoubleType;
+};
-template <typename ArrowType, typename StateType = SumState<ArrowType>>
+template <typename ArrowType, typename StateType>
class SumAggregateFunction final : public AggregateFunctionStaticState<StateType> {
using CType = typename TypeTraits<ArrowType>::CType;
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ // A small number of elements rounded to the next cacheline. This should
+ // amount to a maximum of 4 cachelines when dealing with 8 bytes elements.
static constexpr int64_t kTinyThreshold = 32;
- static_assert(kTinyThreshold > 18,
- "ConsumeSparse requires at least 18 elements to fit 3 bytes");
+ static_assert(kTinyThreshold >= (2 * CHAR_BIT) + 1,
+ "ConsumeSparse requires 3 bytes of null bitmap, and 17 is the"
+ "required minimum number of bits/elements to cover 3 bytes.");
public:
Status Consume(const Array& input, StateType* state) const override {
@@ -96,18 +89,11 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
}
Status Finalize(const StateType& src, Datum* output) const override {
- auto boxed = src.AsScalar();
- if (src.count == 0) {
- // TODO(wesm): Currently null, but fix this
- boxed->is_valid = false;
- }
- *output = boxed;
+ *output = src.Finalize();
return Status::OK();
}
- std::shared_ptr<DataType> out_type() const override {
- return TypeTraits<typename FindAccumulatorType<ArrowType>::Type>::type_singleton();
- }
+ std::shared_ptr<DataType> out_type() const override { return StateType::out_type(); }
private:
StateType ConsumeDense(const ArrayType& array) const {
@@ -141,22 +127,20 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
return local;
}
+ // While this is not branchless, gcc needs this to be in a different function
+ // for it to generate cmov which ends to be slightly faster than
+ // multiplication but safe for handling NaN with doubles.
+ inline CType MaskedValue(bool valid, CType value) const { return valid ? value : 0; }
+
inline StateType UnrolledSum(uint8_t bits, const CType* values) const {
StateType local;
if (bits < 0xFF) {
-#define SUM_SHIFT(ITEM) values[ITEM] * static_cast<CType>(((bits >> ITEM) & 1U))
// Some nulls
- local.sum += SUM_SHIFT(0);
- local.sum += SUM_SHIFT(1);
- local.sum += SUM_SHIFT(2);
- local.sum += SUM_SHIFT(3);
- local.sum += SUM_SHIFT(4);
- local.sum += SUM_SHIFT(5);
- local.sum += SUM_SHIFT(6);
- local.sum += SUM_SHIFT(7);
+ for (size_t i = 0; i < 8; i++) {
+ local.sum += MaskedValue(bits & (1U << i), values[i]);
+ }
local.count += BitUtil::kBytePopcount[bits];
-#undef SUM_SHIFT
} else {
// No nulls
for (size_t i = 0; i < 8; i++) {
@@ -189,7 +173,8 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
// The number of bytes covering the range, this includes partial bytes.
// This number bounded by `<= (length / 8) + 2`, e.g. a possible extra byte
// on the left, and on the right.
- const int64_t covering_bytes = CoveringBytes(offset, length);
+ const int64_t covering_bytes = BitUtil::CoveringBytes(offset, length);
+ DCHECK_GE(covering_bytes, 3);
// Align values to the first batch of 8 elements. Note that raw_values() is
// already adjusted with the offset, thus we rewind a little to align to
@@ -216,60 +201,7 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
return local;
}
-};
-
-#define SUM_AGG_FN_CASE(T) \
- case T::type_id: \
- return std::static_pointer_cast<AggregateFunction>( \
- std::make_shared<SumAggregateFunction<T>>());
-
-std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
- FunctionContext* ctx) {
- switch (type.id()) {
- SUM_AGG_FN_CASE(UInt8Type);
- SUM_AGG_FN_CASE(Int8Type);
- SUM_AGG_FN_CASE(UInt16Type);
- SUM_AGG_FN_CASE(Int16Type);
- SUM_AGG_FN_CASE(UInt32Type);
- SUM_AGG_FN_CASE(Int32Type);
- SUM_AGG_FN_CASE(UInt64Type);
- SUM_AGG_FN_CASE(Int64Type);
- SUM_AGG_FN_CASE(FloatType);
- SUM_AGG_FN_CASE(DoubleType);
- default:
- return nullptr;
- }
-
-#undef SUM_AGG_FN_CASE
-}
-
-static Status GetSumKernel(FunctionContext* ctx, const DataType& type,
- std::shared_ptr<AggregateUnaryKernel>& kernel) {
- std::shared_ptr<AggregateFunction> aggregate = MakeSumAggregateFunction(type, ctx);
- if (!aggregate) return Status::Invalid("No sum for type ", type);
-
- kernel = std::make_shared<AggregateUnaryKernel>(aggregate);
-
- return Status::OK();
-}
-
-Status Sum(FunctionContext* ctx, const Datum& value, Datum* out) {
- std::shared_ptr<AggregateUnaryKernel> kernel;
-
- auto data_type = value.type();
- if (data_type == nullptr)
- return Status::Invalid("Datum must be array-like");
- else if (!is_integer(data_type->id()) && !is_floating(data_type->id()))
- return Status::Invalid("Datum must contain a NumericType");
-
- RETURN_NOT_OK(GetSumKernel(ctx, *data_type, kernel));
-
- return kernel->Call(ctx, value, out);
-}
-
-Status Sum(FunctionContext* ctx, const Array& array, Datum* out) {
- return Sum(ctx, array.data(), out);
-}
+}; // namespace compute
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/sum.cc b/cpp/src/arrow/compute/kernels/sum.cc
index 1799941..14b999c 100644
--- a/cpp/src/arrow/compute/kernels/sum.cc
+++ b/cpp/src/arrow/compute/kernels/sum.cc
@@ -16,14 +16,7 @@
// under the License.
#include "arrow/compute/kernels/sum.h"
-
-#include "arrow/array.h"
-#include "arrow/compute/kernel.h"
-#include "arrow/compute/kernels/aggregate.h"
-#include "arrow/type_traits.h"
-#include "arrow/util/bit-util.h"
-#include "arrow/util/logging.h"
-#include "arrow/visitor_inline.h"
+#include "arrow/compute/kernels/sum-internal.h"
namespace arrow {
namespace compute {
@@ -44,184 +37,30 @@ struct SumState {
return *this;
}
- std::shared_ptr<Scalar> AsScalar() const {
+ std::shared_ptr<Scalar> Finalize() const {
using ScalarType = typename TypeTraits<SumType>::ScalarType;
- return std::make_shared<ScalarType>(this->sum);
- }
-
- size_t count = 0;
- typename SumType::c_type sum = 0;
-};
-
-constexpr int64_t CoveringBytes(int64_t offset, int64_t length) {
- return (BitUtil::RoundUp(length + offset, 8) - BitUtil::RoundDown(offset, 8)) / 8;
-}
-
-static_assert(CoveringBytes(0, 8) == 1, "");
-static_assert(CoveringBytes(0, 9) == 2, "");
-static_assert(CoveringBytes(1, 7) == 1, "");
-static_assert(CoveringBytes(1, 8) == 2, "");
-static_assert(CoveringBytes(2, 19) == 3, "");
-static_assert(CoveringBytes(7, 18) == 4, "");
-
-template <typename ArrowType, typename StateType = SumState<ArrowType>>
-class SumAggregateFunction final : public AggregateFunctionStaticState<StateType> {
- using CType = typename TypeTraits<ArrowType>::CType;
- using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
-
- static constexpr int64_t kTinyThreshold = 32;
- static_assert(kTinyThreshold > 18,
- "ConsumeSparse requires at least 18 elements to fit 3 bytes");
-
- public:
- Status Consume(const Array& input, StateType* state) const override {
- const ArrayType& array = static_cast<const ArrayType&>(input);
- if (input.null_count() == 0) {
- *state = ConsumeDense(array);
- } else if (input.length() <= kTinyThreshold) {
- // In order to simplify ConsumeSparse implementation (requires at least 3
- // bytes of bitmap data), small arrays are handled differently.
- *state = ConsumeTiny(array);
- } else {
- *state = ConsumeSparse(array);
- }
-
- return Status::OK();
- }
-
- Status Merge(const StateType& src, StateType* dst) const override {
- *dst += src;
- return Status::OK();
- }
-
- Status Finalize(const StateType& src, Datum* output) const override {
- auto boxed = src.AsScalar();
- if (src.count == 0) {
+ auto boxed = std::make_shared<ScalarType>(this->sum);
+ if (count == 0) {
// TODO(wesm): Currently null, but fix this
boxed->is_valid = false;
}
- *output = boxed;
- return Status::OK();
- }
-
- std::shared_ptr<DataType> out_type() const override {
- return TypeTraits<typename FindAccumulatorType<ArrowType>::Type>::type_singleton();
- }
-
- private:
- StateType ConsumeDense(const ArrayType& array) const {
- StateType local;
-
- const auto values = array.raw_values();
- const int64_t length = array.length();
- for (int64_t i = 0; i < length; i++) {
- local.sum += values[i];
- }
-
- local.count = length;
-
- return local;
- }
-
- StateType ConsumeTiny(const ArrayType& array) const {
- StateType local;
-
- internal::BitmapReader reader(array.null_bitmap_data(), array.offset(),
- array.length());
- const auto values = array.raw_values();
- for (int64_t i = 0; i < array.length(); i++) {
- if (reader.IsSet()) {
- local.sum += values[i];
- local.count++;
- }
- reader.Next();
- }
- return local;
+ return boxed;
}
- inline StateType UnrolledSum(uint8_t bits, const CType* values) const {
- StateType local;
-
- if (bits < 0xFF) {
-#define SUM_SHIFT(ITEM) values[ITEM] * static_cast<CType>(((bits >> ITEM) & 1U))
- // Some nulls
- local.sum += SUM_SHIFT(0);
- local.sum += SUM_SHIFT(1);
- local.sum += SUM_SHIFT(2);
- local.sum += SUM_SHIFT(3);
- local.sum += SUM_SHIFT(4);
- local.sum += SUM_SHIFT(5);
- local.sum += SUM_SHIFT(6);
- local.sum += SUM_SHIFT(7);
- local.count += BitUtil::kBytePopcount[bits];
-#undef SUM_SHIFT
- } else {
- // No nulls
- for (size_t i = 0; i < 8; i++) {
- local.sum += values[i];
- }
- local.count += 8;
- }
-
- return local;
+ static std::shared_ptr<DataType> out_type() {
+ return TypeTraits<SumType>::type_singleton();
}
- StateType ConsumeSparse(const ArrayType& array) const {
- StateType local;
-
- // Sliced bitmaps on non-byte positions induce problem with the branchless
- // unrolled technique. Thus extra padding is added on both left and right
- // side of the slice such that both ends are byte-aligned. The first and
- // last bitmap are properly masked to ignore extra values induced by
- // padding.
- //
- // The execution is divided in 3 sections.
- //
- // 1. Compute the sum of the first masked byte.
- // 2. Compute the sum of the middle bytes
- // 3. Compute the sum of the last masked byte.
-
- const int64_t length = array.length();
- const int64_t offset = array.offset();
-
- // The number of bytes covering the range, this includes partial bytes.
- // This number bounded by `<= (length / 8) + 2`, e.g. a possible extra byte
- // on the left, and on the right.
- const int64_t covering_bytes = CoveringBytes(offset, length);
-
- // Align values to the first batch of 8 elements. Note that raw_values() is
- // already adjusted with the offset, thus we rewind a little to align to
- // the closest 8-batch offset.
- const auto values = array.raw_values() - (offset % 8);
-
- // Align bitmap at the first consumable byte.
- const auto bitmap = array.null_bitmap_data() + BitUtil::RoundDown(offset, 8) / 8;
-
- // Consume the first (potentially partial) byte.
- const uint8_t first_mask = BitUtil::kTrailingBitmask[offset % 8];
- local += UnrolledSum(bitmap[0] & first_mask, values);
-
- // Consume the (full) middle bytes. The loop iterates in unit of
- // batches of 8 values and 1 byte of bitmap.
- for (int64_t i = 1; i < covering_bytes - 1; i++) {
- local += UnrolledSum(bitmap[i], &values[i * 8]);
- }
-
- // Consume the last (potentially partial) byte.
- const int64_t last_idx = covering_bytes - 1;
- const uint8_t last_mask = BitUtil::kPrecedingWrappingBitmask[(offset + length) % 8];
- local += UnrolledSum(bitmap[last_idx] & last_mask, &values[last_idx * 8]);
-
- return local;
- }
+ size_t count = 0;
+ typename SumType::c_type sum = 0;
};
#define SUM_AGG_FN_CASE(T) \
case T::type_id: \
return std::static_pointer_cast<AggregateFunction>( \
- std::make_shared<SumAggregateFunction<T>>());
+ std::make_shared<SumAggregateFunction<T, SumState<T>>>());
std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
FunctionContext* ctx) {
diff --git a/cpp/src/arrow/compute/kernels/sum.h b/cpp/src/arrow/compute/kernels/sum.h
index 88da2ac..e6f9549 100644
--- a/cpp/src/arrow/compute/kernels/sum.h
+++ b/cpp/src/arrow/compute/kernels/sum.h
@@ -15,49 +15,31 @@
// specific language governing permissions and limitations
// under the License.
-#ifndef ARROW_COMPUTE_KERNELS_SUM_H
-#define ARROW_COMPUTE_KERNELS_SUM_H
+#pragma once
#include <memory>
-#include <type_traits>
-#include "arrow/status.h"
-#include "arrow/type.h"
-#include "arrow/type_traits.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
+class Status;
namespace compute {
-// Find the largest compatible primitive type for a primitive type.
-template <typename I, typename Enable = void>
-struct FindAccumulatorType {
- using Type = DoubleType;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsSignedInt<I>::value>::type> {
- using Type = Int64Type;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsUnsignedInt<I>::value>::type> {
- using Type = UInt64Type;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsFloatingPoint<I>::value>::type> {
- using Type = DoubleType;
-};
-
struct Datum;
class FunctionContext;
class AggregateFunction;
+/// \brief Return a Sum Kernel
+///
+/// \param[in] type required to specialize the kernel
+/// \param[in] context the FunctionContext
+///
+/// \since 0.13.0
+/// \note API not yet finalized
ARROW_EXPORT
std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
FunctionContext* context);
@@ -86,5 +68,3 @@ Status Sum(FunctionContext* context, const Array& array, Datum* out);
} // namespace compute
} // namespace arrow
-
-#endif // ARROW_COMPUTE_KERNELS_CAST_H
diff --git a/cpp/src/arrow/compute/test-util.h b/cpp/src/arrow/compute/test-util.h
index e90a034..bec54cc 100644
--- a/cpp/src/arrow/compute/test-util.h
+++ b/cpp/src/arrow/compute/test-util.h
@@ -69,6 +69,41 @@ std::shared_ptr<Array> _MakeArray(const std::shared_ptr<DataType>& type,
return result;
}
+template <typename Type, typename Enable = void>
+struct DatumEqual {};
+
+template <typename Type>
+struct DatumEqual<Type, typename std::enable_if<IsFloatingPoint<Type>::value>::type> {
+ static constexpr double kArbitraryDoubleErrorBound = 1.0;
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+
+ static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
+ ASSERT_EQ(lhs.kind(), rhs.kind());
+ if (lhs.kind() == Datum::SCALAR) {
+ auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
+ auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
+ ASSERT_EQ(left->is_valid, right->is_valid);
+ ASSERT_EQ(left->type->id(), right->type->id());
+ ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
+ }
+ }
+};
+
+template <typename Type>
+struct DatumEqual<Type, typename std::enable_if<!IsFloatingPoint<Type>::value>::type> {
+ using ScalarType = typename TypeTraits<Type>::ScalarType;
+ static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
+ ASSERT_EQ(lhs.kind(), rhs.kind());
+ if (lhs.kind() == Datum::SCALAR) {
+ auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
+ auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
+ ASSERT_EQ(left->is_valid, right->is_valid);
+ ASSERT_EQ(left->type->id(), right->type->id());
+ ASSERT_EQ(left->value, right->value);
+ }
+ }
+};
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/util/bit-util-test.cc b/cpp/src/arrow/util/bit-util-test.cc
index e8f32d2..774d3bf 100644
--- a/cpp/src/arrow/util/bit-util-test.cc
+++ b/cpp/src/arrow/util/bit-util-test.cc
@@ -750,6 +750,15 @@ TEST(BitUtil, RoundDown) {
}
}
+TEST(BitUtil, CoveringBytes) {
+ EXPECT_EQ(BitUtil::CoveringBytes(0, 8), 1);
+ EXPECT_EQ(BitUtil::CoveringBytes(0, 9), 2);
+ EXPECT_EQ(BitUtil::CoveringBytes(1, 7), 1);
+ EXPECT_EQ(BitUtil::CoveringBytes(1, 8), 2);
+ EXPECT_EQ(BitUtil::CoveringBytes(2, 19), 3);
+ EXPECT_EQ(BitUtil::CoveringBytes(7, 18), 4);
+}
+
TEST(BitUtil, TrailingBits) {
EXPECT_EQ(BitUtil::TrailingBits(BOOST_BINARY(1 1 1 1 1 1 1 1), 0), 0);
EXPECT_EQ(BitUtil::TrailingBits(BOOST_BINARY(1 1 1 1 1 1 1 1), 1), 1);
diff --git a/cpp/src/arrow/util/bit-util.h b/cpp/src/arrow/util/bit-util.h
index 53b5588..6724c29 100644
--- a/cpp/src/arrow/util/bit-util.h
+++ b/cpp/src/arrow/util/bit-util.h
@@ -154,6 +154,21 @@ constexpr int64_t RoundUpToMultipleOf64(int64_t num) {
return RoundUpToPowerOf2(num, 64);
}
+// Returns the number of bytes covering a sliced bitmap. Find the length
+// rounded to cover full bytes on both extremities.
+//
+// The following example represents a slice (offset=10, length=9)
+//
+// 0 8 16 24
+// |-------|-------|------|
+// [ ] (slice)
+// [ ] (same slice aligned to bytes bounds, length=16)
+//
+// The covering bytes is the length (in bytes) of this new aligned slice.
+constexpr int64_t CoveringBytes(int64_t offset, int64_t length) {
+ return (BitUtil::RoundUp(length + offset, 8) - BitUtil::RoundDown(offset, 8)) / 8;
+}
+
// Returns the 'num_bits' least-significant bits of 'v'.
static inline uint64_t TrailingBits(uint64_t v, int num_bits) {
if (ARROW_PREDICT_FALSE(num_bits == 0)) return 0;