You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/06/27 20:45:42 UTC
[arrow] branch master updated: ARROW-2104: [C++] take kernel
functions for nested types
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new da752fd ARROW-2104: [C++] take kernel functions for nested types
da752fd is described below
commit da752fddab34d71e5c5f648b2cb20740c16ce11e
Author: Benjamin Kietzman <be...@gmail.com>
AuthorDate: Thu Jun 27 15:45:14 2019 -0500
ARROW-2104: [C++] take kernel functions for nested types
Take now supports gathering from List, FixedSizeList, Map, and Struct arrays. Union is not yet supported
Author: Benjamin Kietzman <be...@gmail.com>
Closes #4531 from bkietz/2104-Implement-take-kernel-functions-nested-a and squashes the following commits:
73262bd44 <Benjamin Kietzman> clang-format
eaf8302ea <Benjamin Kietzman> add benchmarks for Take()
5981ee8d4 <Benjamin Kietzman> rewrite Filter(string array) benchmark to respect memory budget
d60ff7c0d <Benjamin Kietzman> cast size_t -> int16_t, update fixed_size_binary(0) test
30d587252 <Benjamin Kietzman> add LiteralType constructor for gcc 4.8
e73c1ec23 <Benjamin Kietzman> validate arrays in pyarrow's Take() test
ac0e391aa <Benjamin Kietzman> add benchmark for filtering a StringArray
0fe81648e <Benjamin Kietzman> added requested tests and ValidateArray calls
55854836d <Benjamin Kietzman> add doccomments
e6081b027 <Benjamin Kietzman> remove redundant bounds checking in Struct case
d9c4a1a64 <Benjamin Kietzman> add Take() permutation inversion test
abc1733bd <Benjamin Kietzman> simplify looping through IndexSequences
c3e812982 <Benjamin Kietzman> rewrite python Take() test
c7f2e4021 <Benjamin Kietzman> repair bounds checking
6a14c93e8 <Benjamin Kietzman> clang-format, explicit cast
6c453f334 <Benjamin Kietzman> lint fixes
227ea5516 <Benjamin Kietzman> add tests for Take(nested types)
65dcd9075 <Benjamin Kietzman> refactor Take and Filter to share code through Taker<>
---
cpp/src/arrow/array-test.cc | 2 +-
cpp/src/arrow/array.cc | 13 +-
cpp/src/arrow/array.h | 8 +-
cpp/src/arrow/array/builder_primitive.cc | 4 +-
cpp/src/arrow/buffer-builder.h | 3 +
cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 +
cpp/src/arrow/compute/kernels/filter-benchmark.cc | 31 ++
cpp/src/arrow/compute/kernels/filter-test.cc | 57 ++-
cpp/src/arrow/compute/kernels/filter.cc | 426 ++---------------
cpp/src/arrow/compute/kernels/filter.h | 9 +-
cpp/src/arrow/compute/kernels/take-benchmark.cc | 147 ++++++
cpp/src/arrow/compute/kernels/take-internal.h | 553 ++++++++++++++++++++++
cpp/src/arrow/compute/kernels/take-test.cc | 386 +++++++++++++--
cpp/src/arrow/compute/kernels/take.cc | 226 +++------
cpp/src/arrow/compute/kernels/take.h | 32 +-
cpp/src/arrow/compute/kernels/util-internal.h | 2 +-
python/pyarrow/tests/test_compute.py | 20 +-
17 files changed, 1303 insertions(+), 617 deletions(-)
diff --git a/cpp/src/arrow/array-test.cc b/cpp/src/arrow/array-test.cc
index 606ca71..2005a0d 100644
--- a/cpp/src/arrow/array-test.cc
+++ b/cpp/src/arrow/array-test.cc
@@ -1311,7 +1311,7 @@ TEST_F(TestFWBinaryArray, ZeroSize) {
const auto& fw_array = checked_cast<const FixedSizeBinaryArray&>(*array);
// data is never allocated
- ASSERT_TRUE(fw_array.values() == nullptr);
+ ASSERT_EQ(fw_array.values()->size(), 0);
ASSERT_EQ(0, fw_array.byte_width());
ASSERT_EQ(6, array->length());
diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc
index 95acc6b..9b66af2 100644
--- a/cpp/src/arrow/array.cc
+++ b/cpp/src/arrow/array.cc
@@ -301,12 +301,21 @@ MapArray::MapArray(const std::shared_ptr<ArrayData>& data) { SetData(data); }
MapArray::MapArray(const std::shared_ptr<DataType>& type, int64_t length,
const std::shared_ptr<Buffer>& offsets,
- const std::shared_ptr<Array>& keys,
const std::shared_ptr<Array>& values,
const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count,
int64_t offset) {
+ SetData(ArrayData::Make(type, length, {null_bitmap, offsets}, {values->data()},
+ null_count, offset));
+}
+
+MapArray::MapArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& offsets,
+ const std::shared_ptr<Array>& keys,
+ const std::shared_ptr<Array>& items,
+ const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count,
+ int64_t offset) {
auto pair_data = ArrayData::Make(type->children()[0]->type(), keys->data()->length,
- {nullptr}, {keys->data(), values->data()}, 0, offset);
+ {nullptr}, {keys->data(), items->data()}, 0, offset);
auto map_data = ArrayData::Make(type, length, {null_bitmap, offsets}, {pair_data},
null_count, offset);
SetData(map_data);
diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index 5cca9db..1e163b7 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -565,7 +565,13 @@ class ARROW_EXPORT MapArray : public ListArray {
MapArray(const std::shared_ptr<DataType>& type, int64_t length,
const std::shared_ptr<Buffer>& value_offsets,
- const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& items,
+ const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
+ int64_t null_count = kUnknownNullCount, int64_t offset = 0);
+
+ MapArray(const std::shared_ptr<DataType>& type, int64_t length,
+ const std::shared_ptr<Buffer>& value_offsets,
+ const std::shared_ptr<Array>& values,
const std::shared_ptr<Buffer>& null_bitmap = NULLPTR,
int64_t null_count = kUnknownNullCount, int64_t offset = 0);
diff --git a/cpp/src/arrow/array/builder_primitive.cc b/cpp/src/arrow/array/builder_primitive.cc
index 34d198e..c7d934f 100644
--- a/cpp/src/arrow/array/builder_primitive.cc
+++ b/cpp/src/arrow/array/builder_primitive.cc
@@ -65,9 +65,9 @@ Status BooleanBuilder::Resize(int64_t capacity) {
}
Status BooleanBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) {
- std::shared_ptr<Buffer> data, null_bitmap;
- RETURN_NOT_OK(data_builder_.Finish(&data));
+ std::shared_ptr<Buffer> null_bitmap, data;
RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
+ RETURN_NOT_OK(data_builder_.Finish(&data));
*out = ArrayData::Make(boolean(), length_, {null_bitmap, data}, null_count_);
diff --git a/cpp/src/arrow/buffer-builder.h b/cpp/src/arrow/buffer-builder.h
index f069ea4..85f36ee 100644
--- a/cpp/src/arrow/buffer-builder.h
+++ b/cpp/src/arrow/buffer-builder.h
@@ -145,6 +145,9 @@ class ARROW_EXPORT BufferBuilder {
ARROW_RETURN_NOT_OK(Resize(size_, shrink_to_fit));
if (size_ != 0) buffer_->ZeroPadding();
*out = buffer_;
+ if (*out == NULLPTR) {
+ ARROW_RETURN_NOT_OK(AllocateBuffer(pool_, 0, out));
+ }
Reset();
return Status::OK();
}
diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt
index 1bbb5bc..3d9da8b 100644
--- a/cpp/src/arrow/compute/kernels/CMakeLists.txt
+++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt
@@ -34,3 +34,4 @@ add_arrow_benchmark(compare-benchmark PREFIX "arrow-compute")
add_arrow_test(take-test PREFIX "arrow-compute")
add_arrow_test(filter-test PREFIX "arrow-compute")
add_arrow_benchmark(filter-benchmark PREFIX "arrow-compute")
+add_arrow_benchmark(take-benchmark PREFIX "arrow-compute")
diff --git a/cpp/src/arrow/compute/kernels/filter-benchmark.cc b/cpp/src/arrow/compute/kernels/filter-benchmark.cc
index 3eb460a..0ae528b 100644
--- a/cpp/src/arrow/compute/kernels/filter-benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/filter-benchmark.cc
@@ -68,6 +68,30 @@ static void FilterFixedSizeList1Int64(benchmark::State& state) {
}
}
+static void FilterString(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ int32_t string_min_length = 0, string_max_length = 128;
+ int32_t string_mean_length = (string_max_length + string_min_length) / 2;
+ // for an array of 50% null strings, we need to generate twice as many strings
+ // to ensure that they have an average of args.size total characters
+ auto array_size =
+ static_cast<int64_t>(args.size / string_mean_length / (1 - args.null_proportion));
+
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto array = std::static_pointer_cast<StringArray>(rand.String(
+ array_size, string_min_length, string_max_length, args.null_proportion));
+ auto filter = std::static_pointer_cast<BooleanArray>(
+ rand.Boolean(array_size, 0.75, args.null_proportion));
+
+ FunctionContext ctx;
+ for (auto _ : state) {
+ Datum out;
+ ABORT_NOT_OK(Filter(&ctx, Datum(array), Datum(filter), &out));
+ benchmark::DoNotOptimize(out);
+ }
+}
+
BENCHMARK(FilterInt64)
->Apply(RegressionSetArgs)
->Args({1 << 20, 1})
@@ -82,5 +106,12 @@ BENCHMARK(FilterFixedSizeList1Int64)
->MinTime(1.0)
->Unit(benchmark::TimeUnit::kNanosecond);
+BENCHMARK(FilterString)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 1})
+ ->Args({1 << 23, 1})
+ ->MinTime(1.0)
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/filter-test.cc b/cpp/src/arrow/compute/kernels/filter-test.cc
index 7b34949..033efee 100644
--- a/cpp/src/arrow/compute/kernels/filter-test.cc
+++ b/cpp/src/arrow/compute/kernels/filter-test.cc
@@ -34,6 +34,8 @@ namespace compute {
using internal::checked_pointer_cast;
using util::string_view;
+constexpr auto kSeed = 0x0ff1ce;
+
template <typename ArrowType>
class TestFilterKernel : public ComputeFixture, public TestBase {
protected:
@@ -42,23 +44,29 @@ class TestFilterKernel : public ComputeFixture, public TestBase {
const std::shared_ptr<Array>& expected) {
std::shared_ptr<Array> actual;
ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter, &actual));
+ ASSERT_OK(ValidateArray(*actual));
AssertArraysEqual(*expected, *actual);
}
+
void AssertFilter(const std::shared_ptr<DataType>& type, const std::string& values,
const std::string& filter, const std::string& expected) {
std::shared_ptr<Array> actual;
ASSERT_OK(this->Filter(type, values, filter, &actual));
+ ASSERT_OK(ValidateArray(*actual));
AssertArraysEqual(*ArrayFromJSON(type, expected), *actual);
}
+
Status Filter(const std::shared_ptr<DataType>& type, const std::string& values,
const std::string& filter, std::shared_ptr<Array>* out) {
return arrow::compute::Filter(&this->ctx_, *ArrayFromJSON(type, values),
*ArrayFromJSON(boolean(), filter), out);
}
+
void ValidateFilter(const std::shared_ptr<Array>& values,
const std::shared_ptr<Array>& filter_boxed) {
std::shared_ptr<Array> filtered;
ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter_boxed, &filtered));
+ ASSERT_OK(ValidateArray(*filtered));
auto filter = checked_pointer_cast<BooleanArray>(filter_boxed);
int64_t values_i = 0, filtered_i = 0;
@@ -84,11 +92,13 @@ class TestFilterKernelWithNull : public TestFilterKernel<NullType> {
protected:
void AssertFilter(const std::string& values, const std::string& filter,
const std::string& expected) {
- TestFilterKernel<NullType>::AssertFilter(utf8(), values, filter, expected);
+ TestFilterKernel<NullType>::AssertFilter(null(), values, filter, expected);
}
};
TEST_F(TestFilterKernelWithNull, FilterNull) {
+ this->AssertFilter("[]", "[]", "[]");
+
this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]");
this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]");
}
@@ -102,6 +112,8 @@ class TestFilterKernelWithBoolean : public TestFilterKernel<BooleanType> {
};
TEST_F(TestFilterKernelWithBoolean, FilterBoolean) {
+ this->AssertFilter("[]", "[]", "[]");
+
this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]");
this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]");
this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]");
@@ -114,6 +126,7 @@ class TestFilterKernelWithNumeric : public TestFilterKernel<ArrowType> {
const std::string& expected) {
TestFilterKernel<ArrowType>::AssertFilter(type_singleton(), values, filter, expected);
}
+
std::shared_ptr<DataType> type_singleton() {
return TypeTraits<ArrowType>::type_singleton();
}
@@ -135,13 +148,16 @@ TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) {
this->AssertFilter("[null, 8, 9]", "[0, 1, 0]", "[8]");
this->AssertFilter("[7, 8, 9]", "[null, 1, 0]", "[null, 8]");
this->AssertFilter("[7, 8, 9]", "[1, null, 1]", "[7, null, 9]");
+
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(Invalid, this->Filter(this->type_singleton(), "[7, 8, 9]", "[]", &arr));
}
TYPED_TEST(TestFilterKernelWithNumeric, FilterRandomNumeric) {
- auto rand = random::RandomArrayGenerator(0x5416447);
+ auto rand = random::RandomArrayGenerator(kSeed);
for (size_t i = 3; i < 13; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
- for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
+ for (auto null_probability : {0.0, 0.01, 0.25, 1.0}) {
for (auto filter_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
auto values = rand.Numeric<TypeParam>(length, 0, 127, null_probability);
auto filter = rand.Boolean(length, filter_probability, null_probability);
@@ -191,7 +207,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
using CType = typename TypeTraits<TypeParam>::CType;
- auto rand = random::RandomArrayGenerator(0x5416447);
+ auto rand = random::RandomArrayGenerator(kSeed);
for (size_t i = 3; i < 13; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
// TODO(bkietz) rewrite with some nulls
@@ -206,6 +222,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
&selection));
ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection, &filtered));
auto filtered_array = filtered.make_array();
+ ASSERT_OK(ValidateArray(*filtered_array));
auto expected =
CompareAndFilter<TypeParam>(array->raw_values(), array->length(), c_fifty, op);
ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
@@ -216,7 +233,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
- auto rand = random::RandomArrayGenerator(0x5416447);
+ auto rand = random::RandomArrayGenerator(kSeed);
for (size_t i = 3; i < 13; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
auto lhs =
@@ -230,6 +247,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) {
&selection));
ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(lhs), selection, &filtered));
auto filtered_array = filtered.make_array();
+ ASSERT_OK(ValidateArray(*filtered_array));
auto expected = CompareAndFilter<TypeParam>(lhs->raw_values(), lhs->length(),
rhs->raw_values(), op);
ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
@@ -242,7 +260,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
using CType = typename TypeTraits<TypeParam>::CType;
- auto rand = random::RandomArrayGenerator(0x5416447);
+ auto rand = random::RandomArrayGenerator(kSeed);
for (size_t i = 3; i < 13; i++) {
const int64_t length = static_cast<int64_t>(1ULL << i);
auto array =
@@ -259,6 +277,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
&selection));
ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection, &filtered));
auto filtered_array = filtered.make_array();
+ ASSERT_OK(ValidateArray(*filtered_array));
auto expected = CompareAndFilter<TypeParam>(
array->raw_values(), array->length(),
[&](CType e) { return (e > c_fifty) && (e < c_hundred); });
@@ -313,6 +332,32 @@ TEST_F(TestFilterKernelWithList, FilterListInt32) {
this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]");
}
+TEST_F(TestFilterKernelWithList, FilterListListInt32) {
+ std::string list_json = R"([
+ [],
+ [[1], [2, null, 2], []],
+ null,
+ [[3, null], null]
+ ])";
+ auto type = list(list(int32()));
+ this->AssertFilter(type, list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(type, list_json, "[0, 1, 1, null]", R"([
+ [[1], [2, null, 2], []],
+ null,
+ null
+ ])");
+ this->AssertFilter(type, list_json, "[0, 0, 1, null]", "[null, null]");
+ this->AssertFilter(type, list_json, "[1, 0, 0, 1]", R"([
+ [],
+ [[3, null], null]
+ ])");
+ this->AssertFilter(type, list_json, "[1, 1, 1, 1]", list_json);
+ this->AssertFilter(type, list_json, "[0, 1, 0, 1]", R"([
+ [[1], [2, null, 2], []],
+ [[3, null], null]
+ ])");
+}
+
class TestFilterKernelWithFixedSizeList : public TestFilterKernel<FixedSizeListType> {};
TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) {
diff --git a/cpp/src/arrow/compute/kernels/filter.cc b/cpp/src/arrow/compute/kernels/filter.cc
index 654ec61..8a07663 100644
--- a/cpp/src/arrow/compute/kernels/filter.cc
+++ b/cpp/src/arrow/compute/kernels/filter.cc
@@ -15,19 +15,17 @@
// specific language governing permissions and limitations
// under the License.
-#include <algorithm>
+#include "arrow/compute/kernels/filter.h"
+
+#include <limits>
#include <memory>
#include <utility>
-#include <vector>
#include "arrow/builder.h"
#include "arrow/compute/context.h"
-#include "arrow/compute/kernels/filter.h"
-#include "arrow/util/bit-util.h"
+#include "arrow/compute/kernels/take-internal.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
-#include "arrow/util/stl.h"
-#include "arrow/visitor_inline.h"
namespace arrow {
namespace compute {
@@ -35,32 +33,36 @@ namespace compute {
using internal::checked_cast;
using internal::checked_pointer_cast;
-template <typename Builder>
-Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
- std::unique_ptr<Builder>* out) {
- std::unique_ptr<ArrayBuilder> builder;
- RETURN_NOT_OK(MakeBuilder(pool, type, &builder));
- out->reset(checked_cast<Builder*>(builder.release()));
- return Status::OK();
-}
+// IndexSequence which yields the indices of positions in a BooleanArray
+// which are either null or true
+class FilterIndexSequence {
+ public:
+ // constexpr so we'll never instantiate bounds checking
+ constexpr bool never_out_of_bounds() const { return true; }
+ void set_never_out_of_bounds() {}
-template <typename Builder, typename Scalar>
-static Status UnsafeAppend(Builder* builder, Scalar&& value) {
- builder->UnsafeAppend(std::forward<Scalar>(value));
- return Status::OK();
-}
+ constexpr FilterIndexSequence() = default;
-static Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) {
- RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
- builder->UnsafeAppend(value);
- return Status::OK();
-}
+ FilterIndexSequence(const BooleanArray& filter, int64_t out_length)
+ : filter_(&filter), out_length_(out_length) {}
-static Status UnsafeAppend(StringBuilder* builder, util::string_view value) {
- RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
- builder->UnsafeAppend(value);
- return Status::OK();
-}
+ std::pair<int64_t, bool> Next() {
+ // skip until an index is found at which the filter is either null or true
+ while (filter_->IsValid(index_) && !filter_->Value(index_)) {
+ ++index_;
+ }
+ bool is_valid = filter_->IsValid(index_);
+ return std::make_pair(index_++, is_valid);
+ }
+
+ int64_t length() const { return out_length_; }
+
+ int64_t null_count() const { return filter_->null_count(); }
+
+ private:
+ const BooleanArray* filter_ = nullptr;
+ int64_t index_ = 0, out_length_ = -1;
+};
// TODO(bkietz) this can be optimized
static int64_t OutputSize(const BooleanArray& filter) {
@@ -75,358 +77,32 @@ static int64_t OutputSize(const BooleanArray& filter) {
return size;
}
-template <typename ValueType>
-class FilterImpl;
-
-template <>
-class FilterImpl<NullType> : public FilterKernel {
- public:
- using FilterKernel::FilterKernel;
-
- Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter,
- int64_t length, std::shared_ptr<Array>* out) override {
- out->reset(new NullArray(length));
- return Status::OK();
- }
-};
-
-template <typename ValueType>
-class FilterImpl : public FilterKernel {
- public:
- using ValueArray = typename TypeTraits<ValueType>::ArrayType;
- using OutBuilder = typename TypeTraits<ValueType>::BuilderType;
-
- using FilterKernel::FilterKernel;
-
- Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter,
- int64_t length, std::shared_ptr<Array>* out) override {
- std::unique_ptr<OutBuilder> builder;
- RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type_, &builder));
- RETURN_NOT_OK(builder->Resize(OutputSize(filter)));
- RETURN_NOT_OK(UnpackValuesNullCount(checked_cast<const ValueArray&>(values), filter,
- builder.get()));
- return builder->Finish(out);
- }
-
- private:
- Status UnpackValuesNullCount(const ValueArray& values, const BooleanArray& filter,
- OutBuilder* builder) {
- if (values.null_count() == 0) {
- return UnpackIndicesNullCount<true>(values, filter, builder);
- }
- return UnpackIndicesNullCount<false>(values, filter, builder);
- }
-
- template <bool AllValuesValid>
- Status UnpackIndicesNullCount(const ValueArray& values, const BooleanArray& filter,
- OutBuilder* builder) {
- if (filter.null_count() == 0) {
- return Filter<AllValuesValid, true>(values, filter, builder);
- }
- return Filter<AllValuesValid, false>(values, filter, builder);
- }
-
- template <bool AllValuesValid, bool AllIndicesValid>
- Status Filter(const ValueArray& values, const BooleanArray& filter,
- OutBuilder* builder) {
- for (int64_t i = 0; i < filter.length(); ++i) {
- if (!AllIndicesValid && filter.IsNull(i)) {
- builder->UnsafeAppendNull();
- continue;
- }
- if (!filter.Value(i)) {
- continue;
- }
- if (!AllValuesValid && values.IsNull(i)) {
- builder->UnsafeAppendNull();
- continue;
- }
- RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(i)));
- }
- return Status::OK();
- }
-};
-
-template <>
-class FilterImpl<StructType> : public FilterKernel {
+class FilterKernelImpl : public FilterKernel {
public:
- FilterImpl(const std::shared_ptr<DataType>& type,
- std::vector<std::unique_ptr<FilterKernel>> child_kernels)
- : FilterKernel(type), child_kernels_(std::move(child_kernels)) {}
+ FilterKernelImpl(const std::shared_ptr<DataType>& type,
+ std::unique_ptr<Taker<FilterIndexSequence>> taker)
+ : FilterKernel(type), taker_(std::move(taker)) {}
Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter,
int64_t length, std::shared_ptr<Array>* out) override {
- const auto& struct_array = checked_cast<const StructArray&>(values);
-
- TypedBufferBuilder<bool> null_bitmap_builder(ctx->memory_pool());
- RETURN_NOT_OK(null_bitmap_builder.Resize(length));
-
- ArrayVector fields(type_->num_children());
- for (int i = 0; i < type_->num_children(); ++i) {
- RETURN_NOT_OK(child_kernels_[i]->Filter(ctx, *struct_array.field(i), filter, length,
- &fields[i]));
- }
-
- for (int64_t i = 0; i < filter.length(); ++i) {
- if (filter.IsNull(i)) {
- null_bitmap_builder.UnsafeAppend(false);
- continue;
- }
- if (!filter.Value(i)) {
- continue;
- }
- if (struct_array.IsNull(i)) {
- null_bitmap_builder.UnsafeAppend(false);
- continue;
- }
- null_bitmap_builder.UnsafeAppend(true);
+ if (values.length() != filter.length()) {
+ return Status::Invalid("filter and value array must have identical lengths");
}
-
- auto null_count = null_bitmap_builder.false_count();
- std::shared_ptr<Buffer> null_bitmap;
- RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap));
-
- out->reset(new StructArray(type_, length, fields, null_bitmap, null_count));
- return Status::OK();
+ RETURN_NOT_OK(taker_->Init(ctx->memory_pool()));
+ RETURN_NOT_OK(taker_->Take(values, FilterIndexSequence(filter, length)));
+ return taker_->Finish(out);
}
- private:
- std::vector<std::unique_ptr<FilterKernel>> child_kernels_;
-};
-
-template <>
-class FilterImpl<FixedSizeListType> : public FilterKernel {
- public:
- using FilterKernel::FilterKernel;
-
- Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter,
- int64_t length, std::shared_ptr<Array>* out) override {
- const auto& list_array = checked_cast<const FixedSizeListArray&>(values);
-
- TypedBufferBuilder<bool> null_bitmap_builder(ctx->memory_pool());
- RETURN_NOT_OK(null_bitmap_builder.Resize(length));
-
- BooleanBuilder value_filter_builder(ctx->memory_pool());
- auto list_size = list_array.list_type()->list_size();
- RETURN_NOT_OK(value_filter_builder.Resize(list_size * length));
-
- for (int64_t i = 0; i < filter.length(); ++i) {
- if (filter.IsNull(i)) {
- null_bitmap_builder.UnsafeAppend(false);
- for (int64_t j = 0; j < list_size; ++j) {
- value_filter_builder.UnsafeAppendNull();
- }
- continue;
- }
- if (!filter.Value(i)) {
- for (int64_t j = 0; j < list_size; ++j) {
- value_filter_builder.UnsafeAppend(false);
- }
- continue;
- }
- if (values.IsNull(i)) {
- null_bitmap_builder.UnsafeAppend(false);
- for (int64_t j = 0; j < list_size; ++j) {
- value_filter_builder.UnsafeAppendNull();
- }
- continue;
- }
- for (int64_t j = 0; j < list_size; ++j) {
- value_filter_builder.UnsafeAppend(true);
- }
- null_bitmap_builder.UnsafeAppend(true);
- }
-
- std::shared_ptr<BooleanArray> value_filter;
- RETURN_NOT_OK(value_filter_builder.Finish(&value_filter));
- std::shared_ptr<Array> out_values;
- RETURN_NOT_OK(
- arrow::compute::Filter(ctx, *list_array.values(), *value_filter, &out_values));
-
- auto null_count = null_bitmap_builder.false_count();
- std::shared_ptr<Buffer> null_bitmap;
- RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap));
-
- out->reset(
- new FixedSizeListArray(type_, length, out_values, null_bitmap, null_count));
- return Status::OK();
- }
-};
-
-template <>
-class FilterImpl<ListType> : public FilterKernel {
- public:
- using FilterKernel::FilterKernel;
-
- Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter,
- int64_t length, std::shared_ptr<Array>* out) override {
- const auto& list_array = checked_cast<const ListArray&>(values);
-
- TypedBufferBuilder<bool> null_bitmap_builder(ctx->memory_pool());
- RETURN_NOT_OK(null_bitmap_builder.Resize(length));
-
- BooleanBuilder value_filter_builder(ctx->memory_pool());
-
- TypedBufferBuilder<int32_t> offset_builder(ctx->memory_pool());
- RETURN_NOT_OK(offset_builder.Resize(length + 1));
- int32_t offset = 0;
- offset_builder.UnsafeAppend(offset);
-
- for (int64_t i = 0; i < filter.length(); ++i) {
- if (filter.IsNull(i)) {
- null_bitmap_builder.UnsafeAppend(false);
- offset_builder.UnsafeAppend(offset);
- RETURN_NOT_OK(
- value_filter_builder.AppendValues(list_array.value_length(i), false));
- continue;
- }
- if (!filter.Value(i)) {
- RETURN_NOT_OK(
- value_filter_builder.AppendValues(list_array.value_length(i), false));
- continue;
- }
- if (values.IsNull(i)) {
- null_bitmap_builder.UnsafeAppend(false);
- offset_builder.UnsafeAppend(offset);
- RETURN_NOT_OK(
- value_filter_builder.AppendValues(list_array.value_length(i), false));
- continue;
- }
- null_bitmap_builder.UnsafeAppend(true);
- offset += list_array.value_length(i);
- offset_builder.UnsafeAppend(offset);
- RETURN_NOT_OK(value_filter_builder.AppendValues(list_array.value_length(i), true));
- }
-
- std::shared_ptr<BooleanArray> value_filter;
- RETURN_NOT_OK(value_filter_builder.Finish(&value_filter));
- std::shared_ptr<Array> out_values;
- RETURN_NOT_OK(
- arrow::compute::Filter(ctx, *list_array.values(), *value_filter, &out_values));
-
- auto null_count = null_bitmap_builder.false_count();
- std::shared_ptr<Buffer> offsets, null_bitmap;
- RETURN_NOT_OK(offset_builder.Finish(&offsets));
- RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap));
-
- *out = MakeArray(ArrayData::Make(type_, length, {null_bitmap, offsets},
- {out_values->data()}, null_count));
- return Status::OK();
- }
-};
-
-template <>
-class FilterImpl<MapType> : public FilterImpl<ListType> {
- using FilterImpl<ListType>::FilterImpl;
-};
-
-template <>
-class FilterImpl<DictionaryType> : public FilterKernel {
- public:
- FilterImpl(const std::shared_ptr<DataType>& type, std::unique_ptr<FilterKernel> impl)
- : FilterKernel(type), impl_(std::move(impl)) {}
-
- Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter,
- int64_t length, std::shared_ptr<Array>* out) override {
- auto dict_array = checked_cast<const DictionaryArray*>(&values);
- // To filter a dictionary, apply the current kernel to the dictionary's indices.
- std::shared_ptr<Array> taken_indices;
- RETURN_NOT_OK(
- impl_->Filter(ctx, *dict_array->indices(), filter, length, &taken_indices));
- return DictionaryArray::FromArrays(values.type(), taken_indices,
- dict_array->dictionary(), out);
- }
-
- private:
- std::unique_ptr<FilterKernel> impl_;
-};
-
-template <>
-class FilterImpl<ExtensionType> : public FilterKernel {
- public:
- FilterImpl(const std::shared_ptr<DataType>& type, std::unique_ptr<FilterKernel> impl)
- : FilterKernel(type), impl_(std::move(impl)) {}
-
- Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter,
- int64_t length, std::shared_ptr<Array>* out) override {
- auto ext_array = checked_cast<const ExtensionArray*>(&values);
- // To take from an extension array, apply the current kernel to storage.
- std::shared_ptr<Array> taken_storage;
- RETURN_NOT_OK(
- impl_->Filter(ctx, *ext_array->storage(), filter, length, &taken_storage));
- *out = ext_array->extension_type()->MakeArray(taken_storage->data());
- return Status::OK();
- }
-
- private:
- std::unique_ptr<FilterKernel> impl_;
+ std::unique_ptr<Taker<FilterIndexSequence>> taker_;
};
Status FilterKernel::Make(const std::shared_ptr<DataType>& value_type,
std::unique_ptr<FilterKernel>* out) {
- switch (value_type->id()) {
-#define NO_CHILD_CASE(T) \
- case T##Type::type_id: \
- *out = internal::make_unique<FilterImpl<T##Type>>(value_type); \
- return Status::OK()
-
-#define SINGLE_CHILD_CASE(T, CHILD_TYPE) \
- case T##Type::type_id: { \
- auto t = checked_pointer_cast<T##Type>(value_type); \
- std::unique_ptr<FilterKernel> child_filter_impl; \
- RETURN_NOT_OK(FilterKernel::Make(t->CHILD_TYPE(), &child_filter_impl)); \
- *out = internal::make_unique<FilterImpl<T##Type>>(t, std::move(child_filter_impl)); \
- return Status::OK(); \
- }
-
- NO_CHILD_CASE(Null);
- NO_CHILD_CASE(Boolean);
- NO_CHILD_CASE(Int8);
- NO_CHILD_CASE(Int16);
- NO_CHILD_CASE(Int32);
- NO_CHILD_CASE(Int64);
- NO_CHILD_CASE(UInt8);
- NO_CHILD_CASE(UInt16);
- NO_CHILD_CASE(UInt32);
- NO_CHILD_CASE(UInt64);
- NO_CHILD_CASE(Date32);
- NO_CHILD_CASE(Date64);
- NO_CHILD_CASE(Time32);
- NO_CHILD_CASE(Time64);
- NO_CHILD_CASE(Timestamp);
- NO_CHILD_CASE(Duration);
- NO_CHILD_CASE(HalfFloat);
- NO_CHILD_CASE(Float);
- NO_CHILD_CASE(Double);
- NO_CHILD_CASE(String);
- NO_CHILD_CASE(Binary);
- NO_CHILD_CASE(FixedSizeBinary);
- NO_CHILD_CASE(Decimal128);
-
- SINGLE_CHILD_CASE(Dictionary, index_type);
- SINGLE_CHILD_CASE(Extension, storage_type);
+ std::unique_ptr<Taker<FilterIndexSequence>> taker;
+ RETURN_NOT_OK(Taker<FilterIndexSequence>::Make(value_type, &taker));
- NO_CHILD_CASE(List);
- NO_CHILD_CASE(FixedSizeList);
- NO_CHILD_CASE(Map);
-
- case Type::STRUCT: {
- std::vector<std::unique_ptr<FilterKernel>> child_kernels;
- for (auto child : value_type->children()) {
- child_kernels.emplace_back();
- RETURN_NOT_OK(FilterKernel::Make(child->type(), &child_kernels.back()));
- }
- *out = internal::make_unique<FilterImpl<StructType>>(value_type,
- std::move(child_kernels));
- return Status::OK();
- }
-
-#undef NO_CHILD_CASE
-#undef SINGLE_CHILD_CASE
-
- default:
- return Status::NotImplemented("gathering values of type ", *value_type);
- }
+ out->reset(new FilterKernelImpl(value_type, std::move(taker)));
+ return Status::OK();
}
Status FilterKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& filter,
@@ -436,26 +112,26 @@ Status FilterKernel::Call(FunctionContext* ctx, const Datum& values, const Datum
}
auto values_array = values.make_array();
auto filter_array = checked_pointer_cast<BooleanArray>(filter.make_array());
- const auto length = OutputSize(*filter_array);
std::shared_ptr<Array> out_array;
- RETURN_NOT_OK(this->Filter(ctx, *values_array, *filter_array, length, &out_array));
+ RETURN_NOT_OK(this->Filter(ctx, *values_array, *filter_array, OutputSize(*filter_array),
+ &out_array));
*out = out_array;
return Status::OK();
}
-Status Filter(FunctionContext* context, const Array& values, const Array& filter,
+Status Filter(FunctionContext* ctx, const Array& values, const Array& filter,
std::shared_ptr<Array>* out) {
Datum out_datum;
- RETURN_NOT_OK(Filter(context, Datum(values.data()), Datum(filter.data()), &out_datum));
+ RETURN_NOT_OK(Filter(ctx, Datum(values.data()), Datum(filter.data()), &out_datum));
*out = out_datum.make_array();
return Status::OK();
}
-Status Filter(FunctionContext* context, const Datum& values, const Datum& filter,
+Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter,
Datum* out) {
std::unique_ptr<FilterKernel> kernel;
RETURN_NOT_OK(FilterKernel::Make(values.type(), &kernel));
- return kernel->Call(context, values, filter, out);
+ return kernel->Call(ctx, values, filter, out);
}
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/filter.h b/cpp/src/arrow/compute/kernels/filter.h
index 46ad3d4..401daa8 100644
--- a/cpp/src/arrow/compute/kernels/filter.h
+++ b/cpp/src/arrow/compute/kernels/filter.h
@@ -41,23 +41,22 @@ class FunctionContext;
/// filter = [0, 1, 1, 0, null, 1], the output will be
/// = ["b", "c", null, "f"]
///
-/// \param[in] context the FunctionContext
+/// \param[in] ctx the FunctionContext
/// \param[in] values array to filter
/// \param[in] filter indicates which values should be filtered out
/// \param[out] out resulting array
ARROW_EXPORT
-Status Filter(FunctionContext* context, const Array& values, const Array& filter,
+Status Filter(FunctionContext* ctx, const Array& values, const Array& filter,
std::shared_ptr<Array>* out);
/// \brief Filter an array with a boolean selection filter
///
-/// \param[in] context the FunctionContext
+/// \param[in] ctx the FunctionContext
/// \param[in] values datum to filter
/// \param[in] filter indicates which values should be filtered out
/// \param[out] out resulting datum
ARROW_EXPORT
-Status Filter(FunctionContext* context, const Datum& values, const Datum& filter,
- Datum* out);
+Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter, Datum* out);
/// \brief BinaryKernel implementing Filter operation
class ARROW_EXPORT FilterKernel : public BinaryKernel {
diff --git a/cpp/src/arrow/compute/kernels/take-benchmark.cc b/cpp/src/arrow/compute/kernels/take-benchmark.cc
new file mode 100644
index 0000000..139e183
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/take-benchmark.cc
@@ -0,0 +1,147 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/compute/kernels/take.h"
+
+#include "arrow/compute/benchmark-util.h"
+#include "arrow/compute/test-util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+
+namespace arrow {
+namespace compute {
+
+constexpr auto kSeed = 0x0ff1ce;
+
+static void TakeBenchmark(benchmark::State& state, const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices) {
+ FunctionContext ctx;
+ TakeOptions options;
+ for (auto _ : state) {
+ Datum out;
+ ABORT_NOT_OK(Take(&ctx, Datum(values), Datum(indices), options, &out));
+ benchmark::DoNotOptimize(out);
+ }
+}
+
+static void TakeInt64(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+
+ auto values = rand.Int64(array_size, -100, 100, args.null_proportion);
+
+ auto indices = rand.Int32(array_size, 0, array_size - 1, args.null_proportion);
+
+ TakeBenchmark(state, values, indices);
+}
+
+static void TakeFixedSizeList1Int64(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+
+ auto int_array = rand.Int64(array_size, -100, 100, args.null_proportion);
+ auto values = std::make_shared<FixedSizeListArray>(
+ fixed_size_list(int64(), 1), array_size, int_array, int_array->null_bitmap(),
+ int_array->null_count());
+
+ auto indices = rand.Int32(array_size, 0, array_size - 1, args.null_proportion);
+
+ TakeBenchmark(state, values, indices);
+}
+
+static void TakeInt64VsFilter(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+
+ auto values = rand.Int64(array_size, -100, 100, args.null_proportion);
+
+ auto filter = std::static_pointer_cast<BooleanArray>(
+ rand.Boolean(array_size, 0.75, args.null_proportion));
+
+ Int32Builder indices_builder;
+ ABORT_NOT_OK(indices_builder.Resize(array_size));
+
+ for (int64_t i = 0; i < array_size; ++i) {
+ if (filter->IsNull(i)) {
+ indices_builder.UnsafeAppendNull();
+ } else if (filter->Value(i)) {
+ indices_builder.UnsafeAppend(static_cast<int32_t>(i));
+ }
+ }
+
+ std::shared_ptr<Array> indices;
+ ABORT_NOT_OK(indices_builder.Finish(&indices));
+ TakeBenchmark(state, values, indices);
+}
+
+static void TakeString(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ int32_t string_min_length = 0, string_max_length = 128;
+ int32_t string_mean_length = (string_max_length + string_min_length) / 2;
+ // for an array of 50% null strings, we need to generate twice as many strings
+ // to ensure that they have an average of args.size total characters
+ auto array_size =
+ static_cast<int64_t>(args.size / string_mean_length / (1 - args.null_proportion));
+
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto values = std::static_pointer_cast<StringArray>(rand.String(
+ array_size, string_min_length, string_max_length, args.null_proportion));
+
+ auto indices = rand.Int32(array_size, 0, array_size - 1, args.null_proportion);
+
+ TakeBenchmark(state, values, indices);
+}
+
+BENCHMARK(TakeInt64)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 1})
+ ->Args({1 << 23, 1})
+ ->MinTime(1.0)
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(TakeFixedSizeList1Int64)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 1})
+ ->Args({1 << 23, 1})
+ ->MinTime(1.0)
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(TakeInt64VsFilter)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 1})
+ ->Args({1 << 23, 1})
+ ->MinTime(1.0)
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(TakeString)
+ ->Apply(RegressionSetArgs)
+ ->Args({1 << 20, 1})
+ ->Args({1 << 23, 1})
+ ->MinTime(1.0)
+ ->Unit(benchmark::TimeUnit::kNanosecond);
+
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/take-internal.h b/cpp/src/arrow/compute/kernels/take-internal.h
new file mode 100644
index 0000000..bacd71b
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/take-internal.h
@@ -0,0 +1,553 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <limits>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "arrow/builder.h"
+#include "arrow/compute/context.h"
+#include "arrow/util/bit-util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/stl.h"
+#include "arrow/visitor_inline.h"
+
+namespace arrow {
+namespace compute {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+template <typename Builder, typename Scalar>
+static Status UnsafeAppend(Builder* builder, Scalar&& value) {
+ builder->UnsafeAppend(std::forward<Scalar>(value));
+ return Status::OK();
+}
+
+// Use BinaryBuilder::UnsafeAppend, but reserve byte storage first
+static Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) {
+ RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
+ builder->UnsafeAppend(value);
+ return Status::OK();
+}
+
+// Use StringBuilder::UnsafeAppend, but reserve character storage first
+static Status UnsafeAppend(StringBuilder* builder, util::string_view value) {
+ RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
+ builder->UnsafeAppend(value);
+ return Status::OK();
+}
+
+/// \brief visit indices from an IndexSequence while bounds checking
+///
+/// \param[in] indices IndexSequence to visit
+/// \param[in] values array to bounds check against, if necessary
+/// \param[in] vis index visitor, signature must be Status(int64_t index, bool is_valid)
+template <bool SomeIndicesNull, bool SomeValuesNull, bool NeverOutOfBounds,
+ typename IndexSequence, typename Visitor>
+Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) {
+ for (int64_t i = 0; i < indices.length(); ++i) {
+ auto index_valid = indices.Next();
+ if (SomeIndicesNull && !index_valid.second) {
+ RETURN_NOT_OK(vis(0, false));
+ continue;
+ }
+
+ auto index = index_valid.first;
+ if (!NeverOutOfBounds) {
+ if (index < 0 || index >= values.length()) {
+ return Status::IndexError("take index out of bounds");
+ }
+ }
+
+ bool is_valid = !SomeValuesNull || values.IsValid(index);
+ RETURN_NOT_OK(vis(index, is_valid));
+ }
+ return Status::OK();
+}
+
+template <bool SomeIndicesNull, bool SomeValuesNull, typename IndexSequence,
+ typename Visitor>
+Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) {
+ if (indices.never_out_of_bounds()) {
+ return VisitIndices<SomeIndicesNull, SomeValuesNull, true>(
+ indices, values, std::forward<Visitor>(vis));
+ }
+ return VisitIndices<SomeIndicesNull, SomeValuesNull, false>(indices, values,
+ std::forward<Visitor>(vis));
+}
+
+template <bool SomeIndicesNull, typename IndexSequence, typename Visitor>
+Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) {
+ if (values.null_count() == 0) {
+ return VisitIndices<SomeIndicesNull, false>(indices, values,
+ std::forward<Visitor>(vis));
+ }
+ return VisitIndices<SomeIndicesNull, true>(indices, values, std::forward<Visitor>(vis));
+}
+
+template <typename IndexSequence, typename Visitor>
+Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) {
+ if (indices.null_count() == 0) {
+ return VisitIndices<false>(indices, values, std::forward<Visitor>(vis));
+ }
+ return VisitIndices<true>(indices, values, std::forward<Visitor>(vis));
+}
+
+// Helper class for gathering values from an array
+template <typename IndexSequence>
+class Taker {
+ public:
+ explicit Taker(const std::shared_ptr<DataType>& type) : type_(type) {}
+
+ virtual ~Taker() = default;
+
+ // construct any children, must be called once after construction
+ virtual Status MakeChildren() { return Status::OK(); }
+
+ // reset this Taker, prepare to gather into an array allocated from pool
+ // must be called each time the output pool may have changed
+ virtual Status Init(MemoryPool* pool) = 0;
+
+ // gather elements from an array at the provided indices
+ virtual Status Take(const Array& values, IndexSequence indices) = 0;
+
+ // assemble an array of all gathered values
+ virtual Status Finish(std::shared_ptr<Array>*) = 0;
+
+ // factory; the output Taker will support gathering values of the given type
+ static Status Make(const std::shared_ptr<DataType>& type, std::unique_ptr<Taker>* out);
+
+ static_assert(std::is_literal_type<IndexSequence>::value,
+ "Index sequences must be literal type");
+
+ static_assert(std::is_copy_constructible<IndexSequence>::value,
+ "Index sequences must be copy constructible");
+
+ static_assert(std::is_same<decltype(std::declval<IndexSequence>().Next()),
+ std::pair<int64_t, bool>>::value,
+ "An index sequence must yield pairs of indices:int64_t, validity:bool.");
+
+ static_assert(std::is_same<decltype(std::declval<const IndexSequence>().length()),
+ int64_t>::value,
+ "An index sequence must provide its length.");
+
+ static_assert(std::is_same<decltype(std::declval<const IndexSequence>().null_count()),
+ int64_t>::value,
+ "An index sequence must provide the number of nulls it will take.");
+
+ static_assert(
+ std::is_same<decltype(std::declval<const IndexSequence>().never_out_of_bounds()),
+ bool>::value,
+ "Index sequences must declare whether bounds checking is necessary");
+
+ static_assert(
+ std::is_same<decltype(std::declval<IndexSequence>().set_never_out_of_bounds()),
+ void>::value,
+ "An index sequence must support ignoring bounds checking.");
+
+ protected:
+ template <typename Builder>
+ Status MakeBuilder(MemoryPool* pool, std::unique_ptr<Builder>* out) {
+ std::unique_ptr<ArrayBuilder> builder;
+ RETURN_NOT_OK(arrow::MakeBuilder(pool, type_, &builder));
+ out->reset(checked_cast<Builder*>(builder.release()));
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> type_;
+};
+
+// an IndexSequence which yields indices from a specified range
+// or yields null for the length of that range
+class RangeIndexSequence {
+ public:
+ constexpr bool never_out_of_bounds() const { return true; }
+ void set_never_out_of_bounds() {}
+
+ constexpr RangeIndexSequence() = default;
+
+ RangeIndexSequence(bool is_valid, int64_t offset, int64_t length)
+ : is_valid_(is_valid), index_(offset), length_(length) {}
+
+ std::pair<int64_t, bool> Next() { return std::make_pair(index_++, is_valid_); }
+
+ int64_t length() const { return length_; }
+
+ int64_t null_count() const { return is_valid_ ? 0 : length_; }
+
+ private:
+ bool is_valid_ = true;
+ int64_t index_ = 0, length_ = -1;
+};
+
+// Default implementation: taking from a simple array into a builder requires only that
+// the array supports array.GetView() and the corresponding builder supports
+// builder.UnsafeAppend(array.GetView())
+template <typename IndexSequence, typename T>
+class TakerImpl : public Taker<IndexSequence> {
+ public:
+ using ArrayType = typename TypeTraits<T>::ArrayType;
+ using BuilderType = typename TypeTraits<T>::BuilderType;
+
+ using Taker<IndexSequence>::Taker;
+
+ Status Init(MemoryPool* pool) override { return this->MakeBuilder(pool, &builder_); }
+
+ Status Take(const Array& values, IndexSequence indices) override {
+ DCHECK(this->type_->Equals(values.type()));
+ RETURN_NOT_OK(builder_->Reserve(indices.length()));
+ return VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
+ if (!is_valid) {
+ builder_->UnsafeAppendNull();
+ return Status::OK();
+ }
+ auto value = checked_cast<const ArrayType&>(values).GetView(index);
+ return UnsafeAppend(builder_.get(), value);
+ });
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) override { return builder_->Finish(out); }
+
+ private:
+ std::unique_ptr<BuilderType> builder_;
+};
+
+// Gathering from NullArrays is trivial; skip the builder and just
+// do bounds checking
+template <typename IndexSequence>
+class TakerImpl<IndexSequence, NullType> : public Taker<IndexSequence> {
+ public:
+ using Taker<IndexSequence>::Taker;
+
+ Status Init(MemoryPool*) override { return Status::OK(); }
+
+ Status Take(const Array& values, IndexSequence indices) override {
+ DCHECK(this->type_->Equals(values.type()));
+
+ length_ += indices.length();
+
+ if (indices.never_out_of_bounds()) {
+ return Status::OK();
+ }
+
+ return VisitIndices(indices, values, [](int64_t, bool) { return Status::OK(); });
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) override {
+ out->reset(new NullArray(length_));
+ return Status::OK();
+ }
+
+ private:
+ int64_t length_ = 0;
+};
+
+template <typename IndexSequence>
+class TakerImpl<IndexSequence, ListType> : public Taker<IndexSequence> {
+ public:
+ using Taker<IndexSequence>::Taker;
+
+ Status MakeChildren() override {
+ const auto& list_type = checked_cast<const ListType&>(*this->type_);
+ return Taker<RangeIndexSequence>::Make(list_type.value_type(), &value_taker_);
+ }
+
+ Status Init(MemoryPool* pool) override {
+ null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool));
+ offset_builder_.reset(new TypedBufferBuilder<int32_t>(pool));
+ RETURN_NOT_OK(offset_builder_->Append(0));
+ return value_taker_->Init(pool);
+ }
+
+ Status Take(const Array& values, IndexSequence indices) override {
+ DCHECK(this->type_->Equals(values.type()));
+
+ const auto& list_array = checked_cast<const ListArray&>(values);
+
+ RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length()));
+ RETURN_NOT_OK(offset_builder_->Reserve(indices.length()));
+
+ int32_t offset = offset_builder_->data()[offset_builder_->length() - 1];
+ return VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
+ null_bitmap_builder_->UnsafeAppend(is_valid);
+
+ if (is_valid) {
+ offset += list_array.value_length(index);
+ RangeIndexSequence value_indices(true, list_array.value_offset(index),
+ list_array.value_length(index));
+ RETURN_NOT_OK(value_taker_->Take(*list_array.values(), value_indices));
+ }
+
+ offset_builder_->UnsafeAppend(offset);
+ return Status::OK();
+ });
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) override { return FinishAs<ListArray>(out); }
+
+ protected:
+ // this added method is provided for use by TakerImpl<IndexSequence, MapType>,
+ // which needs to construct a MapArray rather than a ListArray
+ template <typename T>
+ Status FinishAs(std::shared_ptr<Array>* out) {
+ auto null_count = null_bitmap_builder_->false_count();
+ auto length = null_bitmap_builder_->length();
+
+ std::shared_ptr<Buffer> offsets, null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap));
+ RETURN_NOT_OK(offset_builder_->Finish(&offsets));
+
+ std::shared_ptr<Array> taken_values;
+ RETURN_NOT_OK(value_taker_->Finish(&taken_values));
+
+ out->reset(
+ new T(this->type_, length, offsets, taken_values, null_bitmap, null_count));
+ return Status::OK();
+ }
+
+ std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_;
+ std::unique_ptr<TypedBufferBuilder<int32_t>> offset_builder_;
+ std::unique_ptr<Taker<RangeIndexSequence>> value_taker_;
+};
+
+template <typename IndexSequence>
+class TakerImpl<IndexSequence, MapType> : public TakerImpl<IndexSequence, ListType> {
+ public:
+ using TakerImpl<IndexSequence, ListType>::TakerImpl;
+
+ Status Finish(std::shared_ptr<Array>* out) override {
+ return this->template FinishAs<MapArray>(out);
+ }
+};
+
+template <typename IndexSequence>
+class TakerImpl<IndexSequence, FixedSizeListType> : public Taker<IndexSequence> {
+ public:
+ using Taker<IndexSequence>::Taker;
+
+ Status MakeChildren() override {
+ const auto& list_type = checked_cast<const FixedSizeListType&>(*this->type_);
+ return Taker<RangeIndexSequence>::Make(list_type.value_type(), &value_taker_);
+ }
+
+ Status Init(MemoryPool* pool) override {
+ null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool));
+ return value_taker_->Init(pool);
+ }
+
+ Status Take(const Array& values, IndexSequence indices) override {
+ DCHECK(this->type_->Equals(values.type()));
+
+ const auto& list_array = checked_cast<const FixedSizeListArray&>(values);
+ auto list_size = list_array.list_type()->list_size();
+
+ RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length()));
+ return VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
+ null_bitmap_builder_->UnsafeAppend(is_valid);
+
+ // for FixedSizeList, null lists are not empty (they also span a segment of
+ // list_size in the child data), so we must append to value_taker_ even if !is_valid
+ RangeIndexSequence value_indices(is_valid, list_array.value_offset(index),
+ list_size);
+ return value_taker_->Take(*list_array.values(), value_indices);
+ });
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) override {
+ auto null_count = null_bitmap_builder_->false_count();
+ auto length = null_bitmap_builder_->length();
+
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap));
+
+ std::shared_ptr<Array> taken_values;
+ RETURN_NOT_OK(value_taker_->Finish(&taken_values));
+
+ out->reset(new FixedSizeListArray(this->type_, length, taken_values, null_bitmap,
+ null_count));
+ return Status::OK();
+ }
+
+ protected:
+ std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_;
+ std::unique_ptr<Taker<RangeIndexSequence>> value_taker_;
+};
+
+template <typename IndexSequence>
+class TakerImpl<IndexSequence, StructType> : public Taker<IndexSequence> {
+ public:
+ using Taker<IndexSequence>::Taker;
+
+ Status MakeChildren() override {
+ children_.resize(this->type_->num_children());
+ for (int i = 0; i < this->type_->num_children(); ++i) {
+ RETURN_NOT_OK(
+ Taker<IndexSequence>::Make(this->type_->child(i)->type(), &children_[i]));
+ }
+ return Status::OK();
+ }
+
+ Status Init(MemoryPool* pool) override {
+ null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool));
+ for (int i = 0; i < this->type_->num_children(); ++i) {
+ RETURN_NOT_OK(children_[i]->Init(pool));
+ }
+ return Status::OK();
+ }
+
+ Status Take(const Array& values, IndexSequence indices) override {
+ DCHECK(this->type_->Equals(values.type()));
+
+ RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length()));
+ RETURN_NOT_OK(VisitIndices(indices, values, [&](int64_t, bool is_valid) {
+ null_bitmap_builder_->UnsafeAppend(is_valid);
+ return Status::OK();
+ }));
+
+ // bounds checking was done while appending to the null bitmap
+ indices.set_never_out_of_bounds();
+
+ const auto& struct_array = checked_cast<const StructArray&>(values);
+ for (int i = 0; i < this->type_->num_children(); ++i) {
+ RETURN_NOT_OK(children_[i]->Take(*struct_array.field(i), indices));
+ }
+ return Status::OK();
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) override {
+ auto null_count = null_bitmap_builder_->false_count();
+ auto length = null_bitmap_builder_->length();
+ std::shared_ptr<Buffer> null_bitmap;
+ RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap));
+
+ ArrayVector fields(this->type_->num_children());
+ for (int i = 0; i < this->type_->num_children(); ++i) {
+ RETURN_NOT_OK(children_[i]->Finish(&fields[i]));
+ }
+
+ out->reset(
+ new StructArray(this->type_, length, std::move(fields), null_bitmap, null_count));
+ return Status::OK();
+ }
+
+ protected:
+ std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_;
+ std::vector<std::unique_ptr<Taker<IndexSequence>>> children_;
+};
+
+// taking from a DictionaryArray is accomplished by taking from its indices
+template <typename IndexSequence>
+class TakerImpl<IndexSequence, DictionaryType> : public Taker<IndexSequence> {
+ public:
+ using Taker<IndexSequence>::Taker;
+
+ Status MakeChildren() override {
+ const auto& dict_type = checked_cast<const DictionaryType&>(*this->type_);
+ return Taker<IndexSequence>::Make(dict_type.index_type(), &index_taker_);
+ }
+
+ Status Init(MemoryPool* pool) override {
+ dictionary_ = nullptr;
+ return index_taker_->Init(pool);
+ }
+
+ Status Take(const Array& values, IndexSequence indices) override {
+ DCHECK(this->type_->Equals(values.type()));
+ const auto& dict_array = checked_cast<const DictionaryArray&>(values);
+
+ if (dictionary_ != nullptr && dictionary_ != dict_array.dictionary()) {
+ return Status::NotImplemented(
+ "taking from DictionaryArrays with different dictionaries");
+ } else {
+ dictionary_ = dict_array.dictionary();
+ }
+ return index_taker_->Take(*dict_array.indices(), indices);
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) override {
+ std::shared_ptr<Array> taken_indices;
+ RETURN_NOT_OK(index_taker_->Finish(&taken_indices));
+ out->reset(new DictionaryArray(this->type_, taken_indices, dictionary_));
+ return Status::OK();
+ }
+
+ protected:
+ std::shared_ptr<Array> dictionary_;
+ std::unique_ptr<Taker<IndexSequence>> index_taker_;
+};
+
+// taking from an ExtensionArray is accomplished by taking from its storage
+template <typename IndexSequence>
+class TakerImpl<IndexSequence, ExtensionType> : public Taker<IndexSequence> {
+ public:
+ using Taker<IndexSequence>::Taker;
+
+ Status MakeChildren() override {
+ const auto& ext_type = checked_cast<const ExtensionType&>(*this->type_);
+ return Taker<IndexSequence>::Make(ext_type.storage_type(), &storage_taker_);
+ }
+
+ Status Init(MemoryPool* pool) override { return storage_taker_->Init(pool); }
+
+ Status Take(const Array& values, IndexSequence indices) override {
+ DCHECK(this->type_->Equals(values.type()));
+ const auto& ext_array = checked_cast<const ExtensionArray&>(values);
+ return storage_taker_->Take(*ext_array.storage(), indices);
+ }
+
+ Status Finish(std::shared_ptr<Array>* out) override {
+ std::shared_ptr<Array> taken_storage;
+ RETURN_NOT_OK(storage_taker_->Finish(&taken_storage));
+ out->reset(new ExtensionArray(this->type_, taken_storage));
+ return Status::OK();
+ }
+
+ protected:
+ std::unique_ptr<Taker<IndexSequence>> storage_taker_;
+};
+
+template <typename IndexSequence>
+struct TakerMakeImpl {
+ template <typename T>
+ Status Visit(const T&) {
+ out_->reset(new TakerImpl<IndexSequence, T>(type_));
+ return (*out_)->MakeChildren();
+ }
+
+ Status Visit(const UnionType& t) {
+ return Status::NotImplemented("gathering values of type ", t);
+ }
+
+ std::shared_ptr<DataType> type_;
+ std::unique_ptr<Taker<IndexSequence>>* out_;
+};
+
+template <typename IndexSequence>
+Status Taker<IndexSequence>::Make(const std::shared_ptr<DataType>& type,
+ std::unique_ptr<Taker>* out) {
+ TakerMakeImpl<IndexSequence> visitor{type, out};
+ return VisitTypeInline(*type, &visitor);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc
index c61aeda..da5e0c0 100644
--- a/cpp/src/arrow/compute/kernels/take-test.cc
+++ b/cpp/src/arrow/compute/kernels/take-test.cc
@@ -29,31 +29,40 @@
namespace arrow {
namespace compute {
+using internal::checked_cast;
+using internal::checked_pointer_cast;
using util::string_view;
+constexpr auto kSeed = 0x0ff1ce;
+
template <typename ArrowType>
class TestTakeKernel : public ComputeFixture, public TestBase {
protected:
void AssertTakeArrays(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& indices, TakeOptions options,
+ const std::shared_ptr<Array>& indices,
const std::shared_ptr<Array>& expected) {
std::shared_ptr<Array> actual;
+ TakeOptions options;
ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, options, &actual));
+ ASSERT_OK(ValidateArray(*actual));
AssertArraysEqual(*expected, *actual);
}
+
void AssertTake(const std::shared_ptr<DataType>& type, const std::string& values,
- const std::string& indices, TakeOptions options,
- const std::string& expected) {
+ const std::string& indices, const std::string& expected) {
std::shared_ptr<Array> actual;
for (auto index_type : {int8(), uint32()}) {
- ASSERT_OK(this->Take(type, values, index_type, indices, options, &actual));
+ ASSERT_OK(this->Take(type, values, index_type, indices, &actual));
+ ASSERT_OK(ValidateArray(*actual));
AssertArraysEqual(*ArrayFromJSON(type, expected), *actual);
}
}
+
Status Take(const std::shared_ptr<DataType>& type, const std::string& values,
const std::shared_ptr<DataType>& index_type, const std::string& indices,
- TakeOptions options, std::shared_ptr<Array>* out) {
+ std::shared_ptr<Array>* out) {
+ TakeOptions options;
return arrow::compute::Take(&this->ctx_, *ArrayFromJSON(type, values),
*ArrayFromJSON(index_type, indices), options, out);
}
@@ -62,82 +71,123 @@ class TestTakeKernel : public ComputeFixture, public TestBase {
class TestTakeKernelWithNull : public TestTakeKernel<NullType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
- TakeOptions options, const std::string& expected) {
- TestTakeKernel<NullType>::AssertTake(utf8(), values, indices, options, expected);
+ const std::string& expected) {
+ TestTakeKernel<NullType>::AssertTake(null(), values, indices, expected);
}
};
TEST_F(TestTakeKernelWithNull, TakeNull) {
- TakeOptions options;
- this->AssertTake("[null, null, null]", "[0, 1, 0]", options, "[null, null, null]");
+ this->AssertTake("[null, null, null]", "[0, 1, 0]", "[null, null, null]");
std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError, this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]",
- options, &arr));
+ ASSERT_RAISES(IndexError,
+ this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError,
+ this->Take(boolean(), "[null, null, null]", int8(), "[0, -1, 0]", &arr));
}
TEST_F(TestTakeKernelWithNull, InvalidIndexType) {
- TakeOptions options;
std::shared_ptr<Array> arr;
ASSERT_RAISES(TypeError, this->Take(null(), "[null, null, null]", float32(),
- "[0.0, 1.0, 0.1]", options, &arr));
+ "[0.0, 1.0, 0.1]", &arr));
}
class TestTakeKernelWithBoolean : public TestTakeKernel<BooleanType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
- TakeOptions options, const std::string& expected) {
- TestTakeKernel<BooleanType>::AssertTake(boolean(), values, indices, options,
- expected);
+ const std::string& expected) {
+ TestTakeKernel<BooleanType>::AssertTake(boolean(), values, indices, expected);
}
};
TEST_F(TestTakeKernelWithBoolean, TakeBoolean) {
- TakeOptions options;
- this->AssertTake("[true, false, true]", "[0, 1, 0]", options, "[true, false, true]");
- this->AssertTake("[null, false, true]", "[0, 1, 0]", options, "[null, false, null]");
- this->AssertTake("[true, false, true]", "[null, 1, 0]", options, "[null, false, true]");
+ this->AssertTake("[7, 8, 9]", "[]", "[]");
+ this->AssertTake("[true, false, true]", "[0, 1, 0]", "[true, false, true]");
+ this->AssertTake("[null, false, true]", "[0, 1, 0]", "[null, false, null]");
+ this->AssertTake("[true, false, true]", "[null, 1, 0]", "[null, false, true]");
std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError, this->Take(boolean(), "[true, false, true]", int8(),
- "[0, 9, 0]", options, &arr));
+ ASSERT_RAISES(IndexError,
+ this->Take(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError,
+ this->Take(boolean(), "[true, false, true]", int8(), "[0, -1, 0]", &arr));
}
template <typename ArrowType>
class TestTakeKernelWithNumeric : public TestTakeKernel<ArrowType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
- TakeOptions options, const std::string& expected) {
- TestTakeKernel<ArrowType>::AssertTake(type_singleton(), values, indices, options,
- expected);
+ const std::string& expected) {
+ TestTakeKernel<ArrowType>::AssertTake(type_singleton(), values, indices, expected);
}
+
std::shared_ptr<DataType> type_singleton() {
return TypeTraits<ArrowType>::type_singleton();
}
+
+ void ValidateTake(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices_boxed) {
+ std::shared_ptr<Array> taken;
+ TakeOptions options;
+ ASSERT_OK(
+ arrow::compute::Take(&this->ctx_, *values, *indices_boxed, options, &taken));
+ ASSERT_OK(ValidateArray(*taken));
+ ASSERT_EQ(indices_boxed->length(), taken->length());
+
+ ASSERT_EQ(indices_boxed->type_id(), Type::INT32);
+ auto indices = checked_pointer_cast<Int32Array>(indices_boxed);
+ for (int64_t i = 0; i < indices->length(); ++i) {
+ if (indices->IsNull(i)) {
+ ASSERT_TRUE(taken->IsNull(i));
+ continue;
+ }
+ int32_t taken_index = indices->Value(i);
+ ASSERT_TRUE(values->RangeEquals(taken_index, taken_index + 1, i, taken));
+ }
+ }
};
TYPED_TEST_CASE(TestTakeKernelWithNumeric, NumericArrowTypes);
TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) {
- TakeOptions options;
- this->AssertTake("[7, 8, 9]", "[0, 1, 0]", options, "[7, 8, 7]");
- this->AssertTake("[null, 8, 9]", "[0, 1, 0]", options, "[null, 8, null]");
- this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]");
- this->AssertTake("[null, 8, 9]", "[]", options, "[]");
+ this->AssertTake("[7, 8, 9]", "[]", "[]");
+ this->AssertTake("[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]");
+ this->AssertTake("[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]");
+ this->AssertTake("[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]");
+ this->AssertTake("[null, 8, 9]", "[]", "[]");
+ this->AssertTake("[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]");
std::shared_ptr<Array> arr;
ASSERT_RAISES(IndexError, this->Take(this->type_singleton(), "[7, 8, 9]", int8(),
- "[0, 9, 0]", options, &arr));
+ "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError, this->Take(this->type_singleton(), "[7, 8, 9]", int8(),
+ "[0, -1, 0]", &arr));
+}
+
+TYPED_TEST(TestTakeKernelWithNumeric, TakeRandomNumeric) {
+ auto rand = random::RandomArrayGenerator(kSeed);
+ for (size_t i = 3; i < 8; i++) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ for (size_t j = 0; j < 13; j++) {
+ const int64_t indices_length = static_cast<int64_t>(1ULL << j);
+ for (auto null_probability : {0.0, 0.01, 0.25, 1.0}) {
+ auto values = rand.Numeric<TypeParam>(length, 0, 127, null_probability);
+ auto max_index = static_cast<int32_t>(length - 1);
+ auto filter = rand.Int32(indices_length, 0, max_index, null_probability);
+ this->ValidateTake(values, filter);
+ }
+ }
+ }
}
class TestTakeKernelWithString : public TestTakeKernel<StringType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
- TakeOptions options, const std::string& expected) {
- TestTakeKernel<StringType>::AssertTake(utf8(), values, indices, options, expected);
+ const std::string& expected) {
+ TestTakeKernel<StringType>::AssertTake(utf8(), values, indices, expected);
}
void AssertTakeDictionary(const std::string& dictionary_values,
const std::string& dictionary_indices,
- const std::string& indices, TakeOptions options,
+ const std::string& indices,
const std::string& expected_indices) {
auto dict = ArrayFromJSON(utf8(), dictionary_values);
auto type = dictionary(int8(), utf8());
@@ -147,28 +197,272 @@ class TestTakeKernelWithString : public TestTakeKernel<StringType> {
ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices),
dict, &expected));
auto take_indices = ArrayFromJSON(int8(), indices);
- this->AssertTakeArrays(values, take_indices, options, expected);
+ this->AssertTakeArrays(values, take_indices, expected);
}
};
TEST_F(TestTakeKernelWithString, TakeString) {
- TakeOptions options;
- this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", options, R"(["a", "b", "a"])");
- this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", options, "[null, \"b\", null]");
- this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", options, R"([null, "b", "a"])");
+ this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])");
+ this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]");
+ this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])");
std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError, this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]",
- options, &arr));
+ ASSERT_RAISES(IndexError,
+ this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError, this->Take(utf8(), R"(["a", "b", null, "ddd", "ee"])",
+ int64(), "[2, 5]", &arr));
}
TEST_F(TestTakeKernelWithString, TakeDictionary) {
- TakeOptions options;
auto dict = R"(["a", "b", "c", "d", "e"])";
- this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", options, "[3, 4, 3]");
- this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", options,
- "[null, 4, null]");
- this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", options, "[null, 4, 3]");
+ this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]");
+ this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]");
+ this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]");
+}
+
+class TestTakeKernelWithList : public TestTakeKernel<ListType> {};
+
+TEST_F(TestTakeKernelWithList, TakeListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ this->AssertTake(list(int32()), list_json, "[]", "[]");
+ this->AssertTake(list(int32()), list_json, "[3, 2, 1]", "[[3], null, [1,2]]");
+ this->AssertTake(list(int32()), list_json, "[null, 3, 0]", "[null, [3], []]");
+ this->AssertTake(list(int32()), list_json, "[null, null]", "[null, null]");
+ this->AssertTake(list(int32()), list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]");
+ this->AssertTake(list(int32()), list_json, "[0, 1, 2, 3]", list_json);
+ this->AssertTake(list(int32()), list_json, "[0, 0, 0, 0, 0, 0, 1]",
+ "[[], [], [], [], [], [], [1, 2]]");
+}
+
+TEST_F(TestTakeKernelWithList, TakeListListInt32) {
+ std::string list_json = R"([
+ [],
+ [[1], [2, null, 2], []],
+ null,
+ [[3, null], null]
+ ])";
+ auto type = list(list(int32()));
+ this->AssertTake(type, list_json, "[]", "[]");
+ this->AssertTake(type, list_json, "[3, 2, 1]", R"([
+ [[3, null], null],
+ null,
+ [[1], [2, null, 2], []]
+ ])");
+ this->AssertTake(type, list_json, "[null, 3, 0]", R"([
+ null,
+ [[3, null], null],
+ []
+ ])");
+ this->AssertTake(type, list_json, "[null, null]", "[null, null]");
+ this->AssertTake(type, list_json, "[3, 0, 0, 3]",
+ "[[[3, null], null], [], [], [[3, null], null]]");
+ this->AssertTake(type, list_json, "[0, 1, 2, 3]", list_json);
+ this->AssertTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]",
+ "[[], [], [], [], [], [], [[1], [2, null, 2], []]]");
+}
+
+class TestTakeKernelWithFixedSizeList : public TestTakeKernel<FixedSizeListType> {};
+
+TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) {
+ std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
+ this->AssertTake(fixed_size_list(int32(), 3), list_json, "[]", "[]");
+ this->AssertTake(fixed_size_list(int32(), 3), list_json, "[3, 2, 1]",
+ "[[7, 8, null], [4, 5, 6], [1, null, 3]]");
+ this->AssertTake(fixed_size_list(int32(), 3), list_json, "[null, 2, 0]",
+ "[null, [4, 5, 6], null]");
+ this->AssertTake(fixed_size_list(int32(), 3), list_json, "[null, null]",
+ "[null, null]");
+ this->AssertTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]",
+ "[[7, 8, null], null, null, [7, 8, null]]");
+ this->AssertTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json);
+ this->AssertTake(
+ fixed_size_list(int32(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
+ "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, null, 3]]");
+}
+
+class TestTakeKernelWithMap : public TestTakeKernel<MapType> {};
+
+TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
+ std::string map_json = R"([
+ [["joe", 0], ["mark", null]],
+ null,
+ [["cap", 8]],
+ []
+ ])";
+ this->AssertTake(map(utf8(), int32()), map_json, "[]", "[]");
+ this->AssertTake(map(utf8(), int32()), map_json, "[3, 1, 3, 1, 3]",
+ "[[], null, [], null, []]");
+ this->AssertTake(map(utf8(), int32()), map_json, "[2, 1, null]", R"([
+ [["cap", 8]],
+ null,
+ null
+ ])");
+ this->AssertTake(map(utf8(), int32()), map_json, "[2, 1, 0]", R"([
+ [["cap", 8]],
+ null,
+ [["joe", 0], ["mark", null]]
+ ])");
+ this->AssertTake(map(utf8(), int32()), map_json, "[0, 1, 2, 3]", map_json);
+ this->AssertTake(map(utf8(), int32()), map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ []
+ ])");
+}
+
+class TestTakeKernelWithStruct : public TestTakeKernel<StructType> {};
+
+TEST_F(TestTakeKernelWithStruct, TakeStruct) {
+ auto struct_type = struct_({field("a", int32()), field("b", utf8())});
+ auto struct_json = R"([
+ null,
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ this->AssertTake(struct_type, struct_json, "[]", "[]");
+ this->AssertTake(struct_type, struct_json, "[3, 1, 3, 1, 3]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"}
+ ])");
+ this->AssertTake(struct_type, struct_json, "[3, 1, 0]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ null
+ ])");
+ this->AssertTake(struct_type, struct_json, "[0, 1, 2, 3]", struct_json);
+ this->AssertTake(struct_type, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ null,
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"}
+ ])");
+}
+
+class TestPermutationsWithTake : public ComputeFixture, public TestBase {
+ protected:
+ void Take(const Int16Array& values, const Int16Array& indices,
+ std::shared_ptr<Int16Array>* out) {
+ TakeOptions options;
+ std::shared_ptr<Array> boxed_out;
+ ASSERT_OK(arrow::compute::Take(&this->ctx_, values, indices, options, &boxed_out));
+ ASSERT_OK(ValidateArray(*boxed_out));
+ *out = checked_pointer_cast<Int16Array>(std::move(boxed_out));
+ }
+
+ std::shared_ptr<Int16Array> Take(const Int16Array& values, const Int16Array& indices) {
+ std::shared_ptr<Int16Array> out;
+ Take(values, indices, &out);
+ return out;
+ }
+
+ std::shared_ptr<Int16Array> TakeN(uint64_t n, std::shared_ptr<Int16Array> array) {
+ auto power_of_2 = array;
+ array = Identity(array->length());
+ while (n != 0) {
+ if (n & 1) {
+ array = Take(*array, *power_of_2);
+ }
+ power_of_2 = Take(*power_of_2, *power_of_2);
+ n >>= 1;
+ }
+ return array;
+ }
+
+ template <typename Rng>
+ void Shuffle(const Int16Array& array, Rng& gen, std::shared_ptr<Int16Array>* shuffled) {
+ auto byte_length = array.length() * sizeof(int16_t);
+ std::shared_ptr<Buffer> data;
+ ASSERT_OK(array.values()->Copy(0, byte_length, &data));
+ auto mutable_data = reinterpret_cast<int16_t*>(data->mutable_data());
+ std::shuffle(mutable_data, mutable_data + array.length(), gen);
+ shuffled->reset(new Int16Array(array.length(), data));
+ }
+
+ template <typename Rng>
+ std::shared_ptr<Int16Array> Shuffle(const Int16Array& array, Rng& gen) {
+ std::shared_ptr<Int16Array> out;
+ Shuffle(array, gen, &out);
+ return out;
+ }
+
+ void Identity(int64_t length, std::shared_ptr<Int16Array>* identity) {
+ Int16Builder identity_builder;
+ ASSERT_OK(identity_builder.Resize(length));
+ for (int16_t i = 0; i < length; ++i) {
+ identity_builder.UnsafeAppend(i);
+ }
+ ASSERT_OK(identity_builder.Finish(identity));
+ }
+
+ std::shared_ptr<Int16Array> Identity(int64_t length) {
+ std::shared_ptr<Int16Array> out;
+ Identity(length, &out);
+ return out;
+ }
+
+ std::shared_ptr<Int16Array> Inverse(const std::shared_ptr<Int16Array>& permutation) {
+ auto length = static_cast<int16_t>(permutation->length());
+
+ std::vector<bool> cycle_lengths(length + 1, false);
+ auto permutation_to_the_i = permutation;
+ for (int16_t cycle_length = 1; cycle_length <= length; ++cycle_length) {
+ cycle_lengths[cycle_length] = HasTrivialCycle(*permutation_to_the_i);
+ permutation_to_the_i = Take(*permutation, *permutation_to_the_i);
+ }
+
+ uint64_t cycle_to_identity_length = 1;
+ for (int16_t cycle_length = length; cycle_length > 1; --cycle_length) {
+ if (!cycle_lengths[cycle_length]) {
+ continue;
+ }
+ if (cycle_to_identity_length % cycle_length == 0) {
+ continue;
+ }
+ if (cycle_to_identity_length >
+ std::numeric_limits<uint64_t>::max() / cycle_length) {
+ // overflow, can't compute Inverse
+ return nullptr;
+ }
+ cycle_to_identity_length *= cycle_length;
+ }
+
+ return TakeN(cycle_to_identity_length - 1, permutation);
+ }
+
+ bool HasTrivialCycle(const Int16Array& permutation) {
+ for (int64_t i = 0; i < permutation.length(); ++i) {
+ if (permutation.Value(i) == static_cast<int16_t>(i)) {
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
+TEST_F(TestPermutationsWithTake, InvertPermutation) {
+ for (int seed : {0, kSeed, kSeed * 2 - 1}) {
+ std::default_random_engine gen(seed);
+ for (int16_t length = 0; length < 1 << 10; ++length) {
+ auto identity = Identity(length);
+ auto permutation = Shuffle(*identity, gen);
+ auto inverse = Inverse(permutation);
+ if (inverse == nullptr) {
+ break;
+ }
+ ASSERT_TRUE(Take(*inverse, *permutation)->Equals(identity));
+ }
+ }
}
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc
index 17b0540..6ed9111 100644
--- a/cpp/src/arrow/compute/kernels/take.cc
+++ b/cpp/src/arrow/compute/kernels/take.cc
@@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.
+#include <limits>
#include <memory>
#include <utility>
-#include "arrow/builder.h"
#include "arrow/compute/context.h"
+#include "arrow/compute/kernels/take-internal.h"
#include "arrow/compute/kernels/take.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
@@ -30,200 +31,107 @@ namespace compute {
using internal::checked_cast;
-Status Take(FunctionContext* context, const Array& values, const Array& indices,
- const TakeOptions& options, std::shared_ptr<Array>* out) {
- Datum out_datum;
- RETURN_NOT_OK(
- Take(context, Datum(values.data()), Datum(indices.data()), options, &out_datum));
- *out = out_datum.make_array();
- return Status::OK();
-}
-
-Status Take(FunctionContext* context, const Datum& values, const Datum& indices,
- const TakeOptions& options, Datum* out) {
- TakeKernel kernel(values.type(), options);
- RETURN_NOT_OK(kernel.Call(context, values, indices, out));
- return Status::OK();
-}
-
-struct TakeParameters {
- FunctionContext* context;
- std::shared_ptr<Array> values, indices;
- TakeOptions options;
- std::shared_ptr<Array>* out;
-};
-
-template <typename Builder, typename Scalar>
-Status UnsafeAppend(Builder* builder, Scalar&& value) {
- builder->UnsafeAppend(std::forward<Scalar>(value));
- return Status::OK();
-}
-
-Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) {
- RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
- builder->UnsafeAppend(value);
- return Status::OK();
-}
-
-Status UnsafeAppend(StringBuilder* builder, util::string_view value) {
- RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
- builder->UnsafeAppend(value);
- return Status::OK();
-}
-
-template <bool AllValuesValid, bool AllIndicesValid, typename ValueArray,
- typename IndexArray, typename OutBuilder>
-Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& indices,
- OutBuilder* builder) {
- auto raw_indices = indices.raw_values();
- for (int64_t i = 0; i < indices.length(); ++i) {
- if (!AllIndicesValid && indices.IsNull(i)) {
- builder->UnsafeAppendNull();
- continue;
- }
- auto index = static_cast<int64_t>(raw_indices[i]);
- if (index < 0 || index >= values.length()) {
- return Status::IndexError("take index out of bounds");
- }
- if (!AllValuesValid && values.IsNull(index)) {
- builder->UnsafeAppendNull();
- continue;
- }
- RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(index)));
- }
- return Status::OK();
-}
-
-template <bool AllValuesValid, typename ValueArray, typename IndexArray,
- typename OutBuilder>
-Status UnpackIndicesNullCount(FunctionContext* context, const ValueArray& values,
- const IndexArray& indices, OutBuilder* builder) {
- if (indices.null_count() == 0) {
- return TakeImpl<AllValuesValid, true>(context, values, indices, builder);
- }
- return TakeImpl<AllValuesValid, false>(context, values, indices, builder);
-}
-
-template <typename ValueArray, typename IndexArray, typename OutBuilder>
-Status UnpackValuesNullCount(FunctionContext* context, const ValueArray& values,
- const IndexArray& indices, OutBuilder* builder) {
- if (values.null_count() == 0) {
- return UnpackIndicesNullCount<true>(context, values, indices, builder);
- }
- return UnpackIndicesNullCount<false>(context, values, indices, builder);
-}
-
+// an IndexSequence which yields the values of an Array of integers
template <typename IndexType>
-struct UnpackValues {
- using IndexArrayRef = const typename TypeTraits<IndexType>::ArrayType&;
-
- template <typename ValueType>
- Status Visit(const ValueType&) {
- using ValueArrayRef = const typename TypeTraits<ValueType>::ArrayType&;
- using OutBuilder = typename TypeTraits<ValueType>::BuilderType;
- IndexArrayRef indices = checked_cast<IndexArrayRef>(*params_.indices);
- ValueArrayRef values = checked_cast<ValueArrayRef>(*params_.values);
- std::unique_ptr<ArrayBuilder> builder;
- RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder));
- RETURN_NOT_OK(builder->Reserve(indices.length()));
- RETURN_NOT_OK(UnpackValuesNullCount(params_.context, values, indices,
- checked_cast<OutBuilder*>(builder.get())));
- return builder->Finish(params_.out);
- }
+class ArrayIndexSequence {
+ public:
+ bool never_out_of_bounds() const { return never_out_of_bounds_; }
+ void set_never_out_of_bounds() { never_out_of_bounds_ = true; }
- Status Visit(const NullType& t) {
- auto indices_length = params_.indices->length();
- if (indices_length != 0) {
- auto indices = checked_cast<IndexArrayRef>(*params_.indices).raw_values();
- auto minmax = std::minmax_element(indices, indices + indices_length);
- auto min = static_cast<int64_t>(*minmax.first);
- auto max = static_cast<int64_t>(*minmax.second);
- if (min < 0 || max >= params_.values->length()) {
- return Status::IndexError("take index out of bounds");
- }
- }
- params_.out->reset(new NullArray(indices_length));
- return Status::OK();
- }
+ constexpr ArrayIndexSequence() = default;
- Status Visit(const DictionaryType& t) {
- std::shared_ptr<Array> taken_indices;
- const auto& values = internal::checked_cast<const DictionaryArray&>(*params_.values);
- {
- // To take from a dictionary, apply the current kernel to the dictionary's
- // indices. (Use UnpackValues<IndexType> since IndexType is already unpacked)
- auto indices = values.indices();
- TakeParameters params = params_;
- params.values = indices;
- params.out = &taken_indices;
- UnpackValues<IndexType> unpack = {params};
- RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack));
+ explicit ArrayIndexSequence(const Array& indices)
+ : indices_(&checked_cast<const NumericArray<IndexType>&>(indices)) {}
+
+ std::pair<int64_t, bool> Next() {
+ if (indices_->IsNull(index_)) {
+ ++index_;
+ return std::make_pair(-1, false);
}
- // create output dictionary from taken indices
- *params_.out = std::make_shared<DictionaryArray>(values.type(), taken_indices,
- values.dictionary());
- return Status::OK();
+ return std::make_pair(indices_->Value(index_++), true);
}
- Status Visit(const ExtensionType& t) {
- // XXX can we just take from its storage?
- return Status::NotImplemented("gathering values of type ", t);
- }
+ int64_t length() const { return indices_->length(); }
- Status Visit(const UnionType& t) {
- return Status::NotImplemented("gathering values of type ", t);
- }
+ int64_t null_count() const { return indices_->null_count(); }
- Status Visit(const ListType& t) {
- return Status::NotImplemented("gathering values of type ", t);
- }
+ private:
+ const NumericArray<IndexType>* indices_ = nullptr;
+ int64_t index_ = 0;
+ bool never_out_of_bounds_ = false;
+};
- Status Visit(const MapType& t) {
- return Status::NotImplemented("gathering values of type ", t);
- }
+template <typename IndexType>
+class TakeKernelImpl : public TakeKernel {
+ public:
+ explicit TakeKernelImpl(const std::shared_ptr<DataType>& value_type)
+ : TakeKernel(value_type) {}
- Status Visit(const FixedSizeListType& t) {
- return Status::NotImplemented("gathering values of type ", t);
+ Status Init() {
+ return Taker<ArrayIndexSequence<IndexType>>::Make(this->type_, &taker_);
}
- Status Visit(const StructType& t) {
- return Status::NotImplemented("gathering values of type ", t);
+ Status Take(FunctionContext* ctx, const Array& values, const Array& indices_array,
+ std::shared_ptr<Array>* out) override {
+ RETURN_NOT_OK(taker_->Init(ctx->memory_pool()));
+ RETURN_NOT_OK(taker_->Take(values, ArrayIndexSequence<IndexType>(indices_array)));
+ return taker_->Finish(out);
}
- const TakeParameters& params_;
+ std::unique_ptr<Taker<ArrayIndexSequence<IndexType>>> taker_;
};
struct UnpackIndices {
template <typename IndexType>
enable_if_integer<IndexType, Status> Visit(const IndexType&) {
- UnpackValues<IndexType> unpack = {params_};
- return VisitTypeInline(*params_.values->type(), &unpack);
+ auto out = new TakeKernelImpl<IndexType>(value_type_);
+ out_->reset(out);
+ return out->Init();
}
Status Visit(const DataType& other) {
return Status::TypeError("index type not supported: ", other);
}
- const TakeParameters& params_;
+ std::shared_ptr<DataType> value_type_;
+ std::unique_ptr<TakeKernel>* out_;
};
+Status TakeKernel::Make(const std::shared_ptr<DataType>& value_type,
+ const std::shared_ptr<DataType>& index_type,
+ std::unique_ptr<TakeKernel>* out) {
+ UnpackIndices visitor{value_type, out};
+ return VisitTypeInline(*index_type, &visitor);
+}
+
Status TakeKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& indices,
Datum* out) {
if (!values.is_array() || !indices.is_array()) {
return Status::Invalid("TakeKernel expects array values and indices");
}
+ auto values_array = values.make_array();
+ auto indices_array = indices.make_array();
std::shared_ptr<Array> out_array;
- TakeParameters params;
- params.context = ctx;
- params.values = values.make_array();
- params.indices = indices.make_array();
- params.options = options_;
- params.out = &out_array;
- UnpackIndices unpack = {params};
- RETURN_NOT_OK(VisitTypeInline(*indices.type(), &unpack));
+ RETURN_NOT_OK(Take(ctx, *values_array, *indices_array, &out_array));
*out = Datum(out_array);
return Status::OK();
}
+Status Take(FunctionContext* ctx, const Array& values, const Array& indices,
+ const TakeOptions& options, std::shared_ptr<Array>* out) {
+ Datum out_datum;
+ RETURN_NOT_OK(
+ Take(ctx, Datum(values.data()), Datum(indices.data()), options, &out_datum));
+ *out = out_datum.make_array();
+ return Status::OK();
+}
+
+Status Take(FunctionContext* ctx, const Datum& values, const Datum& indices,
+ const TakeOptions& options, Datum* out) {
+ std::unique_ptr<TakeKernel> kernel;
+ RETURN_NOT_OK(TakeKernel::Make(values.type(), indices.type(), &kernel));
+ return kernel->Call(ctx, values, indices, out);
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h
index 3aa5ed5..f064b72 100644
--- a/cpp/src/arrow/compute/kernels/take.h
+++ b/cpp/src/arrow/compute/kernels/take.h
@@ -44,40 +44,58 @@ struct ARROW_EXPORT TakeOptions {};
/// = [values[2], values[1], null, values[3]]
/// = ["c", "b", null, null]
///
-/// \param[in] context the FunctionContext
+/// \param[in] ctx the FunctionContext
/// \param[in] values array from which to take
/// \param[in] indices which values to take
/// \param[in] options options
/// \param[out] out resulting array
ARROW_EXPORT
-Status Take(FunctionContext* context, const Array& values, const Array& indices,
+Status Take(FunctionContext* ctx, const Array& values, const Array& indices,
const TakeOptions& options, std::shared_ptr<Array>* out);
/// \brief Take from an array of values at indices in another array
///
-/// \param[in] context the FunctionContext
+/// \param[in] ctx the FunctionContext
/// \param[in] values datum from which to take
/// \param[in] indices which values to take
/// \param[in] options options
/// \param[out] out resulting datum
ARROW_EXPORT
-Status Take(FunctionContext* context, const Datum& values, const Datum& indices,
+Status Take(FunctionContext* ctx, const Datum& values, const Datum& indices,
const TakeOptions& options, Datum* out);
/// \brief BinaryKernel implementing Take operation
class ARROW_EXPORT TakeKernel : public BinaryKernel {
public:
explicit TakeKernel(const std::shared_ptr<DataType>& type, TakeOptions options = {})
- : type_(type), options_(options) {}
+ : type_(type) {}
+ /// \brief BinaryKernel interface
+ ///
+ /// delegates to subclasses via Take()
Status Call(FunctionContext* ctx, const Datum& values, const Datum& indices,
Datum* out) override;
+ /// \brief output type of this kernel (identical to type of values taken)
std::shared_ptr<DataType> out_type() const override { return type_; }
- private:
+ /// \brief factory for TakeKernels
+ ///
+ /// \param[in] value_type constructed TakeKernel will support taking
+ /// values of this type
+ /// \param[in] index_type constructed TakeKernel will support taking
+ /// with indices of this type
+ /// \param[out] out created kernel
+ static Status Make(const std::shared_ptr<DataType>& value_type,
+ const std::shared_ptr<DataType>& index_type,
+ std::unique_ptr<TakeKernel>* out);
+
+ /// \brief single-array implementation
+ virtual Status Take(FunctionContext* ctx, const Array& values, const Array& indices,
+ std::shared_ptr<Array>* out) = 0;
+
+ protected:
std::shared_ptr<DataType> type_;
- TakeOptions options_;
};
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/util-internal.h b/cpp/src/arrow/compute/kernels/util-internal.h
index efd990f..c832583 100644
--- a/cpp/src/arrow/compute/kernels/util-internal.h
+++ b/cpp/src/arrow/compute/kernels/util-internal.h
@@ -131,7 +131,7 @@ class ARROW_EXPORT PrimitiveAllocatingUnaryKernel : public UnaryKernel {
/// \brief Kernel used to preallocate outputs for primitive types.
class ARROW_EXPORT PrimitiveAllocatingBinaryKernel : public BinaryKernel {
public:
- // \brief Construct with a kernel to delegate operatoions to.
+ // \brief Construct with a kernel to delegate operations to.
//
// Ownership is not taken of the delegate kernel, it must outlive
// the life time of this object.
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 655dd38..37da62c 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -51,19 +51,24 @@ def test_sum(arrow_type):
('double', np.arange(0, 0.5, 0.1)),
('string', ['a', 'b', None, 'ddd', 'ee']),
('binary', [b'a', b'b', b'c', b'ddd', b'ee']),
- (pa.binary(3), [b'abc', b'bcd', b'cde', b'def', b'efg'])
+ (pa.binary(3), [b'abc', b'bcd', b'cde', b'def', b'efg']),
+ (pa.list_(pa.int8()), [[1, 2], [3, 4], [5, 6], None, [9, 16]]),
+ (pa.struct([('a', pa.int8()), ('b', pa.int8())]), [
+ {'a': 1, 'b': 2}, None, {'a': 3, 'b': 4}, None, {'a': 5, 'b': 6}]),
])
def test_take(ty, values):
arr = pa.array(values, type=ty)
for indices_type in [pa.uint8(), pa.int64()]:
indices = pa.array([0, 4, 2, None], type=indices_type)
result = arr.take(indices)
+ result.validate()
expected = pa.array([values[0], values[4], values[2], None], type=ty)
assert result.equals(expected)
# empty indices
indices = pa.array([], type=indices_type)
result = arr.take(indices)
+ result.validate()
expected = pa.array([], type=ty)
assert result.equals(expected)
@@ -83,6 +88,7 @@ def test_take_indices_types():
'uint32', 'int32', 'uint64', 'int64']:
indices = pa.array([0, 4, 2, None], type=indices_type)
result = arr.take(indices)
+ result.validate()
expected = pa.array([0, 4, 2, None])
assert result.equals(expected)
@@ -97,17 +103,7 @@ def test_take_dictionary(ordered):
arr = pa.DictionaryArray.from_arrays([0, 1, 2, 0, 1, 2], ['a', 'b', 'c'],
ordered=ordered)
result = arr.take(pa.array([0, 1, 3]))
+ result.validate()
assert result.to_pylist() == ['a', 'b', 'a']
assert result.dictionary.to_pylist() == ['a', 'b', 'c']
assert result.type.ordered is ordered
-
-
-@pytest.mark.parametrize('array', [
- [[1, 2], [3, 4], [5, 6]],
- [{'a': 1, 'b': 2}, None, {'a': 3, 'b': 4}],
-], ids=['listarray', 'structarray'])
-def test_take_notimplemented(array):
- array = pa.array(array)
- indices = pa.array([0, 2])
- with pytest.raises(NotImplementedError):
- array.take(indices)