You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/02/28 01:34:55 UTC

[arrow] branch master updated: ARROW-3121: [C++] Mean aggregate kernel

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 29aa925  ARROW-3121: [C++] Mean aggregate kernel
29aa925 is described below

commit 29aa925683870d2644bb6334610d164aaeef6d10
Author: François Saint-Jacques <fs...@gmail.com>
AuthorDate: Wed Feb 27 19:34:45 2019 -0600

    ARROW-3121: [C++] Mean aggregate kernel
    
    Implements the mean (average) kernel aggregates on numeric columns. The final type is always a double. Refactored the Sum kernel implementation to share common parts, notably the consume part is identical. Only the Finalize and output type differ.
    
    Author: François Saint-Jacques <fs...@gmail.com>
    
    Closes #3708 from fsaintjacques/ARROW-3121-mean-aggregate and squashes the following commits:
    
    0d39c1f4 <François Saint-Jacques> reformat
    d41db22e <François Saint-Jacques> Refactor with ternary
    d1191fd8 <François Saint-Jacques> Deal with NaN values in sum
    c448bcbc <François Saint-Jacques> Add documentation per review
    79291402 <François Saint-Jacques> Implement mean aggregate
    3a1a0cd8 <François Saint-Jacques> Refactor sum implementation
    8bc293f1 <François Saint-Jacques> Move TypeTraits into sum-internal.h
---
 cpp/src/arrow/CMakeLists.txt                       |   1 +
 cpp/src/arrow/compute/kernels/aggregate-test.cc    | 191 +++++++++++++--------
 cpp/src/arrow/compute/kernels/mean.cc              | 115 +++++++++++++
 cpp/src/arrow/compute/kernels/{sum.h => mean.h}    |  48 ++----
 .../compute/kernels/{sum.cc => sum-internal.h}     | 154 +++++------------
 cpp/src/arrow/compute/kernels/sum.cc               | 181 ++-----------------
 cpp/src/arrow/compute/kernels/sum.h                |  38 +---
 cpp/src/arrow/compute/test-util.h                  |  35 ++++
 cpp/src/arrow/util/bit-util-test.cc                |   9 +
 cpp/src/arrow/util/bit-util.h                      |  15 ++
 10 files changed, 372 insertions(+), 415 deletions(-)

diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 94eba0c..3c6a399 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -147,6 +147,7 @@ if(ARROW_COMPUTE)
       compute/kernels/boolean.cc
       compute/kernels/cast.cc
       compute/kernels/hash.cc
+      compute/kernels/mean.cc
       compute/kernels/sum.cc
       compute/kernels/util-internal.cc)
 endif()
diff --git a/cpp/src/arrow/compute/kernels/aggregate-test.cc b/cpp/src/arrow/compute/kernels/aggregate-test.cc
index ca44744..bdf50f5 100644
--- a/cpp/src/arrow/compute/kernels/aggregate-test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate-test.cc
@@ -15,13 +15,18 @@
 // specific language governing permissions and limitations
 // under the License.
 
+#include <algorithm>
+#include <memory>
 #include <string>
 #include <type_traits>
+#include <utility>
 
 #include <gtest/gtest.h>
 
 #include "arrow/array.h"
 #include "arrow/compute/kernel.h"
+#include "arrow/compute/kernels/mean.h"
+#include "arrow/compute/kernels/sum-internal.h"
 #include "arrow/compute/kernels/sum.h"
 #include "arrow/compute/test-util.h"
 #include "arrow/type.h"
@@ -38,63 +43,16 @@ using std::vector;
 namespace arrow {
 namespace compute {
 
-template <typename Type, typename Enable = void>
-struct DatumEqual {
-  static void EnsureEqual(const Datum& lhs, const Datum& rhs) {}
-};
-
-template <typename Type>
-struct DatumEqual<Type, typename std::enable_if<IsFloatingPoint<Type>::Value>::type> {
-  static constexpr double kArbitraryDoubleErrorBound = 1.0;
-  using ScalarType = typename TypeTraits<Type>::ScalarType;
-
-  static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
-    ASSERT_EQ(lhs.kind(), rhs.kind());
-    if (lhs.kind() == Datum::SCALAR) {
-      auto left = static_cast<const ScalarType*>(lhs.scalar().get());
-      auto right = static_cast<const ScalarType*>(rhs.scalar().get());
-      ASSERT_EQ(left->type->id(), right->type->id());
-      ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
-    }
-  }
-};
-
-template <typename Type>
-struct DatumEqual<Type, typename std::enable_if<!IsFloatingPoint<Type>::value>::type> {
-  using ScalarType = typename TypeTraits<Type>::ScalarType;
-  static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
-    ASSERT_EQ(lhs.kind(), rhs.kind());
-    if (lhs.kind() == Datum::SCALAR) {
-      auto left = static_cast<const ScalarType*>(lhs.scalar().get());
-      auto right = static_cast<const ScalarType*>(rhs.scalar().get());
-      ASSERT_EQ(left->type->id(), right->type->id());
-      ASSERT_EQ(left->value, right->value);
-    }
-  }
-};
-
 template <typename ArrowType>
-void ValidateSum(FunctionContext* ctx, const Array& input, Datum expected) {
-  using OutputType = typename FindAccumulatorType<ArrowType>::Type;
-  Datum result;
-  ASSERT_OK(Sum(ctx, input, &result));
-  DatumEqual<OutputType>::EnsureEqual(result, expected);
-}
+using SumResult =
+    std::pair<typename FindAccumulatorType<ArrowType>::Type::c_type, size_t>;
 
 template <typename ArrowType>
-void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
-  auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
-  ValidateSum<ArrowType>(ctx, *array, expected);
-}
-
-template <typename ArrowType>
-static Datum DummySum(const Array& array) {
+static SumResult<ArrowType> NaiveSumPartial(const Array& array) {
   using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
-  using SumType = typename FindAccumulatorType<ArrowType>::Type;
-  using SumScalarType = typename TypeTraits<SumType>::ScalarType;
+  using ResultType = SumResult<ArrowType>;
 
-  typename SumType::c_type sum = 0;
-  int64_t count = 0;
+  ResultType result;
 
   auto data = array.data();
   internal::BitmapReader reader(array.null_bitmap_data(), array.offset(), array.length());
@@ -102,30 +60,52 @@ static Datum DummySum(const Array& array) {
   const auto values = array_numeric.raw_values();
   for (int64_t i = 0; i < array.length(); i++) {
     if (reader.IsSet()) {
-      sum += values[i];
-      count++;
+      result.first += values[i];
+      result.second++;
     }
 
     reader.Next();
   }
 
-  if (count > 0) {
-    return Datum(std::make_shared<SumScalarType>(sum));
-  } else {
-    return Datum(std::make_shared<SumScalarType>(0, false));
-  }
+  return result;
+}
+
+template <typename ArrowType>
+static Datum NaiveSum(const Array& array) {
+  using SumType = typename FindAccumulatorType<ArrowType>::Type;
+  using SumScalarType = typename TypeTraits<SumType>::ScalarType;
+
+  auto result = NaiveSumPartial<ArrowType>(array);
+  bool is_valid = result.second > 0;
+
+  return Datum(std::make_shared<SumScalarType>(result.first, is_valid));
+}
+
+template <typename ArrowType>
+void ValidateSum(FunctionContext* ctx, const Array& input, Datum expected) {
+  using OutputType = typename FindAccumulatorType<ArrowType>::Type;
+
+  Datum result;
+  ASSERT_OK(Sum(ctx, input, &result));
+  DatumEqual<OutputType>::EnsureEqual(result, expected);
+}
+
+template <typename ArrowType>
+void ValidateSum(FunctionContext* ctx, const char* json, Datum expected) {
+  auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
+  ValidateSum<ArrowType>(ctx, *array, expected);
 }
 
 template <typename ArrowType>
 void ValidateSum(FunctionContext* ctx, const Array& array) {
-  ValidateSum<ArrowType>(ctx, array, DummySum<ArrowType>(array));
+  ValidateSum<ArrowType>(ctx, array, NaiveSum<ArrowType>(array));
 }
 
 template <typename ArrowType>
-class TestSumKernelNumeric : public ComputeFixture, public TestBase {};
+class TestNumericSumKernel : public ComputeFixture, public TestBase {};
 
-TYPED_TEST_CASE(TestSumKernelNumeric, NumericArrowTypes);
-TYPED_TEST(TestSumKernelNumeric, SimpleSum) {
+TYPED_TEST_CASE(TestNumericSumKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericSumKernel, SimpleSum) {
   using SumType = typename FindAccumulatorType<TypeParam>::Type;
   using ScalarType = typename TypeTraits<SumType>::ScalarType;
   using T = typename TypeParam::c_type;
@@ -145,10 +125,10 @@ TYPED_TEST(TestSumKernelNumeric, SimpleSum) {
 }
 
 template <typename ArrowType>
-class TestRandomSumKernelNumeric : public ComputeFixture, public TestBase {};
+class TestRandomNumericSumKernel : public ComputeFixture, public TestBase {};
 
-TYPED_TEST_CASE(TestRandomSumKernelNumeric, NumericArrowTypes);
-TYPED_TEST(TestRandomSumKernelNumeric, RandomArraySum) {
+TYPED_TEST_CASE(TestRandomNumericSumKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericSumKernel, RandomArraySum) {
   auto rand = random::RandomArrayGenerator(0x5487655);
   for (size_t i = 3; i < 14; i++) {
     for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
@@ -161,7 +141,7 @@ TYPED_TEST(TestRandomSumKernelNumeric, RandomArraySum) {
   }
 }
 
-TYPED_TEST(TestRandomSumKernelNumeric, RandomSliceArraySum) {
+TYPED_TEST(TestRandomNumericSumKernel, RandomSliceArraySum) {
   auto arithmetic = ArrayFromJSON(TypeTraits<TypeParam>::type_singleton(),
                                   "[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]");
   ValidateSum<TypeParam>(&this->ctx_, *arithmetic);
@@ -175,12 +155,87 @@ TYPED_TEST(TestRandomSumKernelNumeric, RandomSliceArraySum) {
   const int64_t length = 1U << 6;
   auto array = rand.Numeric<TypeParam>(length, 0, 10, 0.5);
   for (size_t i = 1; i < 16; i++) {
-    for (size_t j = 1; i < 16; i++) {
+    for (size_t j = 1; j < 16; j++) {
       auto slice = array->Slice(i, length - j);
       ValidateSum<TypeParam>(&this->ctx_, *slice);
     }
   }
 }
 
+template <typename ArrowType>
+static Datum NaiveMean(const Array& array) {
+  using MeanScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+  const auto result = NaiveSumPartial<ArrowType>(array);
+  const double mean = static_cast<double>(result.first) /
+                      static_cast<double>(result.second ? result.second : 1UL);
+  const bool is_valid = result.second > 0;
+
+  return Datum(std::make_shared<MeanScalarType>(mean, is_valid));
+}
+
+template <typename ArrowType>
+void ValidateMean(FunctionContext* ctx, const Array& input, Datum expected) {
+  using OutputType = typename FindAccumulatorType<DoubleType>::Type;
+
+  Datum result;
+  ASSERT_OK(Mean(ctx, input, &result));
+  DatumEqual<OutputType>::EnsureEqual(result, expected);
+}
+
+template <typename ArrowType>
+void ValidateMean(FunctionContext* ctx, const char* json, Datum expected) {
+  auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
+  ValidateMean<ArrowType>(ctx, *array, expected);
+}
+
+template <typename ArrowType>
+void ValidateMean(FunctionContext* ctx, const Array& array) {
+  ValidateMean<ArrowType>(ctx, array, NaiveMean<ArrowType>(array));
+}
+
+template <typename ArrowType>
+class TestMeanKernelNumeric : public ComputeFixture, public TestBase {};
+
+TYPED_TEST_CASE(TestMeanKernelNumeric, NumericArrowTypes);
+TYPED_TEST(TestMeanKernelNumeric, SimpleMean) {
+  using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+  ValidateMean<TypeParam>(&this->ctx_, "[]",
+                          Datum(std::make_shared<ScalarType>(0.0, false)));
+
+  ValidateMean<TypeParam>(&this->ctx_, "[null]",
+                          Datum(std::make_shared<ScalarType>(0.0, false)));
+
+  ValidateMean<TypeParam>(&this->ctx_, "[1, null, 1]",
+                          Datum(std::make_shared<ScalarType>(1.0)));
+
+  ValidateMean<TypeParam>(&this->ctx_, "[1, 2, 3, 4, 5, 6, 7, 8]",
+                          Datum(std::make_shared<ScalarType>(4.5)));
+
+  ValidateMean<TypeParam>(&this->ctx_, "[0, 0, 0, 0, 0, 0, 0, 0]",
+                          Datum(std::make_shared<ScalarType>(0.0)));
+
+  ValidateMean<TypeParam>(&this->ctx_, "[1, 1, 1, 1, 1, 1, 1, 1]",
+                          Datum(std::make_shared<ScalarType>(1.0)));
+}
+
+template <typename ArrowType>
+class TestRandomNumericMeanKernel : public ComputeFixture, public TestBase {};
+
+TYPED_TEST_CASE(TestRandomNumericMeanKernel, NumericArrowTypes);
+TYPED_TEST(TestRandomNumericMeanKernel, RandomArrayMean) {
+  auto rand = random::RandomArrayGenerator(0x8afc055);
+  for (size_t i = 3; i < 14; i++) {
+    for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
+      for (auto length_adjust : {-2, -1, 0, 1, 2}) {
+        int64_t length = (1UL << i) + length_adjust;
+        auto array = rand.Numeric<TypeParam>(length, 0, 100, null_probability);
+        ValidateMean<TypeParam>(&this->ctx_, *array);
+      }
+    }
+  }
+}
+
 }  // namespace compute
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/mean.cc b/cpp/src/arrow/compute/kernels/mean.cc
new file mode 100644
index 0000000..d1eaf15
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/mean.cc
@@ -0,0 +1,115 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/kernels/mean.h"
+
+#include <algorithm>
+
+#include "arrow/compute/kernels/sum-internal.h"
+
+namespace arrow {
+namespace compute {
+
+template <typename ArrowType,
+          typename SumType = typename FindAccumulatorType<ArrowType>::Type>
+struct MeanState {
+  using ThisType = MeanState<ArrowType, SumType>;
+
+  ThisType operator+(const ThisType& rhs) const {
+    return ThisType(this->count + rhs.count, this->sum + rhs.sum);
+  }
+
+  ThisType& operator+=(const ThisType& rhs) {
+    this->count += rhs.count;
+    this->sum += rhs.sum;
+
+    return *this;
+  }
+
+  std::shared_ptr<Scalar> Finalize() const {
+    using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
+
+    const bool is_valid = count > 0;
+    const double divisor = static_cast<double>(is_valid ? count : 1UL);
+    const double mean = static_cast<double>(sum) / divisor;
+
+    return std::make_shared<ScalarType>(mean, is_valid);
+  }
+
+  static std::shared_ptr<DataType> out_type() {
+    return TypeTraits<DoubleType>::type_singleton();
+  }
+
+  size_t count = 0;
+  typename SumType::c_type sum = 0;
+};
+
+#define MEAN_AGG_FN_CASE(T)                             \
+  case T::type_id:                                      \
+    return std::static_pointer_cast<AggregateFunction>( \
+        std::make_shared<SumAggregateFunction<T, MeanState<T>>>());
+
+std::shared_ptr<AggregateFunction> MakeMeanAggregateFunction(const DataType& type,
+                                                             FunctionContext* ctx) {
+  switch (type.id()) {
+    MEAN_AGG_FN_CASE(UInt8Type);
+    MEAN_AGG_FN_CASE(Int8Type);
+    MEAN_AGG_FN_CASE(UInt16Type);
+    MEAN_AGG_FN_CASE(Int16Type);
+    MEAN_AGG_FN_CASE(UInt32Type);
+    MEAN_AGG_FN_CASE(Int32Type);
+    MEAN_AGG_FN_CASE(UInt64Type);
+    MEAN_AGG_FN_CASE(Int64Type);
+    MEAN_AGG_FN_CASE(FloatType);
+    MEAN_AGG_FN_CASE(DoubleType);
+    default:
+      return nullptr;
+  }
+
+#undef MEAN_AGG_FN_CASE
+}
+
+static Status GetMeanKernel(FunctionContext* ctx, const DataType& type,
+                            std::shared_ptr<AggregateUnaryKernel>& kernel) {
+  std::shared_ptr<AggregateFunction> aggregate = MakeMeanAggregateFunction(type, ctx);
+  if (!aggregate) return Status::Invalid("No mean for type ", type);
+
+  kernel = std::make_shared<AggregateUnaryKernel>(aggregate);
+
+  return Status::OK();
+}
+
+Status Mean(FunctionContext* ctx, const Datum& value, Datum* out) {
+  std::shared_ptr<AggregateUnaryKernel> kernel;
+
+  auto data_type = value.type();
+  if (data_type == nullptr)
+    return Status::Invalid("Datum must be array-like");
+  else if (!is_integer(data_type->id()) && !is_floating(data_type->id()))
+    return Status::Invalid("Datum must contain a NumericType");
+
+  RETURN_NOT_OK(GetMeanKernel(ctx, *data_type, kernel));
+
+  return kernel->Call(ctx, value, out);
+}
+
+Status Mean(FunctionContext* ctx, const Array& array, Datum* out) {
+  return Mean(ctx, array.data(), out);
+}
+
+}  // namespace compute
+}  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/sum.h b/cpp/src/arrow/compute/kernels/mean.h
similarity index 51%
copy from cpp/src/arrow/compute/kernels/sum.h
copy to cpp/src/arrow/compute/kernels/mean.h
index 88da2ac..5074d4e 100644
--- a/cpp/src/arrow/compute/kernels/sum.h
+++ b/cpp/src/arrow/compute/kernels/mean.h
@@ -15,8 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#ifndef ARROW_COMPUTE_KERNELS_SUM_H
-#define ARROW_COMPUTE_KERNELS_SUM_H
+#pragma once
 
 #include <memory>
 #include <type_traits>
@@ -33,58 +32,35 @@ class DataType;
 
 namespace compute {
 
-// Find the largest compatible primitive type for a primitive type.
-template <typename I, typename Enable = void>
-struct FindAccumulatorType {
-  using Type = DoubleType;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsSignedInt<I>::value>::type> {
-  using Type = Int64Type;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsUnsignedInt<I>::value>::type> {
-  using Type = UInt64Type;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsFloatingPoint<I>::value>::type> {
-  using Type = DoubleType;
-};
-
 struct Datum;
 class FunctionContext;
 class AggregateFunction;
 
 ARROW_EXPORT
-std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
-                                                            FunctionContext* context);
+std::shared_ptr<AggregateFunction> MakeMeanAggregateFunction(const DataType& type,
+                                                             FunctionContext* context);
 
-/// \brief Sum values of a numeric array.
+/// \brief Compute the mean of a numeric array.
 ///
 /// \param[in] context the FunctionContext
-/// \param[in] value datum to sum, expecting Array or ChunkedArray
-/// \param[out] out resulting datum
+/// \param[in] value datum to compute the mean, expecting Array
+/// \param[out] mean datum of the computed mean as a DoubleScalar
 ///
 /// \since 0.13.0
 /// \note API not yet finalized
 ARROW_EXPORT
-Status Sum(FunctionContext* context, const Datum& value, Datum* out);
+Status Mean(FunctionContext* context, const Datum& value, Datum* mean);
 
-/// \brief Sum values of a numeric array.
+/// \brief Compute the mean of a numeric array.
 ///
 /// \param[in] context the FunctionContext
-/// \param[in] array to sum
-/// \param[out] out resulting datum
+/// \param[in] array to compute the mean
+/// \param[out] mean datum of the computed mean as a DoubleScalar
 ///
 /// \since 0.13.0
 /// \note API not yet finalized
 ARROW_EXPORT
-Status Sum(FunctionContext* context, const Array& array, Datum* out);
+Status Mean(FunctionContext* context, const Array& array, Datum* mean);
 
 }  // namespace compute
-}  // namespace arrow
-
-#endif  // ARROW_COMPUTE_KERNELS_CAST_H
+};  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/sum.cc b/cpp/src/arrow/compute/kernels/sum-internal.h
similarity index 56%
copy from cpp/src/arrow/compute/kernels/sum.cc
copy to cpp/src/arrow/compute/kernels/sum-internal.h
index 1799941..a4e7ea6 100644
--- a/cpp/src/arrow/compute/kernels/sum.cc
+++ b/cpp/src/arrow/compute/kernels/sum-internal.h
@@ -1,7 +1,7 @@
 // Licensed to the Apache Software Foundation (ASF) under one
 // or more contributor license agreements.  See the NOTICE file
 // distributed with this work for additional information
-// returnGegarding copyright ownership.  The ASF licenses this file
+// regarding copyright ownership.  The ASF licenses this file
 // to you under the Apache License, Version 2.0 (the
 // "License"); you may not use this file except in compliance
 // with the License.  You may obtain a copy of the License at
@@ -15,63 +15,56 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#include "arrow/compute/kernels/sum.h"
+#pragma once
+
+#include <memory>
+#include <type_traits>
 
-#include "arrow/array.h"
 #include "arrow/compute/kernel.h"
 #include "arrow/compute/kernels/aggregate.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
 #include "arrow/type_traits.h"
 #include "arrow/util/bit-util.h"
 #include "arrow/util/logging.h"
-#include "arrow/visitor_inline.h"
 
 namespace arrow {
-namespace compute {
 
-template <typename ArrowType,
-          typename SumType = typename FindAccumulatorType<ArrowType>::Type>
-struct SumState {
-  using ThisType = SumState<ArrowType, SumType>;
-
-  ThisType operator+(const ThisType& rhs) const {
-    return ThisType(this->count + rhs.count, this->sum + rhs.sum);
-  }
+class Array;
+class DataType;
 
-  ThisType& operator+=(const ThisType& rhs) {
-    this->count += rhs.count;
-    this->sum += rhs.sum;
-
-    return *this;
-  }
+namespace compute {
 
-  std::shared_ptr<Scalar> AsScalar() const {
-    using ScalarType = typename TypeTraits<SumType>::ScalarType;
-    return std::make_shared<ScalarType>(this->sum);
-  }
+// Find the largest compatible primitive type for a primitive type.
+template <typename I, typename Enable = void>
+struct FindAccumulatorType {};
 
-  size_t count = 0;
-  typename SumType::c_type sum = 0;
+template <typename I>
+struct FindAccumulatorType<I, enable_if_signed_integer<I>> {
+  using Type = Int64Type;
 };
 
-constexpr int64_t CoveringBytes(int64_t offset, int64_t length) {
-  return (BitUtil::RoundUp(length + offset, 8) - BitUtil::RoundDown(offset, 8)) / 8;
-}
+template <typename I>
+struct FindAccumulatorType<I, enable_if_unsigned_integer<I>> {
+  using Type = UInt64Type;
+};
 
-static_assert(CoveringBytes(0, 8) == 1, "");
-static_assert(CoveringBytes(0, 9) == 2, "");
-static_assert(CoveringBytes(1, 7) == 1, "");
-static_assert(CoveringBytes(1, 8) == 2, "");
-static_assert(CoveringBytes(2, 19) == 3, "");
-static_assert(CoveringBytes(7, 18) == 4, "");
+template <typename I>
+struct FindAccumulatorType<I, enable_if_floating_point<I>> {
+  using Type = DoubleType;
+};
 
-template <typename ArrowType, typename StateType = SumState<ArrowType>>
+template <typename ArrowType, typename StateType>
 class SumAggregateFunction final : public AggregateFunctionStaticState<StateType> {
   using CType = typename TypeTraits<ArrowType>::CType;
   using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
 
+  // A small number of elements rounded to the next cacheline. This should
+  // amount to a maximum of 4 cachelines when dealing with 8 bytes elements.
   static constexpr int64_t kTinyThreshold = 32;
-  static_assert(kTinyThreshold > 18,
-                "ConsumeSparse requires at least 18 elements to fit 3 bytes");
+  static_assert(kTinyThreshold >= (2 * CHAR_BIT) + 1,
+                "ConsumeSparse requires 3 bytes of null bitmap, and 17 is the"
+                "required minimum number of bits/elements to cover 3 bytes.");
 
  public:
   Status Consume(const Array& input, StateType* state) const override {
@@ -96,18 +89,11 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
   }
 
   Status Finalize(const StateType& src, Datum* output) const override {
-    auto boxed = src.AsScalar();
-    if (src.count == 0) {
-      // TODO(wesm): Currently null, but fix this
-      boxed->is_valid = false;
-    }
-    *output = boxed;
+    *output = src.Finalize();
     return Status::OK();
   }
 
-  std::shared_ptr<DataType> out_type() const override {
-    return TypeTraits<typename FindAccumulatorType<ArrowType>::Type>::type_singleton();
-  }
+  std::shared_ptr<DataType> out_type() const override { return StateType::out_type(); }
 
  private:
   StateType ConsumeDense(const ArrayType& array) const {
@@ -141,22 +127,20 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
     return local;
   }
 
+  // While this is not branchless, gcc needs this to be in a different function
+  // for it to generate cmov which ends to be slightly faster than
+  // multiplication but safe for handling NaN with doubles.
+  inline CType MaskedValue(bool valid, CType value) const { return valid ? value : 0; }
+
   inline StateType UnrolledSum(uint8_t bits, const CType* values) const {
     StateType local;
 
     if (bits < 0xFF) {
-#define SUM_SHIFT(ITEM) values[ITEM] * static_cast<CType>(((bits >> ITEM) & 1U))
       // Some nulls
-      local.sum += SUM_SHIFT(0);
-      local.sum += SUM_SHIFT(1);
-      local.sum += SUM_SHIFT(2);
-      local.sum += SUM_SHIFT(3);
-      local.sum += SUM_SHIFT(4);
-      local.sum += SUM_SHIFT(5);
-      local.sum += SUM_SHIFT(6);
-      local.sum += SUM_SHIFT(7);
+      for (size_t i = 0; i < 8; i++) {
+        local.sum += MaskedValue(bits & (1U << i), values[i]);
+      }
       local.count += BitUtil::kBytePopcount[bits];
-#undef SUM_SHIFT
     } else {
       // No nulls
       for (size_t i = 0; i < 8; i++) {
@@ -189,7 +173,8 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
     // The number of bytes covering the range, this includes partial bytes.
     // This number bounded by `<= (length / 8) + 2`, e.g. a possible extra byte
     // on the left, and on the right.
-    const int64_t covering_bytes = CoveringBytes(offset, length);
+    const int64_t covering_bytes = BitUtil::CoveringBytes(offset, length);
+    DCHECK_GE(covering_bytes, 3);
 
     // Align values to the first batch of 8 elements. Note that raw_values() is
     // already adjusted with the offset, thus we rewind a little to align to
@@ -216,60 +201,7 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
 
     return local;
   }
-};
-
-#define SUM_AGG_FN_CASE(T)                              \
-  case T::type_id:                                      \
-    return std::static_pointer_cast<AggregateFunction>( \
-        std::make_shared<SumAggregateFunction<T>>());
-
-std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
-                                                            FunctionContext* ctx) {
-  switch (type.id()) {
-    SUM_AGG_FN_CASE(UInt8Type);
-    SUM_AGG_FN_CASE(Int8Type);
-    SUM_AGG_FN_CASE(UInt16Type);
-    SUM_AGG_FN_CASE(Int16Type);
-    SUM_AGG_FN_CASE(UInt32Type);
-    SUM_AGG_FN_CASE(Int32Type);
-    SUM_AGG_FN_CASE(UInt64Type);
-    SUM_AGG_FN_CASE(Int64Type);
-    SUM_AGG_FN_CASE(FloatType);
-    SUM_AGG_FN_CASE(DoubleType);
-    default:
-      return nullptr;
-  }
-
-#undef SUM_AGG_FN_CASE
-}
-
-static Status GetSumKernel(FunctionContext* ctx, const DataType& type,
-                           std::shared_ptr<AggregateUnaryKernel>& kernel) {
-  std::shared_ptr<AggregateFunction> aggregate = MakeSumAggregateFunction(type, ctx);
-  if (!aggregate) return Status::Invalid("No sum for type ", type);
-
-  kernel = std::make_shared<AggregateUnaryKernel>(aggregate);
-
-  return Status::OK();
-}
-
-Status Sum(FunctionContext* ctx, const Datum& value, Datum* out) {
-  std::shared_ptr<AggregateUnaryKernel> kernel;
-
-  auto data_type = value.type();
-  if (data_type == nullptr)
-    return Status::Invalid("Datum must be array-like");
-  else if (!is_integer(data_type->id()) && !is_floating(data_type->id()))
-    return Status::Invalid("Datum must contain a NumericType");
-
-  RETURN_NOT_OK(GetSumKernel(ctx, *data_type, kernel));
-
-  return kernel->Call(ctx, value, out);
-}
-
-Status Sum(FunctionContext* ctx, const Array& array, Datum* out) {
-  return Sum(ctx, array.data(), out);
-}
+};  // namespace compute
 
 }  // namespace compute
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/sum.cc b/cpp/src/arrow/compute/kernels/sum.cc
index 1799941..14b999c 100644
--- a/cpp/src/arrow/compute/kernels/sum.cc
+++ b/cpp/src/arrow/compute/kernels/sum.cc
@@ -16,14 +16,7 @@
 // under the License.
 
 #include "arrow/compute/kernels/sum.h"
-
-#include "arrow/array.h"
-#include "arrow/compute/kernel.h"
-#include "arrow/compute/kernels/aggregate.h"
-#include "arrow/type_traits.h"
-#include "arrow/util/bit-util.h"
-#include "arrow/util/logging.h"
-#include "arrow/visitor_inline.h"
+#include "arrow/compute/kernels/sum-internal.h"
 
 namespace arrow {
 namespace compute {
@@ -44,184 +37,30 @@ struct SumState {
     return *this;
   }
 
-  std::shared_ptr<Scalar> AsScalar() const {
+  std::shared_ptr<Scalar> Finalize() const {
     using ScalarType = typename TypeTraits<SumType>::ScalarType;
-    return std::make_shared<ScalarType>(this->sum);
-  }
-
-  size_t count = 0;
-  typename SumType::c_type sum = 0;
-};
-
-constexpr int64_t CoveringBytes(int64_t offset, int64_t length) {
-  return (BitUtil::RoundUp(length + offset, 8) - BitUtil::RoundDown(offset, 8)) / 8;
-}
-
-static_assert(CoveringBytes(0, 8) == 1, "");
-static_assert(CoveringBytes(0, 9) == 2, "");
-static_assert(CoveringBytes(1, 7) == 1, "");
-static_assert(CoveringBytes(1, 8) == 2, "");
-static_assert(CoveringBytes(2, 19) == 3, "");
-static_assert(CoveringBytes(7, 18) == 4, "");
-
-template <typename ArrowType, typename StateType = SumState<ArrowType>>
-class SumAggregateFunction final : public AggregateFunctionStaticState<StateType> {
-  using CType = typename TypeTraits<ArrowType>::CType;
-  using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
-
-  static constexpr int64_t kTinyThreshold = 32;
-  static_assert(kTinyThreshold > 18,
-                "ConsumeSparse requires at least 18 elements to fit 3 bytes");
-
- public:
-  Status Consume(const Array& input, StateType* state) const override {
-    const ArrayType& array = static_cast<const ArrayType&>(input);
 
-    if (input.null_count() == 0) {
-      *state = ConsumeDense(array);
-    } else if (input.length() <= kTinyThreshold) {
-      // In order to simplify ConsumeSparse implementation (requires at least 3
-      // bytes of bitmap data), small arrays are handled differently.
-      *state = ConsumeTiny(array);
-    } else {
-      *state = ConsumeSparse(array);
-    }
-
-    return Status::OK();
-  }
-
-  Status Merge(const StateType& src, StateType* dst) const override {
-    *dst += src;
-    return Status::OK();
-  }
-
-  Status Finalize(const StateType& src, Datum* output) const override {
-    auto boxed = src.AsScalar();
-    if (src.count == 0) {
+    auto boxed = std::make_shared<ScalarType>(this->sum);
+    if (count == 0) {
       // TODO(wesm): Currently null, but fix this
       boxed->is_valid = false;
     }
-    *output = boxed;
-    return Status::OK();
-  }
-
-  std::shared_ptr<DataType> out_type() const override {
-    return TypeTraits<typename FindAccumulatorType<ArrowType>::Type>::type_singleton();
-  }
-
- private:
-  StateType ConsumeDense(const ArrayType& array) const {
-    StateType local;
-
-    const auto values = array.raw_values();
-    const int64_t length = array.length();
-    for (int64_t i = 0; i < length; i++) {
-      local.sum += values[i];
-    }
-
-    local.count = length;
-
-    return local;
-  }
-
-  StateType ConsumeTiny(const ArrayType& array) const {
-    StateType local;
-
-    internal::BitmapReader reader(array.null_bitmap_data(), array.offset(),
-                                  array.length());
-    const auto values = array.raw_values();
-    for (int64_t i = 0; i < array.length(); i++) {
-      if (reader.IsSet()) {
-        local.sum += values[i];
-        local.count++;
-      }
-      reader.Next();
-    }
 
-    return local;
+    return boxed;
   }
 
-  inline StateType UnrolledSum(uint8_t bits, const CType* values) const {
-    StateType local;
-
-    if (bits < 0xFF) {
-#define SUM_SHIFT(ITEM) values[ITEM] * static_cast<CType>(((bits >> ITEM) & 1U))
-      // Some nulls
-      local.sum += SUM_SHIFT(0);
-      local.sum += SUM_SHIFT(1);
-      local.sum += SUM_SHIFT(2);
-      local.sum += SUM_SHIFT(3);
-      local.sum += SUM_SHIFT(4);
-      local.sum += SUM_SHIFT(5);
-      local.sum += SUM_SHIFT(6);
-      local.sum += SUM_SHIFT(7);
-      local.count += BitUtil::kBytePopcount[bits];
-#undef SUM_SHIFT
-    } else {
-      // No nulls
-      for (size_t i = 0; i < 8; i++) {
-        local.sum += values[i];
-      }
-      local.count += 8;
-    }
-
-    return local;
+  static std::shared_ptr<DataType> out_type() {
+    return TypeTraits<SumType>::type_singleton();
   }
 
-  StateType ConsumeSparse(const ArrayType& array) const {
-    StateType local;
-
-    // Sliced bitmaps on non-byte positions induce problem with the branchless
-    // unrolled technique. Thus extra padding is added on both left and right
-    // side of the slice such that both ends are byte-aligned. The first and
-    // last bitmap are properly masked to ignore extra values induced by
-    // padding.
-    //
-    // The execution is divided in 3 sections.
-    //
-    // 1. Compute the sum of the first masked byte.
-    // 2. Compute the sum of the middle bytes
-    // 3. Compute the sum of the last masked byte.
-
-    const int64_t length = array.length();
-    const int64_t offset = array.offset();
-
-    // The number of bytes covering the range, this includes partial bytes.
-    // This number bounded by `<= (length / 8) + 2`, e.g. a possible extra byte
-    // on the left, and on the right.
-    const int64_t covering_bytes = CoveringBytes(offset, length);
-
-    // Align values to the first batch of 8 elements. Note that raw_values() is
-    // already adjusted with the offset, thus we rewind a little to align to
-    // the closest 8-batch offset.
-    const auto values = array.raw_values() - (offset % 8);
-
-    // Align bitmap at the first consumable byte.
-    const auto bitmap = array.null_bitmap_data() + BitUtil::RoundDown(offset, 8) / 8;
-
-    // Consume the first (potentially partial) byte.
-    const uint8_t first_mask = BitUtil::kTrailingBitmask[offset % 8];
-    local += UnrolledSum(bitmap[0] & first_mask, values);
-
-    // Consume the (full) middle bytes. The loop iterates in unit of
-    // batches of 8 values and 1 byte of bitmap.
-    for (int64_t i = 1; i < covering_bytes - 1; i++) {
-      local += UnrolledSum(bitmap[i], &values[i * 8]);
-    }
-
-    // Consume the last (potentially partial) byte.
-    const int64_t last_idx = covering_bytes - 1;
-    const uint8_t last_mask = BitUtil::kPrecedingWrappingBitmask[(offset + length) % 8];
-    local += UnrolledSum(bitmap[last_idx] & last_mask, &values[last_idx * 8]);
-
-    return local;
-  }
+  size_t count = 0;
+  typename SumType::c_type sum = 0;
 };
 
 #define SUM_AGG_FN_CASE(T)                              \
   case T::type_id:                                      \
     return std::static_pointer_cast<AggregateFunction>( \
-        std::make_shared<SumAggregateFunction<T>>());
+        std::make_shared<SumAggregateFunction<T, SumState<T>>>());
 
 std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
                                                             FunctionContext* ctx) {
diff --git a/cpp/src/arrow/compute/kernels/sum.h b/cpp/src/arrow/compute/kernels/sum.h
index 88da2ac..e6f9549 100644
--- a/cpp/src/arrow/compute/kernels/sum.h
+++ b/cpp/src/arrow/compute/kernels/sum.h
@@ -15,49 +15,31 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#ifndef ARROW_COMPUTE_KERNELS_SUM_H
-#define ARROW_COMPUTE_KERNELS_SUM_H
+#pragma once
 
 #include <memory>
-#include <type_traits>
 
-#include "arrow/status.h"
-#include "arrow/type.h"
-#include "arrow/type_traits.h"
 #include "arrow/util/visibility.h"
 
 namespace arrow {
 
 class Array;
 class DataType;
+class Status;
 
 namespace compute {
 
-// Find the largest compatible primitive type for a primitive type.
-template <typename I, typename Enable = void>
-struct FindAccumulatorType {
-  using Type = DoubleType;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsSignedInt<I>::value>::type> {
-  using Type = Int64Type;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsUnsignedInt<I>::value>::type> {
-  using Type = UInt64Type;
-};
-
-template <typename I>
-struct FindAccumulatorType<I, typename std::enable_if<IsFloatingPoint<I>::value>::type> {
-  using Type = DoubleType;
-};
-
 struct Datum;
 class FunctionContext;
 class AggregateFunction;
 
+/// \brief Return a Sum Kernel
+///
+/// \param[in] type required to specialize the kernel
+/// \param[in] context the FunctionContext
+///
+/// \since 0.13.0
+/// \note API not yet finalized
 ARROW_EXPORT
 std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
                                                             FunctionContext* context);
@@ -86,5 +68,3 @@ Status Sum(FunctionContext* context, const Array& array, Datum* out);
 
 }  // namespace compute
 }  // namespace arrow
-
-#endif  // ARROW_COMPUTE_KERNELS_CAST_H
diff --git a/cpp/src/arrow/compute/test-util.h b/cpp/src/arrow/compute/test-util.h
index e90a034..bec54cc 100644
--- a/cpp/src/arrow/compute/test-util.h
+++ b/cpp/src/arrow/compute/test-util.h
@@ -69,6 +69,41 @@ std::shared_ptr<Array> _MakeArray(const std::shared_ptr<DataType>& type,
   return result;
 }
 
+template <typename Type, typename Enable = void>
+struct DatumEqual {};
+
+template <typename Type>
+struct DatumEqual<Type, typename std::enable_if<IsFloatingPoint<Type>::value>::type> {
+  static constexpr double kArbitraryDoubleErrorBound = 1.0;
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+
+  static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
+    ASSERT_EQ(lhs.kind(), rhs.kind());
+    if (lhs.kind() == Datum::SCALAR) {
+      auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
+      auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
+      ASSERT_EQ(left->is_valid, right->is_valid);
+      ASSERT_EQ(left->type->id(), right->type->id());
+      ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
+    }
+  }
+};
+
+template <typename Type>
+struct DatumEqual<Type, typename std::enable_if<!IsFloatingPoint<Type>::value>::type> {
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
+    ASSERT_EQ(lhs.kind(), rhs.kind());
+    if (lhs.kind() == Datum::SCALAR) {
+      auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
+      auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
+      ASSERT_EQ(left->is_valid, right->is_valid);
+      ASSERT_EQ(left->type->id(), right->type->id());
+      ASSERT_EQ(left->value, right->value);
+    }
+  }
+};
+
 }  // namespace compute
 }  // namespace arrow
 
diff --git a/cpp/src/arrow/util/bit-util-test.cc b/cpp/src/arrow/util/bit-util-test.cc
index e8f32d2..774d3bf 100644
--- a/cpp/src/arrow/util/bit-util-test.cc
+++ b/cpp/src/arrow/util/bit-util-test.cc
@@ -750,6 +750,15 @@ TEST(BitUtil, RoundDown) {
   }
 }
 
+TEST(BitUtil, CoveringBytes) {
+  EXPECT_EQ(BitUtil::CoveringBytes(0, 8), 1);
+  EXPECT_EQ(BitUtil::CoveringBytes(0, 9), 2);
+  EXPECT_EQ(BitUtil::CoveringBytes(1, 7), 1);
+  EXPECT_EQ(BitUtil::CoveringBytes(1, 8), 2);
+  EXPECT_EQ(BitUtil::CoveringBytes(2, 19), 3);
+  EXPECT_EQ(BitUtil::CoveringBytes(7, 18), 4);
+}
+
 TEST(BitUtil, TrailingBits) {
   EXPECT_EQ(BitUtil::TrailingBits(BOOST_BINARY(1 1 1 1 1 1 1 1), 0), 0);
   EXPECT_EQ(BitUtil::TrailingBits(BOOST_BINARY(1 1 1 1 1 1 1 1), 1), 1);
diff --git a/cpp/src/arrow/util/bit-util.h b/cpp/src/arrow/util/bit-util.h
index 53b5588..6724c29 100644
--- a/cpp/src/arrow/util/bit-util.h
+++ b/cpp/src/arrow/util/bit-util.h
@@ -154,6 +154,21 @@ constexpr int64_t RoundUpToMultipleOf64(int64_t num) {
   return RoundUpToPowerOf2(num, 64);
 }
 
+// Returns the number of bytes covering a sliced bitmap. Find the length
+// rounded to cover full bytes on both extremities.
+//
+// The following example represents a slice (offset=10, length=9)
+//
+// 0       8       16     24
+// |-------|-------|------|
+//           [       ]          (slice)
+//         [             ]      (same slice aligned to bytes bounds, length=16)
+//
+// The covering bytes is the length (in bytes) of this new aligned slice.
+constexpr int64_t CoveringBytes(int64_t offset, int64_t length) {
+  return (BitUtil::RoundUp(length + offset, 8) - BitUtil::RoundDown(offset, 8)) / 8;
+}
+
 // Returns the 'num_bits' least-significant bits of 'v'.
 static inline uint64_t TrailingBits(uint64_t v, int num_bits) {
   if (ARROW_PREDICT_FALSE(num_bits == 0)) return 0;