You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2020/06/17 20:14:58 UTC
[arrow] branch master updated: ARROW-9075: [C++] Optimized Filter
implementation: faster performance + compilation, smaller code size
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new d0f3b5f ARROW-9075: [C++] Optimized Filter implementation: faster performance + compilation, smaller code size
d0f3b5f is described below
commit d0f3b5f3c74f67c7ca941c98bd60148e9a9e94f0
Author: Wes McKinney <we...@apache.org>
AuthorDate: Wed Jun 17 15:14:30 2020 -0500
ARROW-9075: [C++] Optimized Filter implementation: faster performance + compilation, smaller code size
NOTE: the diff is artificially larger due to some code rearranging (that was necessitated because of how some data selection code is shared between the Take and Filter implementations).
Summary:
* Filter is now 1.5-10+x faster across the board, most notably on primitive types with very high selectivity or very low selectivity filters. The BitBlockCounters do a lot of the heavy lifting in that case but even in the worst case scenario when the block counters never encounter a "full" block, this is still consistently faster.
* Total -O3 code size for **both** Take and Filter is now about 600KB. That's down from about 8MB total prior to this patch and ARROW-5760
Some incidental changes:
* Implemented a fast conversion from boolean filter to take indices (aka "selection vector"), `compute::internal::GetTakeIndices`. I have also altered the implementation of filtering a record batch to use this, which should be faster (it would be good to have some benchmarks to confirm this).
* Various expansions to the BitBlockCounter classes that I needed to support this work
* Fixed a bug ARROW-9142 with RandomArrayGenerator::Boolean. The probability parameter was being interpreted as the probability of a false value rather than the probability of a true. IIUC with Bernoulli distributions, the probability specified is P(X = 1) not P(X = 0). Please someone confirm this.
Closes #7442 from wesm/ARROW-9075
Authored-by: Wes McKinney <we...@apache.org>
Signed-off-by: Wes McKinney <we...@apache.org>
---
cpp/src/arrow/CMakeLists.txt | 6 +-
cpp/src/arrow/array/array_binary.h | 2 +
cpp/src/arrow/compute/api_vector.cc | 5 +-
cpp/src/arrow/compute/api_vector.h | 23 +-
cpp/src/arrow/compute/benchmark_util.h | 4 +-
cpp/src/arrow/compute/kernels/CMakeLists.txt | 3 +-
.../compute/kernels/scalar_arithmetic_test.cc | 2 +-
cpp/src/arrow/compute/kernels/scalar_cast_test.cc | 3 +-
.../arrow/compute/kernels/scalar_compare_test.cc | 11 +-
cpp/src/arrow/compute/kernels/test_util.h | 11 +-
.../util_internal.cc} | 55 +-
cpp/src/arrow/compute/kernels/util_internal.h | 55 +
cpp/src/arrow/compute/kernels/vector_filter.cc | 248 ---
.../arrow/compute/kernels/vector_filter_test.cc | 721 --------
cpp/src/arrow/compute/kernels/vector_selection.cc | 1826 ++++++++++++++++++++
.../compute/kernels/vector_selection_benchmark.cc | 24 +-
.../compute/kernels/vector_selection_internal.h | 819 ---------
.../arrow/compute/kernels/vector_selection_test.cc | 1638 ++++++++++++++++++
cpp/src/arrow/compute/kernels/vector_take.cc | 989 -----------
cpp/src/arrow/compute/kernels/vector_take_test.cc | 844 ---------
cpp/src/arrow/compute/registry.cc | 3 +-
cpp/src/arrow/compute/registry_internal.h | 3 +-
cpp/src/arrow/dataset/filter.cc | 3 +-
cpp/src/arrow/testing/random.cc | 11 +-
cpp/src/arrow/testing/random.h | 4 +-
cpp/src/arrow/util/bit_block_counter.cc | 25 +-
cpp/src/arrow/util/bit_block_counter.h | 64 +
cpp/src/arrow/util/bit_block_counter_test.cc | 70 +-
python/pyarrow/includes/libarrow.pxd | 2 +
python/pyarrow/tests/test_compute.py | 5 +-
30 files changed, 3791 insertions(+), 3688 deletions(-)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 715373f..ac1b570 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -356,10 +356,10 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_set_lookup.cc
compute/kernels/scalar_string.cc
compute/kernels/scalar_validity.cc
- compute/kernels/vector_filter.cc
+ compute/kernels/util_internal.cc
compute/kernels/vector_hash.cc
- compute/kernels/vector_sort.cc
- compute/kernels/vector_take.cc)
+ compute/kernels/vector_selection.cc
+ compute/kernels/vector_sort.cc)
endif()
if(ARROW_FILESYSTEM)
diff --git a/cpp/src/arrow/array/array_binary.h b/cpp/src/arrow/array/array_binary.h
index b291de3..c54e504 100644
--- a/cpp/src/arrow/array/array_binary.h
+++ b/cpp/src/arrow/array/array_binary.h
@@ -85,6 +85,8 @@ class BaseBinaryArray : public FlatArray {
return raw_value_offsets_ + data_->offset;
}
+ const uint8_t* raw_data() const { return raw_data_; }
+
/// \brief Return the data buffer absolute offset of the data for the value
/// at the passed index.
///
diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc
index d0d67e7..dd9c439 100644
--- a/cpp/src/arrow/compute/api_vector.cc
+++ b/cpp/src/arrow/compute/api_vector.cc
@@ -21,8 +21,8 @@
#include <utility>
#include <vector>
+#include "arrow/array/builder_primitive.h"
#include "arrow/compute/exec.h"
-#include "arrow/compute/kernels/vector_selection_internal.h"
#include "arrow/datum.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
@@ -65,6 +65,9 @@ Result<std::shared_ptr<Array>> ValueCounts(const Datum& value, ExecContext* ctx)
return result.make_array();
}
+// ----------------------------------------------------------------------
+// Filter- and take-related selection functions
+
Result<Datum> Filter(const Datum& values, const Datum& filter,
const FilterOptions& options, ExecContext* ctx) {
// Invoke metafunction which deals with Datum kinds other than just Array,
diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h
index 48f9102..166bc10 100644
--- a/cpp/src/arrow/compute/api_vector.h
+++ b/cpp/src/arrow/compute/api_vector.h
@@ -38,7 +38,10 @@ struct FilterOptions : public FunctionOptions {
EMIT_NULL,
};
- static FilterOptions Defaults() { return FilterOptions{}; }
+ explicit FilterOptions(NullSelectionBehavior null_selection = DROP)
+ : null_selection_behavior(null_selection) {}
+
+ static FilterOptions Defaults() { return FilterOptions(); }
NullSelectionBehavior null_selection_behavior = DROP;
};
@@ -64,6 +67,24 @@ Result<Datum> Filter(const Datum& values, const Datum& filter,
const FilterOptions& options = FilterOptions::Defaults(),
ExecContext* ctx = NULLPTR);
+namespace internal {
+
+// These internal functions are implemented in kernels/vector_selection.cc
+
+/// \brief Return the number of selected indices in the boolean filter
+ARROW_EXPORT
+int64_t GetFilterOutputSize(const ArrayData& filter,
+ FilterOptions::NullSelectionBehavior null_selection);
+
+/// \brief Compute uint64 selection indices for use with Take given a boolean
+/// filter
+ARROW_EXPORT
+Result<std::shared_ptr<ArrayData>> GetTakeIndices(
+ const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection,
+ MemoryPool* memory_pool = default_memory_pool());
+
+} // namespace internal
+
struct ARROW_EXPORT TakeOptions : public FunctionOptions {
explicit TakeOptions(bool boundscheck = true) : boundscheck(boundscheck) {}
diff --git a/cpp/src/arrow/compute/benchmark_util.h b/cpp/src/arrow/compute/benchmark_util.h
index 1259d1b..edd2007 100644
--- a/cpp/src/arrow/compute/benchmark_util.h
+++ b/cpp/src/arrow/compute/benchmark_util.h
@@ -24,9 +24,11 @@
#include "arrow/util/cpu_info.h"
namespace arrow {
-namespace compute {
using internal::CpuInfo;
+
+namespace compute {
+
static CpuInfo* cpu_info = CpuInfo::GetInstance();
static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE);
diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt
index 9ff0d09..0082799 100644
--- a/cpp/src/arrow/compute/kernels/CMakeLists.txt
+++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt
@@ -38,9 +38,8 @@ add_arrow_benchmark(scalar_string_benchmark PREFIX "arrow-compute")
add_arrow_compute_test(vector_test
SOURCES
- vector_filter_test.cc
vector_hash_test.cc
- vector_take_test.cc
+ vector_selection_test.cc
vector_sort_test.cc
test_util.cc)
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index 2f2159e..4b64244 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -97,7 +97,7 @@ std::string MakeArray(Elements... elements) {
std::copy(elements_as_strings.begin(), elements_as_strings.end(),
elements_as_views.begin());
- return "[" + internal::JoinStrings(elements_as_views, ",") + "]";
+ return "[" + ::arrow::internal::JoinStrings(elements_as_views, ",") + "]";
}
template <typename T>
diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
index ace8759..4970c83 100644
--- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc
@@ -46,10 +46,11 @@
#include "arrow/compute/kernels/test_util.h"
namespace arrow {
-namespace compute {
using internal::checked_cast;
+namespace compute {
+
static constexpr const char* kInvalidUtf8 = "\xa0\xa1";
static std::vector<std::shared_ptr<DataType>> kNumericTypes = {
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
index 758e10b..8bedb96 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
@@ -36,6 +36,9 @@
#include "arrow/util/checked_cast.h"
namespace arrow {
+
+using internal::BitmapReader;
+
namespace compute {
using util::string_view;
@@ -115,8 +118,8 @@ Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs,
ArrayFromVector<BooleanType>(bitmap, &result);
} else {
std::vector<bool> null_bitmap(array->length());
- auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(),
- array->length());
+ auto reader =
+ BitmapReader(array->null_bitmap_data(), array->offset(), array->length());
for (int64_t i = 0; i < array->length(); i++, reader.Next()) {
null_bitmap[i] = reader.IsSet();
}
@@ -146,8 +149,8 @@ Datum SimpleScalarArrayCompare<StringType>(CompareOptions options, const Datum&
ArrayFromVector<BooleanType>(bitmap, &result);
} else {
std::vector<bool> null_bitmap(array->length());
- auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(),
- array->length());
+ auto reader =
+ BitmapReader(array->null_bitmap_data(), array->offset(), array->length());
for (int64_t i = 0; i < array->length(); i++, reader.Next()) {
null_bitmap[i] = reader.IsSet();
}
diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h
index c4e1f07..f01f5c7 100644
--- a/cpp/src/arrow/compute/kernels/test_util.h
+++ b/cpp/src/arrow/compute/kernels/test_util.h
@@ -39,6 +39,9 @@
// IWYU pragma: end_exports
namespace arrow {
+
+using internal::checked_cast;
+
namespace compute {
template <typename Type, typename T>
@@ -65,8 +68,8 @@ struct DatumEqual<Type, enable_if_floating_point<Type>> {
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
ASSERT_EQ(lhs.kind(), rhs.kind());
if (lhs.kind() == Datum::SCALAR) {
- auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
- auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
+ auto left = checked_cast<const ScalarType*>(lhs.scalar().get());
+ auto right = checked_cast<const ScalarType*>(rhs.scalar().get());
ASSERT_EQ(left->is_valid, right->is_valid);
ASSERT_EQ(left->type->id(), right->type->id());
ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
@@ -80,8 +83,8 @@ struct DatumEqual<Type, enable_if_integer<Type>> {
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
ASSERT_EQ(lhs.kind(), rhs.kind());
if (lhs.kind() == Datum::SCALAR) {
- auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
- auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
+ auto left = checked_cast<const ScalarType*>(lhs.scalar().get());
+ auto right = checked_cast<const ScalarType*>(rhs.scalar().get());
ASSERT_EQ(*left, *right);
}
}
diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/kernels/util_internal.cc
similarity index 51%
copy from cpp/src/arrow/compute/registry_internal.h
copy to cpp/src/arrow/compute/kernels/util_internal.cc
index 515b17b..32c6317 100644
--- a/cpp/src/arrow/compute/registry_internal.h
+++ b/cpp/src/arrow/compute/kernels/util_internal.cc
@@ -15,32 +15,47 @@
// specific language governing permissions and limitations
// under the License.
-#pragma once
+#include "arrow/compute/kernels/util_internal.h"
+
+#include <cstdint>
+
+#include "arrow/array/data.h"
+#include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
namespace arrow {
-namespace compute {
-class FunctionRegistry;
+using internal::checked_cast;
+namespace compute {
namespace internal {
-// Built-in scalar / elementwise functions
-void RegisterScalarArithmetic(FunctionRegistry* registry);
-void RegisterScalarBoolean(FunctionRegistry* registry);
-void RegisterScalarCast(FunctionRegistry* registry);
-void RegisterScalarComparison(FunctionRegistry* registry);
-void RegisterScalarSetLookup(FunctionRegistry* registry);
-void RegisterScalarStringAscii(FunctionRegistry* registry);
-void RegisterScalarValidity(FunctionRegistry* registry);
-
-// Vector functions
-void RegisterVectorFilter(FunctionRegistry* registry);
-void RegisterVectorHash(FunctionRegistry* registry);
-void RegisterVectorSort(FunctionRegistry* registry);
-void RegisterVectorTake(FunctionRegistry* registry);
-
-// Aggregate functions
-void RegisterScalarAggregateBasic(FunctionRegistry* registry);
+const uint8_t* GetValidityBitmap(const ArrayData& data) {
+ const uint8_t* bitmap = nullptr;
+ if (data.buffers[0]) {
+ bitmap = data.buffers[0]->data();
+ }
+ return bitmap;
+}
+
+int GetBitWidth(const DataType& type) {
+ return checked_cast<const FixedWidthType&>(type).bit_width();
+}
+
+PrimitiveArg GetPrimitiveArg(const ArrayData& arr) {
+ PrimitiveArg arg;
+ arg.is_valid = GetValidityBitmap(arr);
+ arg.data = arr.buffers[1]->data();
+ arg.bit_width = GetBitWidth(*arr.type);
+ arg.offset = arr.offset;
+ arg.length = arr.length;
+ if (arg.bit_width > 1) {
+ arg.data += arr.offset * arg.bit_width / 8;
+ }
+ // This may be kUnknownNullCount
+ arg.null_count = arr.null_count.load();
+ return arg;
+}
} // namespace internal
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h
new file mode 100644
index 0000000..7ab5996
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/util_internal.h
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <cstdint>
+
+#include "arrow/buffer.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+// An internal data structure for unpacking a primitive argument to pass to a
+// kernel implementation
+struct PrimitiveArg {
+ const uint8_t* is_valid;
+ // If the bit_width is a multiple of 8 (i.e. not boolean), then "data" should
+ // be shifted by offset * (bit_width / 8). For bit-packed data, the offset
+ // must be used when indexing.
+ const uint8_t* data;
+ int bit_width;
+ int64_t length;
+ int64_t offset;
+ // This may be kUnknownNullCount if the null_count has not yet been computed,
+ // so use null_count != 0 to determine "may have nulls".
+ int64_t null_count;
+};
+
+// Get validity bitmap data or return nullptr if there is no validity buffer
+const uint8_t* GetValidityBitmap(const ArrayData& data);
+
+int GetBitWidth(const DataType& type);
+
+// Reduce code size by dealing with the unboxing of the kernel inputs once
+// rather than duplicating compiled code to do all these in each kernel.
+PrimitiveArg GetPrimitiveArg(const ArrayData& arr);
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_filter.cc b/cpp/src/arrow/compute/kernels/vector_filter.cc
deleted file mode 100644
index db21d40..0000000
--- a/cpp/src/arrow/compute/kernels/vector_filter.cc
+++ /dev/null
@@ -1,248 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include "arrow/array/array_base.h"
-#include "arrow/array/array_primitive.h"
-#include "arrow/compute/api_vector.h"
-#include "arrow/compute/kernels/common.h"
-#include "arrow/compute/kernels/vector_selection_internal.h"
-#include "arrow/record_batch.h"
-#include "arrow/result.h"
-#include "arrow/visitor_inline.h"
-
-namespace arrow {
-namespace compute {
-namespace internal {
-
-// IndexSequence which yields the indices of positions in a BooleanArray
-// which are either null or true
-template <FilterOptions::NullSelectionBehavior NullSelectionBehavior>
-class FilterIndexSequence {
- public:
- // constexpr so we'll never instantiate bounds checking
- constexpr bool never_out_of_bounds() const { return true; }
- void set_never_out_of_bounds() {}
-
- constexpr FilterIndexSequence() = default;
-
- FilterIndexSequence(const BooleanArray& filter, int64_t out_length)
- : filter_(&filter), out_length_(out_length) {}
-
- std::pair<int64_t, bool> Next() {
- if (NullSelectionBehavior == FilterOptions::DROP) {
- // skip until an index is found at which the filter is true
- while (filter_->IsNull(index_) || !filter_->Value(index_)) {
- ++index_;
- }
- return std::make_pair(index_++, true);
- }
-
- // skip until an index is found at which the filter is either null or true
- while (filter_->IsValid(index_) && !filter_->Value(index_)) {
- ++index_;
- }
- bool is_valid = filter_->IsValid(index_);
- return std::make_pair(index_++, is_valid);
- }
-
- int64_t length() const { return out_length_; }
-
- int64_t null_count() const {
- if (NullSelectionBehavior == FilterOptions::DROP) {
- return 0;
- }
- return filter_->null_count();
- }
-
- private:
- const BooleanArray* filter_ = nullptr;
- int64_t index_ = 0, out_length_ = -1;
-};
-
-int64_t FilterOutputSize(FilterOptions::NullSelectionBehavior null_selection,
- const Array& arr) {
- const auto& filter = checked_cast<const BooleanArray&>(arr);
- // TODO(bkietz) this can be optimized. Use Bitmap::VisitWords
- int64_t size = 0;
- if (null_selection == FilterOptions::EMIT_NULL) {
- for (auto i = 0; i < filter.length(); ++i) {
- if (filter.IsNull(i) || filter.Value(i)) {
- ++size;
- }
- }
- } else {
- for (auto i = 0; i < filter.length(); ++i) {
- if (filter.IsValid(i) && filter.Value(i)) {
- ++size;
- }
- }
- }
- return size;
-}
-
-struct FilterState : public KernelState {
- explicit FilterState(FilterOptions options) : options(std::move(options)) {}
- FilterOptions options;
-};
-
-std::unique_ptr<KernelState> InitFilter(KernelContext*, const KernelInitArgs& args) {
- FilterOptions options;
- if (args.options == nullptr) {
- options = FilterOptions::Defaults();
- } else {
- options = *static_cast<const FilterOptions*>(args.options);
- }
- return std::unique_ptr<KernelState>(new FilterState(std::move(options)));
-}
-
-template <typename ValueType>
-struct FilterFunctor {
- using ArrayType = typename TypeTraits<ValueType>::ArrayType;
-
- template <FilterOptions::NullSelectionBehavior NullSelection>
- static void ExecImpl(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- using IS = FilterIndexSequence<NullSelection>;
- ArrayType values(batch[0].array());
- BooleanArray filter(batch[1].array());
- const int64_t output_size = FilterOutputSize(NullSelection, filter);
- std::shared_ptr<Array> result;
- KERNEL_RETURN_IF_ERROR(ctx, Select(ctx, values, IS(filter, output_size), &result));
- out->value = result->data();
- }
-
- static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& state = checked_cast<const FilterState&>(*ctx->state());
- if (state.options.null_selection_behavior == FilterOptions::EMIT_NULL) {
- ExecImpl<FilterOptions::EMIT_NULL>(ctx, batch, out);
- } else {
- ExecImpl<FilterOptions::DROP>(ctx, batch, out);
- }
- }
-};
-
-struct FilterKernelVisitor {
- template <typename Type>
- Status Visit(const Type&) {
- this->result = FilterFunctor<Type>::Exec;
- return Status::OK();
- }
-
- Status Create(const DataType& type) { return VisitTypeInline(type, this); }
- ArrayKernelExec result;
-};
-
-Status GetFilterKernel(const DataType& type, ArrayKernelExec* exec) {
- FilterKernelVisitor visitor;
- RETURN_NOT_OK(visitor.Create(type));
- *exec = visitor.result;
- return Status::OK();
-}
-
-Result<std::shared_ptr<RecordBatch>> FilterRecordBatch(const RecordBatch& batch,
- const Datum& filter,
- const FunctionOptions* options,
- ExecContext* ctx) {
- if (!filter.is_array()) {
- return Status::Invalid("Cannot filter a RecordBatch with a filter of kind ",
- filter.kind());
- }
-
- const auto& filter_opts = *static_cast<const FilterOptions*>(options);
- // TODO: Rewrite this to convert to selection vector and use Take
- std::vector<std::shared_ptr<Array>> columns(batch.num_columns());
- for (int i = 0; i < batch.num_columns(); ++i) {
- ARROW_ASSIGN_OR_RAISE(Datum out,
- Filter(batch.column(i)->data(), filter, filter_opts, ctx));
- columns[i] = out.make_array();
- }
-
- int64_t out_length;
- if (columns.size() == 0) {
- out_length =
- FilterOutputSize(filter_opts.null_selection_behavior, *filter.make_array());
- } else {
- out_length = columns[0]->length();
- }
- return RecordBatch::Make(batch.schema(), out_length, columns);
-}
-
-Result<std::shared_ptr<Table>> FilterTable(const Table& table, const Datum& filter,
- const FunctionOptions* options,
- ExecContext* ctx) {
- auto new_columns = table.columns();
- for (auto& column : new_columns) {
- ARROW_ASSIGN_OR_RAISE(
- Datum out_column,
- Filter(column, filter, *static_cast<const FilterOptions*>(options), ctx));
- column = out_column.chunked_array();
- }
- return Table::Make(table.schema(), std::move(new_columns));
-}
-
-class FilterMetaFunction : public MetaFunction {
- public:
- FilterMetaFunction() : MetaFunction("filter", Arity::Binary()) {}
-
- Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
- const FunctionOptions* options,
- ExecContext* ctx) const override {
- if (args[0].kind() == Datum::RECORD_BATCH) {
- auto values_batch = args[0].record_batch();
- ARROW_ASSIGN_OR_RAISE(
- std::shared_ptr<RecordBatch> out_batch,
- FilterRecordBatch(*args[0].record_batch(), args[1], options, ctx));
- return Datum(out_batch);
- } else if (args[0].kind() == Datum::TABLE) {
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> out_table,
- FilterTable(*args[0].table(), args[1], options, ctx));
- return Datum(out_table);
- } else {
- return CallFunction("array_filter", args, options, ctx);
- }
- }
-};
-
-void RegisterVectorFilter(FunctionRegistry* registry) {
- VectorKernel base;
- base.init = InitFilter;
-
- auto filter = std::make_shared<VectorFunction>("array_filter", Arity::Binary());
- InputType filter_ty = InputType::Array(boolean());
- OutputType out_ty(FirstType);
-
- auto AddKernel = [&](InputType in_ty, const DataType& example_type) {
- base.signature = KernelSignature::Make({in_ty, filter_ty}, out_ty);
- DCHECK_OK(GetFilterKernel(example_type, &base.exec));
- DCHECK_OK(filter->AddKernel(base));
- };
-
- for (const auto& value_ty : PrimitiveTypes()) {
- AddKernel(InputType::Array(value_ty), *value_ty);
- }
- // Other types where we may only on the DataType::id
- for (const auto& value_ty : ExampleParametricTypes()) {
- AddKernel(InputType::Array(value_ty->id()), *value_ty);
- }
- DCHECK_OK(registry->AddFunction(std::move(filter)));
-
- // Add filter metafunction
- DCHECK_OK(registry->AddFunction(std::make_shared<FilterMetaFunction>()));
-}
-
-} // namespace internal
-} // namespace compute
-} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_filter_test.cc b/cpp/src/arrow/compute/kernels/vector_filter_test.cc
deleted file mode 100644
index 3277891..0000000
--- a/cpp/src/arrow/compute/kernels/vector_filter_test.cc
+++ /dev/null
@@ -1,721 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "arrow/compute/api.h"
-#include "arrow/compute/kernels/test_util.h"
-#include "arrow/table.h"
-#include "arrow/testing/gtest_common.h"
-#include "arrow/testing/gtest_util.h"
-#include "arrow/testing/random.h"
-#include "arrow/testing/util.h"
-
-namespace arrow {
-namespace compute {
-
-using internal::checked_pointer_cast;
-using util::string_view;
-
-std::shared_ptr<Array> CoalesceNullToFalse(std::shared_ptr<Array> filter) {
- if (filter->null_count() == 0) {
- return filter;
- }
- const auto& data = *filter->data();
- auto is_true = std::make_shared<BooleanArray>(data.length, data.buffers[1]);
- auto is_valid = std::make_shared<BooleanArray>(data.length, data.buffers[0]);
- EXPECT_OK_AND_ASSIGN(Datum out_datum, arrow::compute::And(is_true, is_valid));
- return out_datum.make_array();
-}
-
-template <typename ArrowType>
-class TestFilterKernel : public ::testing::Test {
- protected:
- TestFilterKernel() {
- emit_null_.null_selection_behavior = FilterOptions::EMIT_NULL;
- drop_.null_selection_behavior = FilterOptions::DROP;
- }
-
- void AssertFilter(std::shared_ptr<Array> values, std::shared_ptr<Array> filter,
- std::shared_ptr<Array> expected) {
- // test with EMIT_NULL
- ASSERT_OK_AND_ASSIGN(Datum out_datum,
- arrow::compute::Filter(values, filter, emit_null_));
- auto actual = out_datum.make_array();
- ASSERT_OK(actual->ValidateFull());
- AssertArraysEqual(*expected, *actual);
-
- // test with DROP using EMIT_NULL and a coalesced filter
- auto coalesced_filter = CoalesceNullToFalse(filter);
- ASSERT_OK_AND_ASSIGN(out_datum,
- arrow::compute::Filter(values, coalesced_filter, emit_null_));
- expected = out_datum.make_array();
- ASSERT_OK_AND_ASSIGN(out_datum, arrow::compute::Filter(values, filter, drop_));
- actual = out_datum.make_array();
- AssertArraysEqual(*expected, *actual);
- }
-
- void AssertFilter(std::shared_ptr<DataType> type, const std::string& values,
- const std::string& filter, const std::string& expected) {
- AssertFilter(ArrayFromJSON(type, values), ArrayFromJSON(boolean(), filter),
- ArrayFromJSON(type, expected));
- }
-
- void ValidateFilter(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& filter_boxed) {
- ASSERT_OK_AND_ASSIGN(Datum out_datum,
- arrow::compute::Filter(values, filter_boxed, emit_null_));
- auto filtered_emit_null = out_datum.make_array();
- ASSERT_OK(filtered_emit_null->ValidateFull());
-
- ASSERT_OK_AND_ASSIGN(out_datum, arrow::compute::Filter(values, filter_boxed, drop_));
- auto filtered_drop = out_datum.make_array();
- ASSERT_OK(filtered_drop->ValidateFull());
-
- auto filter = checked_pointer_cast<BooleanArray>(filter_boxed);
- int64_t values_i = 0, emit_null_i = 0, drop_i = 0;
- for (; values_i < values->length(); ++values_i, ++emit_null_i, ++drop_i) {
- if (filter->IsNull(values_i)) {
- ASSERT_LT(emit_null_i, filtered_emit_null->length());
- ASSERT_TRUE(filtered_emit_null->IsNull(emit_null_i));
- // this element was (null) filtered out; don't examine filtered_drop
- --drop_i;
- continue;
- }
- if (!filter->Value(values_i)) {
- // this element was filtered out; don't examine filtered_emit_null
- --emit_null_i;
- --drop_i;
- continue;
- }
- ASSERT_LT(emit_null_i, filtered_emit_null->length());
- ASSERT_LT(drop_i, filtered_drop->length());
- ASSERT_TRUE(
- values->RangeEquals(values_i, values_i + 1, emit_null_i, filtered_emit_null));
- ASSERT_TRUE(values->RangeEquals(values_i, values_i + 1, drop_i, filtered_drop));
- }
- ASSERT_EQ(emit_null_i, filtered_emit_null->length());
- ASSERT_EQ(drop_i, filtered_drop->length());
- }
-
- FilterOptions emit_null_, drop_;
-};
-
-class TestFilterKernelWithNull : public TestFilterKernel<NullType> {
- protected:
- void AssertFilter(const std::string& values, const std::string& filter,
- const std::string& expected) {
- TestFilterKernel<NullType>::AssertFilter(ArrayFromJSON(null(), values),
- ArrayFromJSON(boolean(), filter),
- ArrayFromJSON(null(), expected));
- }
-};
-
-TEST_F(TestFilterKernelWithNull, FilterNull) {
- this->AssertFilter("[]", "[]", "[]");
-
- this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]");
- this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]");
-}
-
-class TestFilterKernelWithBoolean : public TestFilterKernel<BooleanType> {
- protected:
- void AssertFilter(const std::string& values, const std::string& filter,
- const std::string& expected) {
- TestFilterKernel<BooleanType>::AssertFilter(ArrayFromJSON(boolean(), values),
- ArrayFromJSON(boolean(), filter),
- ArrayFromJSON(boolean(), expected));
- }
-};
-
-TEST_F(TestFilterKernelWithBoolean, FilterBoolean) {
- this->AssertFilter("[]", "[]", "[]");
-
- this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]");
- this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]");
- this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]");
-}
-
-template <typename ArrowType>
-class TestFilterKernelWithNumeric : public TestFilterKernel<ArrowType> {
- protected:
- std::shared_ptr<DataType> type_singleton() {
- return TypeTraits<ArrowType>::type_singleton();
- }
-};
-
-TYPED_TEST_SUITE(TestFilterKernelWithNumeric, NumericArrowTypes);
-TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) {
- auto type = this->type_singleton();
- this->AssertFilter(type, "[]", "[]", "[]");
-
- this->AssertFilter(type, "[9]", "[0]", "[]");
- this->AssertFilter(type, "[9]", "[1]", "[9]");
- this->AssertFilter(type, "[9]", "[null]", "[null]");
- this->AssertFilter(type, "[null]", "[0]", "[]");
- this->AssertFilter(type, "[null]", "[1]", "[null]");
- this->AssertFilter(type, "[null]", "[null]", "[null]");
-
- this->AssertFilter(type, "[7, 8, 9]", "[0, 1, 0]", "[8]");
- this->AssertFilter(type, "[7, 8, 9]", "[1, 0, 1]", "[7, 9]");
- this->AssertFilter(type, "[null, 8, 9]", "[0, 1, 0]", "[8]");
- this->AssertFilter(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8]");
- this->AssertFilter(type, "[7, 8, 9]", "[1, null, 1]", "[7, null, 9]");
-
- this->AssertFilter(ArrayFromJSON(type, "[7, 8, 9]"),
- ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3, 3),
- ArrayFromJSON(type, "[7, 9]"));
-
- ASSERT_RAISES(Invalid,
- arrow::compute::Filter(ArrayFromJSON(type, "[7, 8, 9]"),
- ArrayFromJSON(boolean(), "[]"), this->emit_null_));
- ASSERT_RAISES(Invalid,
- arrow::compute::Filter(ArrayFromJSON(type, "[7, 8, 9]"),
- ArrayFromJSON(boolean(), "[]"), this->drop_));
-}
-
-TYPED_TEST(TestFilterKernelWithNumeric, FilterRandomNumeric) {
- auto rand = random::RandomArrayGenerator(kRandomSeed);
- for (size_t i = 3; i < 10; i++) {
- const int64_t length = static_cast<int64_t>(1ULL << i);
- for (auto null_probability : {0.0, 0.01, 0.25, 1.0}) {
- for (auto filter_probability : {0.0, 0.1, 0.5, 1.0}) {
- auto values = rand.Numeric<TypeParam>(length, 0, 127, null_probability);
- auto filter = rand.Boolean(length, filter_probability, null_probability);
- this->ValidateFilter(values, filter);
- }
- }
- }
-}
-
-template <typename CType>
-using Comparator = bool(CType, CType);
-
-template <typename CType>
-Comparator<CType>* GetComparator(CompareOperator op) {
- static Comparator<CType>* cmp[] = {
- // EQUAL
- [](CType l, CType r) { return l == r; },
- // NOT_EQUAL
- [](CType l, CType r) { return l != r; },
- // GREATER
- [](CType l, CType r) { return l > r; },
- // GREATER_EQUAL
- [](CType l, CType r) { return l >= r; },
- // LESS
- [](CType l, CType r) { return l < r; },
- // LESS_EQUAL
- [](CType l, CType r) { return l <= r; },
- };
- return cmp[op];
-}
-
-template <typename T, typename Fn, typename CType = typename TypeTraits<T>::CType>
-std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, Fn&& fn) {
- std::vector<CType> filtered;
- filtered.reserve(length);
- std::copy_if(data, data + length, std::back_inserter(filtered), std::forward<Fn>(fn));
- std::shared_ptr<Array> filtered_array;
- ArrayFromVector<T, CType>(filtered, &filtered_array);
- return filtered_array;
-}
-
-template <typename T, typename CType = typename TypeTraits<T>::CType>
-std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, CType val,
- CompareOperator op) {
- auto cmp = GetComparator<CType>(op);
- return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, val); });
-}
-
-template <typename T, typename CType = typename TypeTraits<T>::CType>
-std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length,
- const CType* other, CompareOperator op) {
- auto cmp = GetComparator<CType>(op);
- return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, *other++); });
-}
-
-TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
- using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
- using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
- using CType = typename TypeTraits<TypeParam>::CType;
-
- auto rand = random::RandomArrayGenerator(kRandomSeed);
- for (size_t i = 3; i < 10; i++) {
- const int64_t length = static_cast<int64_t>(1ULL << i);
- // TODO(bkietz) rewrite with some nulls
- auto array =
- checked_pointer_cast<ArrayType>(rand.Numeric<TypeParam>(length, 0, 100, 0));
- CType c_fifty = 50;
- auto fifty = std::make_shared<ScalarType>(c_fifty);
- for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
- ASSERT_OK_AND_ASSIGN(Datum selection, arrow::compute::Compare(array, Datum(fifty),
- CompareOptions(op)));
- ASSERT_OK_AND_ASSIGN(Datum filtered, arrow::compute::Filter(array, selection, {}));
- auto filtered_array = filtered.make_array();
- ASSERT_OK(filtered_array->ValidateFull());
- auto expected =
- CompareAndFilter<TypeParam>(array->raw_values(), array->length(), c_fifty, op);
- ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
- }
- }
-}
-
-TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) {
- using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
-
- auto rand = random::RandomArrayGenerator(kRandomSeed);
- for (size_t i = 3; i < 10; i++) {
- const int64_t length = static_cast<int64_t>(1ULL << i);
- auto lhs = checked_pointer_cast<ArrayType>(
- rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
- auto rhs = checked_pointer_cast<ArrayType>(
- rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
- for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
- ASSERT_OK_AND_ASSIGN(Datum selection,
- arrow::compute::Compare(lhs, rhs, CompareOptions(op)));
- ASSERT_OK_AND_ASSIGN(Datum filtered, arrow::compute::Filter(lhs, selection, {}));
- auto filtered_array = filtered.make_array();
- ASSERT_OK(filtered_array->ValidateFull());
- auto expected = CompareAndFilter<TypeParam>(lhs->raw_values(), lhs->length(),
- rhs->raw_values(), op);
- ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
- }
- }
-}
-
-TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
- using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
- using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
- using CType = typename TypeTraits<TypeParam>::CType;
-
- auto rand = random::RandomArrayGenerator(kRandomSeed);
- for (size_t i = 3; i < 10; i++) {
- const int64_t length = static_cast<int64_t>(1ULL << i);
- auto array = checked_pointer_cast<ArrayType>(
- rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
- CType c_fifty = 50, c_hundred = 100;
- auto fifty = std::make_shared<ScalarType>(c_fifty);
- auto hundred = std::make_shared<ScalarType>(c_hundred);
- ASSERT_OK_AND_ASSIGN(
- Datum greater_than_fifty,
- arrow::compute::Compare(array, Datum(fifty), CompareOptions(GREATER)));
- ASSERT_OK_AND_ASSIGN(
- Datum less_than_hundred,
- arrow::compute::Compare(array, Datum(hundred), CompareOptions(LESS)));
- ASSERT_OK_AND_ASSIGN(Datum selection,
- arrow::compute::And(greater_than_fifty, less_than_hundred));
- ASSERT_OK_AND_ASSIGN(Datum filtered, arrow::compute::Filter(array, selection, {}));
- auto filtered_array = filtered.make_array();
- ASSERT_OK(filtered_array->ValidateFull());
- auto expected = CompareAndFilter<TypeParam>(
- array->raw_values(), array->length(),
- [&](CType e) { return (e > c_fifty) && (e < c_hundred); });
- ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
- }
-}
-
-using StringTypes =
- ::testing::Types<BinaryType, StringType, LargeBinaryType, LargeStringType>;
-
-template <typename TypeClass>
-class TestFilterKernelWithString : public TestFilterKernel<TypeClass> {
- protected:
- std::shared_ptr<DataType> value_type() {
- return TypeTraits<TypeClass>::type_singleton();
- }
-
- void AssertFilter(const std::string& values, const std::string& filter,
- const std::string& expected) {
- TestFilterKernel<TypeClass>::AssertFilter(ArrayFromJSON(value_type(), values),
- ArrayFromJSON(boolean(), filter),
- ArrayFromJSON(value_type(), expected));
- }
-
- void AssertFilterDictionary(const std::string& dictionary_values,
- const std::string& dictionary_filter,
- const std::string& filter,
- const std::string& expected_filter) {
- auto dict = ArrayFromJSON(value_type(), dictionary_values);
- auto type = dictionary(int8(), value_type());
- ASSERT_OK_AND_ASSIGN(auto values,
- DictionaryArray::FromArrays(
- type, ArrayFromJSON(int8(), dictionary_filter), dict));
- ASSERT_OK_AND_ASSIGN(
- auto expected,
- DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_filter), dict));
- auto take_filter = ArrayFromJSON(boolean(), filter);
- TestFilterKernel<TypeClass>::AssertFilter(values, take_filter, expected);
- }
-};
-
-TYPED_TEST_SUITE(TestFilterKernelWithString, StringTypes);
-
-TYPED_TEST(TestFilterKernelWithString, FilterString) {
- this->AssertFilter(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["b"])");
- this->AssertFilter(R"([null, "b", "c"])", "[0, 1, 0]", R"(["b"])");
- this->AssertFilter(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b"])");
-}
-
-TYPED_TEST(TestFilterKernelWithString, FilterDictionary) {
- auto dict = R"(["a", "b", "c", "d", "e"])";
- this->AssertFilterDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[4]");
- this->AssertFilterDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[4]");
- this->AssertFilterDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4]");
-}
-
-class TestFilterKernelWithList : public TestFilterKernel<ListType> {
- public:
-};
-
-TEST_F(TestFilterKernelWithList, FilterListInt32) {
- std::string list_json = "[[], [1,2], null, [3]]";
- this->AssertFilter(list(int32()), list_json, "[0, 0, 0, 0]", "[]");
- this->AssertFilter(list(int32()), list_json, "[0, 1, 1, null]", "[[1,2], null, null]");
- this->AssertFilter(list(int32()), list_json, "[0, 0, 1, null]", "[null, null]");
- this->AssertFilter(list(int32()), list_json, "[1, 0, 0, 1]", "[[], [3]]");
- this->AssertFilter(list(int32()), list_json, "[1, 1, 1, 1]", list_json);
- this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]");
-}
-
-TEST_F(TestFilterKernelWithList, FilterListListInt32) {
- std::string list_json = R"([
- [],
- [[1], [2, null, 2], []],
- null,
- [[3, null], null]
- ])";
- auto type = list(list(int32()));
- this->AssertFilter(type, list_json, "[0, 0, 0, 0]", "[]");
- this->AssertFilter(type, list_json, "[0, 1, 1, null]", R"([
- [[1], [2, null, 2], []],
- null,
- null
- ])");
- this->AssertFilter(type, list_json, "[0, 0, 1, null]", "[null, null]");
- this->AssertFilter(type, list_json, "[1, 0, 0, 1]", R"([
- [],
- [[3, null], null]
- ])");
- this->AssertFilter(type, list_json, "[1, 1, 1, 1]", list_json);
- this->AssertFilter(type, list_json, "[0, 1, 0, 1]", R"([
- [[1], [2, null, 2], []],
- [[3, null], null]
- ])");
-}
-
-class TestFilterKernelWithLargeList : public TestFilterKernel<LargeListType> {};
-
-TEST_F(TestFilterKernelWithLargeList, FilterListInt32) {
- std::string list_json = "[[], [1,2], null, [3]]";
- this->AssertFilter(large_list(int32()), list_json, "[0, 0, 0, 0]", "[]");
- this->AssertFilter(large_list(int32()), list_json, "[0, 1, 1, null]",
- "[[1,2], null, null]");
-}
-
-class TestFilterKernelWithFixedSizeList : public TestFilterKernel<FixedSizeListType> {};
-
-TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) {
- std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
- this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 0, 0]", "[]");
- this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 1, null]",
- "[[1, null, 3], [4, 5, 6], null]");
- this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 1, null]",
- "[[4, 5, 6], null]");
- this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[1, 1, 1, 1]", list_json);
- this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 0, 1]",
- "[[1, null, 3], [7, 8, null]]");
-}
-
-class TestFilterKernelWithMap : public TestFilterKernel<MapType> {};
-
-TEST_F(TestFilterKernelWithMap, FilterMapStringToInt32) {
- std::string map_json = R"([
- [["joe", 0], ["mark", null]],
- null,
- [["cap", 8]],
- []
- ])";
- this->AssertFilter(map(utf8(), int32()), map_json, "[0, 0, 0, 0]", "[]");
- this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 1, null]", R"([
- null,
- [["cap", 8]],
- null
- ])");
- this->AssertFilter(map(utf8(), int32()), map_json, "[1, 1, 1, 1]", map_json);
- this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 0, 1]", "[null, []]");
-}
-
-class TestFilterKernelWithStruct : public TestFilterKernel<StructType> {};
-
-TEST_F(TestFilterKernelWithStruct, FilterStruct) {
- auto struct_type = struct_({field("a", int32()), field("b", utf8())});
- auto struct_json = R"([
- null,
- {"a": 1, "b": ""},
- {"a": 2, "b": "hello"},
- {"a": 4, "b": "eh"}
- ])";
- this->AssertFilter(struct_type, struct_json, "[0, 0, 0, 0]", "[]");
- this->AssertFilter(struct_type, struct_json, "[0, 1, 1, null]", R"([
- {"a": 1, "b": ""},
- {"a": 2, "b": "hello"},
- null
- ])");
- this->AssertFilter(struct_type, struct_json, "[1, 1, 1, 1]", struct_json);
- this->AssertFilter(struct_type, struct_json, "[1, 0, 1, 0]", R"([
- null,
- {"a": 2, "b": "hello"}
- ])");
-}
-
-class TestFilterKernelWithUnion : public TestFilterKernel<UnionType> {};
-
-TEST_F(TestFilterKernelWithUnion, FilterUnion) {
- for (auto union_ : UnionTypeFactories()) {
- auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5});
- auto union_json = R"([
- null,
- [2, 222],
- [5, "hello"],
- [5, "eh"],
- null,
- [2, 111]
- ])";
- this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0]", "[]");
- this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1]", R"([
- [2, 222],
- [5, "hello"],
- null,
- [2, 111]
- ])");
- this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0]", R"([
- null,
- [5, "hello"],
- null
- ])");
- this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1]", union_json);
- }
-}
-
-class TestFilterKernelWithRecordBatch : public TestFilterKernel<RecordBatch> {
- public:
- void AssertFilter(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
- const std::string& selection, FilterOptions options,
- const std::string& expected_batch) {
- std::shared_ptr<RecordBatch> actual;
-
- ASSERT_OK(this->Filter(schm, batch_json, selection, options, &actual));
- ASSERT_OK(actual->ValidateFull());
- ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
- }
-
- Status Filter(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
- const std::string& selection, FilterOptions options,
- std::shared_ptr<RecordBatch>* out) {
- auto batch = RecordBatchFromJSON(schm, batch_json);
- ARROW_ASSIGN_OR_RAISE(
- Datum out_datum,
- arrow::compute::Filter(batch, ArrayFromJSON(boolean(), selection), options));
- *out = out_datum.record_batch();
- return Status::OK();
- }
-};
-
-TEST_F(TestFilterKernelWithRecordBatch, FilterRecordBatch) {
- std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
- auto schm = schema(fields);
-
- auto batch_json = R"([
- {"a": null, "b": "yo"},
- {"a": 1, "b": ""},
- {"a": 2, "b": "hello"},
- {"a": 4, "b": "eh"}
- ])";
- for (auto options : {this->emit_null_, this->drop_}) {
- this->AssertFilter(schm, batch_json, "[0, 0, 0, 0]", options, "[]");
- this->AssertFilter(schm, batch_json, "[1, 1, 1, 1]", options, batch_json);
- this->AssertFilter(schm, batch_json, "[1, 0, 1, 0]", options, R"([
- {"a": null, "b": "yo"},
- {"a": 2, "b": "hello"}
- ])");
- }
-
- this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->drop_, R"([
- {"a": 1, "b": ""},
- {"a": 2, "b": "hello"}
- ])");
-
- this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->emit_null_, R"([
- {"a": 1, "b": ""},
- {"a": 2, "b": "hello"},
- {"a": null, "b": null}
- ])");
-}
-
-class TestFilterKernelWithChunkedArray : public TestFilterKernel<ChunkedArray> {
- public:
- void AssertFilter(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values, const std::string& filter,
- const std::vector<std::string>& expected) {
- std::shared_ptr<ChunkedArray> actual;
- ASSERT_OK(this->FilterWithArray(type, values, filter, &actual));
- ASSERT_OK(actual->ValidateFull());
- AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
- }
-
- void AssertChunkedFilter(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values,
- const std::vector<std::string>& filter,
- const std::vector<std::string>& expected) {
- std::shared_ptr<ChunkedArray> actual;
- ASSERT_OK(this->FilterWithChunkedArray(type, values, filter, &actual));
- ASSERT_OK(actual->ValidateFull());
- AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
- }
-
- Status FilterWithArray(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values,
- const std::string& filter, std::shared_ptr<ChunkedArray>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum out_datum,
- arrow::compute::Filter(ChunkedArrayFromJSON(type, values),
- ArrayFromJSON(boolean(), filter), {}));
- *out = out_datum.chunked_array();
- return Status::OK();
- }
-
- Status FilterWithChunkedArray(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values,
- const std::vector<std::string>& filter,
- std::shared_ptr<ChunkedArray>* out) {
- ARROW_ASSIGN_OR_RAISE(
- Datum out_datum,
- arrow::compute::Filter(ChunkedArrayFromJSON(type, values),
- ChunkedArrayFromJSON(boolean(), filter), {}));
- *out = out_datum.chunked_array();
- return Status::OK();
- }
-};
-
-TEST_F(TestFilterKernelWithChunkedArray, FilterChunkedArray) {
- this->AssertFilter(int8(), {"[]"}, "[]", {});
- this->AssertChunkedFilter(int8(), {"[]"}, {"[]"}, {});
-
- this->AssertFilter(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0]", {"[8]"});
- this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0]", "[1, 0]"}, {"[8]"});
- this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0, 1]", "[0]"}, {"[8]"});
-
- std::shared_ptr<ChunkedArray> arr;
- ASSERT_RAISES(
- Invalid, this->FilterWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 1, 1]", &arr));
- ASSERT_RAISES(Invalid, this->FilterWithChunkedArray(int8(), {"[7]", "[8, 9]"},
- {"[0, 1, 0]", "[1, 1]"}, &arr));
-}
-
-class TestFilterKernelWithTable : public TestFilterKernel<Table> {
- public:
- void AssertFilter(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& table_json, const std::string& filter,
- FilterOptions options,
- const std::vector<std::string>& expected_table) {
- std::shared_ptr<Table> actual;
-
- ASSERT_OK(this->FilterWithArray(schm, table_json, filter, options, &actual));
- ASSERT_OK(actual->ValidateFull());
- ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
- }
-
- void AssertChunkedFilter(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& table_json,
- const std::vector<std::string>& filter, FilterOptions options,
- const std::vector<std::string>& expected_table) {
- std::shared_ptr<Table> actual;
-
- ASSERT_OK(this->FilterWithChunkedArray(schm, table_json, filter, options, &actual));
- ASSERT_OK(actual->ValidateFull());
- AssertTablesEqual(*TableFromJSON(schm, expected_table), *actual,
- /*same_chunk_layout=*/false);
- }
-
- Status FilterWithArray(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& values,
- const std::string& filter, FilterOptions options,
- std::shared_ptr<Table>* out) {
- ARROW_ASSIGN_OR_RAISE(
- Datum out_datum,
- arrow::compute::Filter(TableFromJSON(schm, values),
- ArrayFromJSON(boolean(), filter), options));
- *out = out_datum.table();
- return Status::OK();
- }
-
- Status FilterWithChunkedArray(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& values,
- const std::vector<std::string>& filter,
- FilterOptions options, std::shared_ptr<Table>* out) {
- ARROW_ASSIGN_OR_RAISE(
- Datum out_datum,
- arrow::compute::Filter(TableFromJSON(schm, values),
- ChunkedArrayFromJSON(boolean(), filter), options));
- *out = out_datum.table();
- return Status::OK();
- }
-};
-
-TEST_F(TestFilterKernelWithTable, FilterTable) {
- std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
- auto schm = schema(fields);
-
- std::vector<std::string> table_json = {R"([
- {"a": null, "b": "yo"},
- {"a": 1, "b": ""}
- ])",
- R"([
- {"a": 2, "b": "hello"},
- {"a": 4, "b": "eh"}
- ])"};
- for (auto options : {this->emit_null_, this->drop_}) {
- this->AssertFilter(schm, table_json, "[0, 0, 0, 0]", options, {});
- this->AssertChunkedFilter(schm, table_json, {"[0]", "[0, 0, 0]"}, options, {});
- this->AssertFilter(schm, table_json, "[1, 1, 1, 1]", options, table_json);
- this->AssertChunkedFilter(schm, table_json, {"[1]", "[1, 1, 1]"}, options,
- table_json);
- }
-
- std::vector<std::string> expected_emit_null = {R"([
- {"a": 1, "b": ""}
- ])",
- R"([
- {"a": 2, "b": "hello"},
- {"a": null, "b": null}
- ])"};
- this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->emit_null_,
- expected_emit_null);
- this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->emit_null_,
- expected_emit_null);
-
- std::vector<std::string> expected_drop = {R"([{"a": 1, "b": ""}])",
- R"([{"a": 2, "b": "hello"}])"};
- this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->drop_, expected_drop);
- this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->drop_,
- expected_drop);
-}
-
-} // namespace compute
-} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_selection.cc b/cpp/src/arrow/compute/kernels/vector_selection.cc
new file mode 100644
index 0000000..77ec028
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/vector_selection.cc
@@ -0,0 +1,1826 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <cstring>
+#include <limits>
+
+#include "arrow/array/array_base.h"
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_dict.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/concatenate.h"
+#include "arrow/buffer_builder.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
+#include "arrow/extension_type.h"
+#include "arrow/record_batch.h"
+#include "arrow/result.h"
+#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/int_util.h"
+
+namespace arrow {
+
+using internal::BinaryBitBlockCounter;
+using internal::BitBlockCount;
+using internal::BitBlockCounter;
+using internal::BitmapReader;
+using internal::CopyBitmap;
+using internal::CountSetBits;
+using internal::GetArrayView;
+using internal::IndexBoundsCheck;
+using internal::OptionalBitBlockCounter;
+using internal::OptionalBitIndexer;
+
+namespace compute {
+namespace internal {
+
+int64_t GetFilterOutputSize(const ArrayData& filter,
+ FilterOptions::NullSelectionBehavior null_selection) {
+ int64_t output_size = 0;
+ if (filter.null_count.load() != 0) {
+ const uint8_t* filter_is_valid = filter.buffers[0]->data();
+ BinaryBitBlockCounter bit_counter(filter.buffers[1]->data(), filter.offset,
+ filter_is_valid, filter.offset, filter.length);
+ int64_t position = 0;
+ if (null_selection == FilterOptions::EMIT_NULL) {
+ while (position < filter.length) {
+ BitBlockCount block = bit_counter.NextOrNotWord();
+ output_size += block.popcount;
+ position += block.length;
+ }
+ } else {
+ while (position < filter.length) {
+ BitBlockCount block = bit_counter.NextAndWord();
+ output_size += block.popcount;
+ position += block.length;
+ }
+ }
+ } else {
+ // The filter has no nulls, so we can use CountSetBits
+ output_size = CountSetBits(filter.buffers[1]->data(), filter.offset, filter.length);
+ }
+ return output_size;
+}
+
+template <typename IndexType>
+Result<std::shared_ptr<ArrayData>> GetTakeIndicesImpl(
+ const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection,
+ MemoryPool* memory_pool) {
+ using T = typename IndexType::c_type;
+ typename TypeTraits<IndexType>::BuilderType builder(memory_pool);
+
+ const uint8_t* filter_data = filter.buffers[1]->data();
+ BitBlockCounter data_counter(filter_data, filter.offset, filter.length);
+
+ // The position relative to the start of the filter
+ T position = 0;
+
+ // The current position taking the filter offset into account
+ int64_t position_with_offset = filter.offset;
+ if (filter.null_count != 0) {
+ // The filter may have nulls, so we scan the validity bitmap and the filter
+ // data bitmap together, branching on the null selection type.
+ const uint8_t* filter_is_valid = filter.buffers[0]->data();
+
+ // To count blocks whether filter_data[i] || !filter_is_valid[i]
+ BinaryBitBlockCounter filter_counter(filter_data, filter.offset, filter_is_valid,
+ filter.offset, filter.length);
+ if (null_selection == FilterOptions::DROP) {
+ while (position < filter.length) {
+ BitBlockCount and_block = filter_counter.NextAndWord();
+ RETURN_NOT_OK(builder.Reserve(and_block.popcount));
+ if (and_block.AllSet()) {
+ // All the values are selected and non-null
+ for (int64_t i = 0; i < and_block.length; ++i) {
+ builder.UnsafeAppend(position++);
+ }
+ position_with_offset += and_block.length;
+ } else if (!and_block.NoneSet()) {
+ // Some of the values are false or null
+ for (int64_t i = 0; i < and_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid, position_with_offset) &&
+ BitUtil::GetBit(filter_data, position_with_offset)) {
+ builder.UnsafeAppend(position);
+ }
+ ++position;
+ ++position_with_offset;
+ }
+ } else {
+ position += and_block.length;
+ position_with_offset += and_block.length;
+ }
+ }
+ } else {
+ BitBlockCounter is_valid_counter(filter_is_valid, filter.offset, filter.length);
+ while (position < filter.length) {
+ // true OR NOT valid
+ BitBlockCount or_not_block = filter_counter.NextOrNotWord();
+ RETURN_NOT_OK(builder.Reserve(or_not_block.popcount));
+
+ // If the values are all valid and the or_not_block is full, then we
+ // can infer that all the values are true and skip the bit checking
+ BitBlockCount is_valid_block = is_valid_counter.NextWord();
+
+ if (or_not_block.AllSet() && is_valid_block.AllSet()) {
+ // All the values are selected and non-null
+ for (int64_t i = 0; i < or_not_block.length; ++i) {
+ builder.UnsafeAppend(position++);
+ }
+ position_with_offset += or_not_block.length;
+ } else {
+ // Some of the values are false or null
+ for (int64_t i = 0; i < or_not_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid, position_with_offset)) {
+ if (BitUtil::GetBit(filter_data, position_with_offset)) {
+ builder.UnsafeAppend(position);
+ }
+ } else {
+ // Null slot, so append a null
+ builder.UnsafeAppendNull();
+ }
+ ++position;
+ ++position_with_offset;
+ }
+ }
+ }
+ }
+ } else {
+ // The filter has no nulls, so we need only look for true values
+ BitBlockCount current_block = data_counter.NextWord();
+ while (position < filter.length) {
+ if (current_block.AllSet()) {
+ int64_t run_length = 0;
+
+ // If we've found a all-true block, then we scan forward until we find
+ // a block that has some false values (or we reach the end)
+ while (current_block.length > 0 && current_block.AllSet()) {
+ run_length += current_block.length;
+ current_block = data_counter.NextWord();
+ }
+
+ // Append the consecutive run of indices
+ RETURN_NOT_OK(builder.Reserve(run_length));
+ for (int64_t i = 0; i < run_length; ++i) {
+ builder.UnsafeAppend(position++);
+ }
+ position_with_offset += run_length;
+ // The current_block already computed, so advance to next loop
+ // iteration.
+ continue;
+ } else if (!current_block.NoneSet()) {
+ // Must do bitchecking on the current block
+ RETURN_NOT_OK(builder.Reserve(current_block.popcount));
+ for (int64_t i = 0; i < current_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data, position_with_offset)) {
+ builder.UnsafeAppend(position);
+ }
+ ++position;
+ ++position_with_offset;
+ }
+ } else {
+ position += current_block.length;
+ position_with_offset += current_block.length;
+ }
+ current_block = data_counter.NextWord();
+ }
+ }
+ std::shared_ptr<ArrayData> result;
+ RETURN_NOT_OK(builder.FinishInternal(&result));
+ return result;
+}
+
+Result<std::shared_ptr<ArrayData>> GetTakeIndices(
+ const ArrayData& filter, FilterOptions::NullSelectionBehavior null_selection,
+ MemoryPool* memory_pool) {
+ DCHECK_EQ(filter.type->id(), Type::BOOL);
+ if (filter.length <= std::numeric_limits<uint16_t>::max()) {
+ return GetTakeIndicesImpl<UInt16Type>(filter, null_selection, memory_pool);
+ } else if (filter.length <= std::numeric_limits<uint32_t>::max()) {
+ return GetTakeIndicesImpl<UInt32Type>(filter, null_selection, memory_pool);
+ } else {
+ // Arrays over 4 billion elements, not especially likely.
+ return Status::NotImplemented(
+ "Filter length exceeds UINT32_MAX, "
+ "consider a different strategy for selecting elements");
+ }
+}
+
+namespace {
+
+using FilterState = OptionsWrapper<FilterOptions>;
+using TakeState = OptionsWrapper<TakeOptions>;
+
+Status PreallocateData(KernelContext* ctx, int64_t length, int bit_width,
+ bool allocate_validity, Datum* out) {
+ // Preallocate memory
+ ArrayData* out_arr = out->mutable_array();
+ out_arr->length = length;
+ out_arr->buffers.resize(2);
+
+ if (allocate_validity) {
+ ARROW_ASSIGN_OR_RAISE(out_arr->buffers[0], ctx->AllocateBitmap(length));
+ }
+ if (bit_width == 1) {
+ ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->AllocateBitmap(length));
+ } else {
+ ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->Allocate(length * bit_width / 8));
+ }
+ return Status::OK();
+}
+
+// ----------------------------------------------------------------------
+// Implement optimized take for primitive types from boolean to 1/2/4/8-byte
+// C-type based types. Use common implementation for every byte width and only
+// generate code for unsigned integer indices, since after boundschecking to
+// check for negative numbers in the indices we can safely reinterpret_cast
+// signed integers as unsigned.
+
+/// \brief The Take implementation for primitive (fixed-width) types does not
+/// use the logical Arrow type but rather the physical C type. This way we
+/// only generate one take function for each byte width.
+///
+/// This function assumes that the indices have been boundschecked.
+template <typename IndexCType, typename ValueCType>
+struct PrimitiveTakeImpl {
+ static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices,
+ Datum* out_datum) {
+ auto values_data = reinterpret_cast<const ValueCType*>(values.data);
+ auto values_is_valid = values.is_valid;
+ auto values_offset = values.offset;
+
+ auto indices_data = reinterpret_cast<const IndexCType*>(indices.data);
+ auto indices_is_valid = indices.is_valid;
+ auto indices_offset = indices.offset;
+
+ ArrayData* out_arr = out_datum->mutable_array();
+ auto out = out_arr->GetMutableValues<ValueCType>(1);
+ auto out_is_valid = out_arr->buffers[0]->mutable_data();
+ auto out_offset = out_arr->offset;
+
+ // If either the values or indices have nulls, we preemptively zero out the
+ // out validity bitmap so that we don't have to use ClearBit in each
+ // iteration for nulls.
+ if (values.null_count != 0 || indices.null_count != 0) {
+ BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false);
+ }
+
+ OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset,
+ indices.length);
+ int64_t position = 0;
+ int64_t valid_count = 0;
+ while (position < indices.length) {
+ BitBlockCount block = indices_bit_counter.NextBlock();
+ if (values.null_count == 0) {
+ // Values are never null, so things are easier
+ valid_count += block.popcount;
+ if (block.popcount == block.length) {
+ // Fastest path: neither values nor index nulls
+ BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true);
+ for (int64_t i = 0; i < block.length; ++i) {
+ out[position] = values_data[indices_data[position]];
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some indices but not all are null
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+ // index is not null
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ out[position] = values_data[indices_data[position]];
+ } else {
+ out[position] = ValueCType{};
+ }
+ ++position;
+ }
+ } else {
+ memset(out + position, 0, sizeof(ValueCType) * block.length);
+ position += block.length;
+ }
+ } else {
+ // Values have nulls, so we must do random access into the values bitmap
+ if (block.popcount == block.length) {
+ // Faster path: indices are not null but values may be
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(values_is_valid,
+ values_offset + indices_data[position])) {
+ // value is not null
+ out[position] = values_data[indices_data[position]];
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ ++valid_count;
+ } else {
+ out[position] = ValueCType{};
+ }
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some but not all indices are null. Since we are doing
+ // random access in general we have to check the value nullness one by
+ // one.
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(indices_is_valid, indices_offset + position) &&
+ BitUtil::GetBit(values_is_valid,
+ values_offset + indices_data[position])) {
+ // index is not null && value is not null
+ out[position] = values_data[indices_data[position]];
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ ++valid_count;
+ } else {
+ out[position] = ValueCType{};
+ }
+ ++position;
+ }
+ } else {
+ memset(out + position, 0, sizeof(ValueCType) * block.length);
+ position += block.length;
+ }
+ }
+ }
+ out_arr->null_count = out_arr->length - valid_count;
+ }
+};
+
+template <typename IndexCType>
+struct BooleanTakeImpl {
+ static void Exec(const PrimitiveArg& values, const PrimitiveArg& indices,
+ Datum* out_datum) {
+ const uint8_t* values_data = values.data;
+ auto values_is_valid = values.is_valid;
+ auto values_offset = values.offset;
+
+ auto indices_data = reinterpret_cast<const IndexCType*>(indices.data);
+ auto indices_is_valid = indices.is_valid;
+ auto indices_offset = indices.offset;
+
+ ArrayData* out_arr = out_datum->mutable_array();
+ auto out = out_arr->buffers[1]->mutable_data();
+ auto out_is_valid = out_arr->buffers[0]->mutable_data();
+ auto out_offset = out_arr->offset;
+
+ // If either the values or indices have nulls, we preemptively zero out the
+ // out validity bitmap so that we don't have to use ClearBit in each
+ // iteration for nulls.
+ if (values.null_count != 0 || indices.null_count != 0) {
+ BitUtil::SetBitsTo(out_is_valid, out_offset, indices.length, false);
+ }
+ // Avoid uninitialized data in values array
+ BitUtil::SetBitsTo(out, out_offset, indices.length, false);
+
+ auto PlaceDataBit = [&](int64_t loc, IndexCType index) {
+ BitUtil::SetBitTo(out, out_offset + loc,
+ BitUtil::GetBit(values_data, values_offset + index));
+ };
+
+ OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset,
+ indices.length);
+ int64_t position = 0;
+ int64_t valid_count = 0;
+ while (position < indices.length) {
+ BitBlockCount block = indices_bit_counter.NextBlock();
+ if (values.null_count == 0) {
+ // Values are never null, so things are easier
+ valid_count += block.popcount;
+ if (block.popcount == block.length) {
+ // Fastest path: neither values nor index nulls
+ BitUtil::SetBitsTo(out_is_valid, out_offset + position, block.length, true);
+ for (int64_t i = 0; i < block.length; ++i) {
+ PlaceDataBit(position, indices_data[position]);
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some but not all indices are null
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+ // index is not null
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ PlaceDataBit(position, indices_data[position]);
+ }
+ ++position;
+ }
+ } else {
+ position += block.length;
+ }
+ } else {
+ // Values have nulls, so we must do random access into the values bitmap
+ if (block.popcount == block.length) {
+ // Faster path: indices are not null but values may be
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(values_is_valid,
+ values_offset + indices_data[position])) {
+ // value is not null
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ PlaceDataBit(position, indices_data[position]);
+ ++valid_count;
+ }
+ ++position;
+ }
+ } else if (block.popcount > 0) {
+ // Slow path: some but not all indices are null. Since we are doing
+ // random access in general we have to check the value nullness one by
+ // one.
+ for (int64_t i = 0; i < block.length; ++i) {
+ if (BitUtil::GetBit(indices_is_valid, indices_offset + position)) {
+ // index is not null
+ if (BitUtil::GetBit(values_is_valid,
+ values_offset + indices_data[position])) {
+ // value is not null
+ PlaceDataBit(position, indices_data[position]);
+ BitUtil::SetBit(out_is_valid, out_offset + position);
+ ++valid_count;
+ }
+ }
+ ++position;
+ }
+ } else {
+ position += block.length;
+ }
+ }
+ }
+ out_arr->null_count = out_arr->length - valid_count;
+ }
+};
+
+template <template <typename...> class TakeImpl, typename... Args>
+void TakeIndexDispatch(const PrimitiveArg& values, const PrimitiveArg& indices,
+ Datum* out) {
+ // With the simplifying assumption that boundschecking has taken place
+ // already at a higher level, we can now assume that the index values are all
+ // non-negative. Thus, we can interpret signed integers as unsigned and avoid
+ // having to generate double the amount of binary code to handle each integer
+ // width.
+ switch (indices.bit_width) {
+ case 8:
+ return TakeImpl<uint8_t, Args...>::Exec(values, indices, out);
+ case 16:
+ return TakeImpl<uint16_t, Args...>::Exec(values, indices, out);
+ case 32:
+ return TakeImpl<uint32_t, Args...>::Exec(values, indices, out);
+ case 64:
+ return TakeImpl<uint64_t, Args...>::Exec(values, indices, out);
+ default:
+ DCHECK(false) << "Invalid indices byte width";
+ break;
+ }
+}
+
+void PrimitiveTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const TakeState&>(*ctx->state());
+ if (state.options.boundscheck) {
+ KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length()));
+ }
+
+ PrimitiveArg values = GetPrimitiveArg(*batch[0].array());
+ PrimitiveArg indices = GetPrimitiveArg(*batch[1].array());
+
+ // TODO: When neither values nor indices contain nulls, we can skip
+ // allocating the validity bitmap altogether and save time and space. A
+ // streamlined PrimitiveTakeImpl would need to be written that skips all
+ // interactions with the output validity bitmap, though.
+ KERNEL_RETURN_IF_ERROR(ctx, PreallocateData(ctx, indices.length, values.bit_width,
+ /*allocate_validity=*/true, out));
+ switch (values.bit_width) {
+ case 1:
+ return TakeIndexDispatch<BooleanTakeImpl>(values, indices, out);
+ case 8:
+ return TakeIndexDispatch<PrimitiveTakeImpl, int8_t>(values, indices, out);
+ case 16:
+ return TakeIndexDispatch<PrimitiveTakeImpl, int16_t>(values, indices, out);
+ case 32:
+ return TakeIndexDispatch<PrimitiveTakeImpl, int32_t>(values, indices, out);
+ case 64:
+ return TakeIndexDispatch<PrimitiveTakeImpl, int64_t>(values, indices, out);
+ default:
+ DCHECK(false) << "Invalid values byte width";
+ break;
+ }
+}
+
+// ----------------------------------------------------------------------
+// Optimized and streamlined filter for primitive types
+
+// Use either BitBlockCounter or BinaryBitBlockCounter to quickly scan filter a
+// word at a time for the DROP selection type.
+class DropNullCounter {
+ public:
+ // validity bitmap may be null
+ DropNullCounter(const uint8_t* validity, const uint8_t* data, int64_t offset,
+ int64_t length)
+ : data_counter_(data, offset, length),
+ data_and_validity_counter_(data, offset, validity, offset, length),
+ has_validity_(validity != nullptr) {}
+
+ BitBlockCount NextBlock() {
+ if (has_validity_) {
+ // filter is true AND not null
+ return data_and_validity_counter_.NextAndWord();
+ } else {
+ return data_counter_.NextWord();
+ }
+ }
+
+ private:
+ // For when just data is present, but no validity bitmap
+ BitBlockCounter data_counter_;
+
+ // For when both validity bitmap and data are present
+ BinaryBitBlockCounter data_and_validity_counter_;
+ const bool has_validity_;
+};
+
+/// \brief The Filter implementation for primitive (fixed-width) types does not
+/// use the logical Arrow type but rather the physical C type. This way we only
+/// generate one take function for each byte width. We use the same
+/// implementation here for boolean and fixed-byte-size inputs with some
+/// template specialization.
+template <typename ArrowType>
+class PrimitiveFilterImpl {
+ public:
+ using T = typename std::conditional<std::is_same<ArrowType, BooleanType>::value,
+ uint8_t, typename ArrowType::c_type>::type;
+
+ PrimitiveFilterImpl(const PrimitiveArg& values, const PrimitiveArg& filter,
+ FilterOptions::NullSelectionBehavior null_selection,
+ Datum* out_datum)
+ : values_is_valid_(values.is_valid),
+ values_data_(reinterpret_cast<const T*>(values.data)),
+ values_null_count_(values.null_count),
+ values_offset_(values.offset),
+ values_length_(values.length),
+ filter_is_valid_(filter.is_valid),
+ filter_data_(filter.data),
+ filter_null_count_(filter.null_count),
+ filter_offset_(filter.offset),
+ null_selection_(null_selection) {
+ ArrayData* out_arr = out_datum->mutable_array();
+ if (out_arr->buffers[0] != nullptr) {
+ // May not be allocated if neither filter nor values contains nulls
+ out_is_valid_ = out_arr->buffers[0]->mutable_data();
+ }
+ out_data_ = reinterpret_cast<T*>(out_arr->buffers[1]->mutable_data());
+ out_offset_ = out_arr->offset;
+ out_length_ = out_arr->length;
+ out_position_ = 0;
+ }
+
+ void ExecNonNull() {
+ // Fast filter when values and filter are not null
+ // Bit counters used for both null_selection behaviors
+ BitBlockCounter filter_counter(filter_data_, filter_offset_, values_length_);
+
+ int64_t in_position = 0;
+ BitBlockCount current_block = filter_counter.NextWord();
+ while (in_position < values_length_) {
+ if (current_block.AllSet()) {
+ int64_t run_length = 0;
+ // If we've found a all-true block, then we scan forward until we find
+ // a block that has some false values (or we reach the end
+ while (current_block.length > 0 && current_block.AllSet()) {
+ run_length += current_block.length;
+ current_block = filter_counter.NextWord();
+ }
+ WriteValueSegment(in_position, run_length);
+ in_position += run_length;
+ } else if (current_block.NoneSet()) {
+ // Nothing selected
+ in_position += current_block.length;
+ current_block = filter_counter.NextWord();
+ } else {
+ // Some values selected
+ for (int64_t i = 0; i < current_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteValue(in_position);
+ }
+ ++in_position;
+ }
+ current_block = filter_counter.NextWord();
+ }
+ }
+ }
+
+ void Exec() {
+ if (filter_null_count_ == 0 && values_null_count_ == 0) {
+ return ExecNonNull();
+ }
+
+ // Bit counters used for both null_selection behaviors
+ DropNullCounter drop_null_counter(filter_is_valid_, filter_data_, filter_offset_,
+ values_length_);
+ OptionalBitBlockCounter data_counter(values_is_valid_, values_offset_,
+ values_length_);
+ OptionalBitBlockCounter filter_valid_counter(filter_is_valid_, filter_offset_,
+ values_length_);
+
+ auto WriteNotNull = [&](int64_t index) {
+ BitUtil::SetBit(out_is_valid_, out_offset_ + out_position_);
+ // Increments out_position_
+ WriteValue(index);
+ };
+
+ auto WriteMaybeNull = [&](int64_t index) {
+ BitUtil::SetBitTo(out_is_valid_, out_offset_ + out_position_,
+ BitUtil::GetBit(values_is_valid_, values_offset_ + index));
+ // Increments out_position_
+ WriteValue(index);
+ };
+
+ int64_t in_position = 0;
+ while (in_position < values_length_) {
+ BitBlockCount filter_block = drop_null_counter.NextBlock();
+ BitBlockCount filter_valid_block = filter_valid_counter.NextWord();
+ BitBlockCount data_block = data_counter.NextWord();
+ if (filter_block.AllSet() && data_block.AllSet()) {
+ // Fastest path: all values in block are included and not null
+ BitUtil::SetBitsTo(out_is_valid_, out_offset_ + out_position_,
+ filter_block.length, true);
+ WriteValueSegment(in_position, filter_block.length);
+ in_position += filter_block.length;
+ } else if (filter_block.AllSet()) {
+ // Faster: all values are selected, but some values are null
+ // Batch copy bits from values validity bitmap to output validity bitmap
+ CopyBitmap(values_is_valid_, values_offset_ + in_position, filter_block.length,
+ out_is_valid_, out_offset_ + out_position_);
+ WriteValueSegment(in_position, filter_block.length);
+ in_position += filter_block.length;
+ } else if (filter_block.NoneSet() && null_selection_ == FilterOptions::DROP) {
+ // For this exceedingly common case in low-selectivity filters we can
+ // skip further analysis of the data and move on to the next block.
+ in_position += filter_block.length;
+ } else {
+ // Some filter values are false or null
+ if (data_block.AllSet()) {
+ // No values are null
+ if (filter_valid_block.AllSet()) {
+ // Filter is non-null but some values are false
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteNotNull(in_position);
+ }
+ ++in_position;
+ }
+ } else if (null_selection_ == FilterOptions::DROP) {
+ // If any values are selected, they ARE NOT null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position) &&
+ BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteNotNull(in_position);
+ }
+ ++in_position;
+ }
+ } else { // null_selection == FilterOptions::EMIT_NULL
+ // Data values in this block are not null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ const bool is_valid =
+ BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position);
+ if (is_valid &&
+ BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ // Filter slot is non-null and set
+ WriteNotNull(in_position);
+ } else if (!is_valid) {
+ // Filter slot is null, so we have a null in the output
+ BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_);
+ WriteNull();
+ }
+ ++in_position;
+ }
+ }
+ } else { // !data_block.AllSet()
+ // Some values are null
+ if (filter_valid_block.AllSet()) {
+ // Filter is non-null but some values are false
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteMaybeNull(in_position);
+ }
+ ++in_position;
+ }
+ } else if (null_selection_ == FilterOptions::DROP) {
+ // If any values are selected, they ARE NOT null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position) &&
+ BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ WriteMaybeNull(in_position);
+ }
+ ++in_position;
+ }
+ } else { // null_selection == FilterOptions::EMIT_NULL
+ // Data values in this block are not null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ const bool is_valid =
+ BitUtil::GetBit(filter_is_valid_, filter_offset_ + in_position);
+ if (is_valid &&
+ BitUtil::GetBit(filter_data_, filter_offset_ + in_position)) {
+ // Filter slot is non-null and set
+ WriteMaybeNull(in_position);
+ } else if (!is_valid) {
+ // Filter slot is null, so we have a null in the output
+ BitUtil::ClearBit(out_is_valid_, out_offset_ + out_position_);
+ WriteNull();
+ }
+ ++in_position;
+ }
+ }
+ }
+ } // !filter_block.AllSet()
+ } // while(in_position < values_length_)
+ }
+
+ // Write the next out_position given the selected in_position for the input
+ // data and advance out_position
+ void WriteValue(int64_t in_position) {
+ out_data_[out_position_++] = values_data_[in_position];
+ }
+
+ void WriteValueSegment(int64_t in_start, int64_t length) {
+ std::memcpy(out_data_ + out_position_, values_data_ + in_start, length * sizeof(T));
+ out_position_ += length;
+ }
+
+ void WriteNull() {
+ // Zero the memory
+ out_data_[out_position_++] = T{};
+ }
+
+ private:
+ const uint8_t* values_is_valid_;
+ const T* values_data_;
+ int64_t values_null_count_;
+ int64_t values_offset_;
+ int64_t values_length_;
+ const uint8_t* filter_is_valid_;
+ const uint8_t* filter_data_;
+ int64_t filter_null_count_;
+ int64_t filter_offset_;
+ FilterOptions::NullSelectionBehavior null_selection_;
+ uint8_t* out_is_valid_;
+ T* out_data_;
+ int64_t out_offset_;
+ int64_t out_length_;
+ int64_t out_position_;
+};
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteValue(int64_t in_position) {
+ BitUtil::SetBitTo(out_data_, out_offset_ + out_position_++,
+ BitUtil::GetBit(values_data_, values_offset_ + in_position));
+}
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteValueSegment(int64_t in_start,
+ int64_t length) {
+ CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_,
+ out_offset_ + out_position_);
+ out_position_ += length;
+}
+
+template <>
+inline void PrimitiveFilterImpl<BooleanType>::WriteNull() {
+ // Zero the bit
+ BitUtil::ClearBit(out_data_, out_offset_ + out_position_++);
+}
+
+void PrimitiveFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const FilterState&>(*ctx->state());
+ PrimitiveArg values = GetPrimitiveArg(*batch[0].array());
+ PrimitiveArg filter = GetPrimitiveArg(*batch[1].array());
+ FilterOptions::NullSelectionBehavior null_selection =
+ state.options.null_selection_behavior;
+
+ int64_t output_length = GetFilterOutputSize(*batch[1].array(), null_selection);
+
+ // The output precomputed null count is unknown except in the narrow
+ // condition that all the values are non-null and the filter will not cause
+ // any new nulls to be created.
+ if (values.null_count == 0 &&
+ (null_selection == FilterOptions::DROP || filter.null_count == 0)) {
+ out->mutable_array()->null_count = 0;
+ } else {
+ out->mutable_array()->null_count = kUnknownNullCount;
+ }
+
+ // When neither the values nor filter is known to have any nulls, we will
+ // elect the optimized ExecNonNull path where there is no need to populate a
+ // validity bitmap.
+ bool allocate_validity = values.null_count != 0 || filter.null_count != 0;
+
+ KERNEL_RETURN_IF_ERROR(
+ ctx, PreallocateData(ctx, output_length, values.bit_width, allocate_validity, out));
+
+ switch (values.bit_width) {
+ case 1:
+ return PrimitiveFilterImpl<BooleanType>(values, filter, null_selection, out).Exec();
+ case 8:
+ return PrimitiveFilterImpl<UInt8Type>(values, filter, null_selection, out).Exec();
+ case 16:
+ return PrimitiveFilterImpl<UInt16Type>(values, filter, null_selection, out).Exec();
+ case 32:
+ return PrimitiveFilterImpl<UInt32Type>(values, filter, null_selection, out).Exec();
+ case 64:
+ return PrimitiveFilterImpl<UInt64Type>(values, filter, null_selection, out).Exec();
+ default:
+ DCHECK(false) << "Invalid values bit width";
+ break;
+ }
+}
+
+// ----------------------------------------------------------------------
+// Null take and filter
+
+void NullTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const TakeState&>(*ctx->state());
+ if (state.options.boundscheck) {
+ KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length()));
+ }
+ out->value = std::make_shared<NullArray>(batch.length)->data();
+}
+
+void NullFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const FilterState&>(*ctx->state());
+ int64_t output_length =
+ GetFilterOutputSize(*batch[1].array(), state.options.null_selection_behavior);
+ out->value = std::make_shared<NullArray>(output_length)->data();
+}
+
+// ----------------------------------------------------------------------
+// Dictionary take and filter
+
+void DictionaryTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const TakeState&>(*ctx->state());
+ DictionaryArray values(batch[0].array());
+ Datum result;
+ KERNEL_RETURN_IF_ERROR(
+ ctx, Take(Datum(values.indices()), batch[1], state.options, ctx->exec_context())
+ .Value(&result));
+ DictionaryArray taken_values(values.type(), result.make_array(), values.dictionary());
+ out->value = taken_values.data();
+}
+
+void DictionaryFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const FilterState&>(*ctx->state());
+ DictionaryArray dict_values(batch[0].array());
+ Datum result;
+ KERNEL_RETURN_IF_ERROR(ctx, Filter(Datum(dict_values.indices()), batch[1].array(),
+ state.options, ctx->exec_context())
+ .Value(&result));
+ DictionaryArray filtered_values(dict_values.type(), result.make_array(),
+ dict_values.dictionary());
+ out->value = filtered_values.data();
+}
+
+// ----------------------------------------------------------------------
+// Extension take and filter
+
+void ExtensionTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const TakeState&>(*ctx->state());
+ ExtensionArray values(batch[0].array());
+ Datum result;
+ KERNEL_RETURN_IF_ERROR(
+ ctx, Take(Datum(values.storage()), batch[1], state.options, ctx->exec_context())
+ .Value(&result));
+ ExtensionArray taken_values(values.type(), result.make_array());
+ out->value = taken_values.data();
+}
+
+void ExtensionFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const FilterState&>(*ctx->state());
+ ExtensionArray ext_values(batch[0].array());
+ Datum result;
+ KERNEL_RETURN_IF_ERROR(ctx, Filter(Datum(ext_values.storage()), batch[1].array(),
+ state.options, ctx->exec_context())
+ .Value(&result));
+ ExtensionArray filtered_values(ext_values.type(), result.make_array());
+ out->value = filtered_values.data();
+}
+
+// ----------------------------------------------------------------------
+// Implement take for other data types where there is less performance
+// sensitivity by visiting the selected indices.
+
+// Use CRTP to dispatch to type-specific processing of take indices for each
+// unsigned integer type.
+template <typename Impl, typename Type>
+struct Selection {
+ using ValuesArrayType = typename TypeTraits<Type>::ArrayType;
+
+ // Forwards the generic value visitors to the take index visitor template
+ template <typename IndexCType>
+ struct TakeAdapter {
+ static constexpr bool is_take = true;
+
+ Impl* impl;
+ explicit TakeAdapter(Impl* impl) : impl(impl) {}
+ template <typename ValidVisitor, typename NullVisitor>
+ Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ return impl->template VisitTake<IndexCType>(std::forward<ValidVisitor>(visit_valid),
+ std::forward<NullVisitor>(visit_null));
+ }
+ };
+
+ // Forwards the generic value visitors to the VisitFilter template
+ struct FilterAdapter {
+ static constexpr bool is_take = false;
+
+ Impl* impl;
+ explicit FilterAdapter(Impl* impl) : impl(impl) {}
+ template <typename ValidVisitor, typename NullVisitor>
+ Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ return impl->VisitFilter(std::forward<ValidVisitor>(visit_valid),
+ std::forward<NullVisitor>(visit_null));
+ }
+ };
+
+ KernelContext* ctx;
+ std::shared_ptr<ArrayData> values;
+ std::shared_ptr<ArrayData> selection;
+ int64_t output_length;
+ ArrayData* out;
+ TypedBufferBuilder<bool> validity_builder;
+
+ Selection(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out)
+ : ctx(ctx),
+ values(batch[0].array()),
+ selection(batch[1].array()),
+ output_length(output_length),
+ out(out->mutable_array()),
+ validity_builder(ctx->memory_pool()) {}
+
+ virtual ~Selection() = default;
+
+ Status FinishCommon() {
+ out->buffers.resize(values->buffers.size());
+ out->length = validity_builder.length();
+ out->null_count = validity_builder.false_count();
+ return validity_builder.Finish(&out->buffers[0]);
+ }
+
+ template <typename IndexCType, typename ValidVisitor, typename NullVisitor>
+ Status VisitTake(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ const auto indices_values = selection->GetValues<IndexCType>(1);
+ const uint8_t* is_valid = GetValidityBitmap(*selection);
+ OptionalBitIndexer indices_is_valid(selection->buffers[0], selection->offset);
+ OptionalBitIndexer values_is_valid(values->buffers[0], values->offset);
+
+ const bool values_have_nulls = values->null_count.load() != 0;
+ OptionalBitBlockCounter bit_counter(is_valid, selection->offset, selection->length);
+ int64_t position = 0;
+ while (position < selection->length) {
+ BitBlockCount block = bit_counter.NextBlock();
+ const bool indices_have_nulls = block.popcount < block.length;
+ if (!indices_have_nulls && !values_have_nulls) {
+ // Fastest path, neither indices nor values have nulls
+ validity_builder.UnsafeAppend(block.length, true);
+ for (int64_t i = 0; i < block.length; ++i) {
+ RETURN_NOT_OK(visit_valid(indices_values[position++]));
+ }
+ } else if (block.popcount > 0) {
+ // Since we have to branch on whether the indices are null or not, we
+ // combine the "non-null indices block but some values null" and
+ // "some-null indices block but values non-null" into a single loop.
+ for (int64_t i = 0; i < block.length; ++i) {
+ if ((!indices_have_nulls || indices_is_valid[position]) &&
+ values_is_valid[indices_values[position]]) {
+ validity_builder.UnsafeAppend(true);
+ RETURN_NOT_OK(visit_valid(indices_values[position]));
+ } else {
+ validity_builder.UnsafeAppend(false);
+ RETURN_NOT_OK(visit_null());
+ }
+ ++position;
+ }
+ } else {
+ // The whole block is null
+ validity_builder.UnsafeAppend(block.length, false);
+ for (int64_t i = 0; i < block.length; ++i) {
+ RETURN_NOT_OK(visit_null());
+ }
+ position += block.length;
+ }
+ }
+ return Status::OK();
+ }
+
+ // We use the NullVisitor both for "selected" nulls as well as "emitted"
+ // nulls coming from the filter when using FilterOptions::EMIT_NULL
+ template <typename ValidVisitor, typename NullVisitor>
+ Status VisitFilter(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
+ const auto& state = checked_cast<const FilterState&>(*ctx->state());
+ auto null_selection = state.options.null_selection_behavior;
+
+ const auto filter_data = selection->buffers[1]->data();
+
+ const uint8_t* filter_is_valid = GetValidityBitmap(*selection);
+ const int64_t filter_offset = selection->offset;
+ OptionalBitIndexer values_is_valid(values->buffers[0], values->offset);
+
+ // We use 3 block counters for fast scanning of the filter
+ //
+ // * values_valid_counter: for values null/not-null
+ // * filter_valid_counter: for filter null/not-null
+ // * filter_counter: for filter true/false
+ OptionalBitBlockCounter values_valid_counter(GetValidityBitmap(*values),
+ values->offset, values->length);
+ OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset,
+ selection->length);
+ BitBlockCounter filter_counter(filter_data, filter_offset, selection->length);
+ int64_t in_position = 0;
+
+ auto AppendNotNull = [&](int64_t index) -> Status {
+ validity_builder.UnsafeAppend(true);
+ return visit_valid(index);
+ };
+
+ auto AppendNull = [&]() -> Status {
+ validity_builder.UnsafeAppend(false);
+ return visit_null();
+ };
+
+ auto AppendMaybeNull = [&](int64_t index) -> Status {
+ if (values_is_valid[index]) {
+ return AppendNotNull(index);
+ } else {
+ return AppendNull();
+ }
+ };
+
+ while (in_position < selection->length) {
+ BitBlockCount filter_valid_block = filter_valid_counter.NextWord();
+ BitBlockCount values_valid_block = values_valid_counter.NextWord();
+ BitBlockCount filter_block = filter_counter.NextWord();
+ if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) {
+ // For this exceedingly common case in low-selectivity filters we can
+ // skip further analysis of the data and move on to the next block.
+ in_position += filter_block.length;
+ } else if (filter_valid_block.AllSet()) {
+ // Simpler path: no filter values are null
+ if (filter_block.AllSet()) {
+ // Fastest path: filter values are all true and not null
+ if (values_valid_block.AllSet()) {
+ // The values aren't null either
+ validity_builder.UnsafeAppend(filter_block.length, true);
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ RETURN_NOT_OK(visit_valid(in_position++));
+ }
+ } else {
+ // Some of the values in this block are null
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ RETURN_NOT_OK(AppendMaybeNull(in_position++));
+ }
+ }
+ } else { // !filter_block.AllSet()
+ // Some of the filter values are false, but all not null
+ if (values_valid_block.AllSet()) {
+ // All the values are not-null, so we can skip null checking for
+ // them
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ RETURN_NOT_OK(AppendNotNull(in_position));
+ }
+ ++in_position;
+ }
+ } else {
+ // Some of the values in the block are null, so we have to check
+ // each one
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ RETURN_NOT_OK(AppendMaybeNull(in_position));
+ }
+ ++in_position;
+ }
+ }
+ }
+ } else { // !filter_valid_block.AllSet()
+ // Some of the filter values are null, so we have to handle the DROP
+ // versus EMIT_NULL null selection behavior.
+ if (null_selection == FilterOptions::DROP) {
+ // Filter null values are treated as false.
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ if (BitUtil::GetBit(filter_is_valid, filter_offset + in_position) &&
+ BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ RETURN_NOT_OK(AppendMaybeNull(in_position));
+ }
+ ++in_position;
+ }
+ } else {
+ // Filter null values are appended to output as null whether the
+ // value in the corresponding slot is valid or not
+ for (int64_t i = 0; i < filter_block.length; ++i) {
+ const bool filter_not_null =
+ BitUtil::GetBit(filter_is_valid, filter_offset + in_position);
+ if (filter_not_null &&
+ BitUtil::GetBit(filter_data, filter_offset + in_position)) {
+ RETURN_NOT_OK(AppendMaybeNull(in_position));
+ } else if (!filter_not_null) {
+ // EMIT_NULL case
+ RETURN_NOT_OK(AppendNull());
+ }
+ ++in_position;
+ }
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ virtual Status Init() { return Status::OK(); }
+
+ // Implementation specific finish logic
+ virtual Status Finish() = 0;
+
+ Status ExecTake() {
+ RETURN_NOT_OK(this->validity_builder.Reserve(output_length));
+ RETURN_NOT_OK(Init());
+ int index_width =
+ checked_cast<const FixedWidthType&>(*this->selection->type).bit_width() / 8;
+
+ // CTRP dispatch here
+ switch (index_width) {
+ case 1: {
+ Status s =
+ static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint8_t>>();
+ RETURN_NOT_OK(s);
+ } break;
+ case 2: {
+ Status s =
+ static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint16_t>>();
+ RETURN_NOT_OK(s);
+ } break;
+ case 4: {
+ Status s =
+ static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint32_t>>();
+ RETURN_NOT_OK(s);
+ } break;
+ case 8: {
+ Status s =
+ static_cast<Impl*>(this)->template GenerateOutput<TakeAdapter<uint64_t>>();
+ RETURN_NOT_OK(s);
+ } break;
+ default:
+ DCHECK(false) << "Invalid index width";
+ break;
+ }
+ RETURN_NOT_OK(this->FinishCommon());
+ return Finish();
+ }
+
+ Status ExecFilter() {
+ RETURN_NOT_OK(this->validity_builder.Reserve(output_length));
+ RETURN_NOT_OK(Init());
+ // CRTP dispatch
+ Status s = static_cast<Impl*>(this)->template GenerateOutput<FilterAdapter>();
+ RETURN_NOT_OK(s);
+ RETURN_NOT_OK(this->FinishCommon());
+ return Finish();
+ }
+};
+
+#define LIFT_BASE_MEMBERS() \
+ using ValuesArrayType = typename Base::ValuesArrayType; \
+ using Base::ctx; \
+ using Base::values; \
+ using Base::selection; \
+ using Base::output_length; \
+ using Base::out; \
+ using Base::validity_builder
+
+static inline Status VisitNoop() { return Status::OK(); }
+
+// A take implementation for 32-bit and 64-bit variable binary types. Common
+// generated kernels are shared between Binary/String and
+// LargeBinary/LargeString
+template <typename Type>
+struct VarBinaryImpl : public Selection<VarBinaryImpl<Type>, Type> {
+ using offset_type = typename Type::offset_type;
+
+ using Base = Selection<VarBinaryImpl<Type>, Type>;
+ LIFT_BASE_MEMBERS();
+
+ std::shared_ptr<ArrayData> values_as_binary;
+ TypedBufferBuilder<offset_type> offset_builder;
+ TypedBufferBuilder<uint8_t> data_builder;
+
+ static constexpr int64_t kOffsetLimit = std::numeric_limits<offset_type>::max() - 1;
+
+ VarBinaryImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length,
+ Datum* out)
+ : Base(ctx, batch, output_length, out),
+ offset_builder(ctx->memory_pool()),
+ data_builder(ctx->memory_pool()) {}
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ ValuesArrayType typed_values(this->values_as_binary);
+
+ // Presize the data builder with a rough estimate of the required data size
+ if (values->length > 0) {
+ const double mean_value_length =
+ (typed_values.total_values_length() / static_cast<double>(values->length));
+
+ // TODO: See if possible to reduce output_length for take/filter cases
+ // where there are nulls in the selection array
+ RETURN_NOT_OK(
+ data_builder.Reserve(static_cast<int64_t>(mean_value_length * output_length)));
+ }
+ int64_t space_available = data_builder.capacity();
+
+ const offset_type* raw_offsets = typed_values.raw_value_offsets();
+ const uint8_t* raw_data = typed_values.raw_data();
+
+ offset_type offset = 0;
+ Adapter adapter(this);
+ RETURN_NOT_OK(adapter.Generate(
+ [&](int64_t index) {
+ offset_builder.UnsafeAppend(offset);
+ offset_type val_offset = raw_offsets[index];
+ offset_type val_size = raw_offsets[index + 1] - val_offset;
+
+ // Use static property to prune this code from the filter path in
+ // optimized builds
+ if (Adapter::is_take &&
+ ARROW_PREDICT_FALSE(static_cast<int64_t>(offset) +
+ static_cast<int64_t>(val_size)) > kOffsetLimit) {
+ return Status::Invalid("Take operation overflowed binary array capacity");
+ }
+ offset += val_size;
+ if (ARROW_PREDICT_FALSE(val_size > space_available)) {
+ RETURN_NOT_OK(data_builder.Reserve(val_size));
+ space_available = data_builder.capacity() - data_builder.length();
+ }
+ data_builder.UnsafeAppend(raw_data + val_offset, val_size);
+ space_available -= val_size;
+ return Status::OK();
+ },
+ [&]() {
+ offset_builder.UnsafeAppend(offset);
+ return Status::OK();
+ }));
+ offset_builder.UnsafeAppend(offset);
+ return Status::OK();
+ }
+
+ Status Init() override {
+ ARROW_ASSIGN_OR_RAISE(this->values_as_binary,
+ GetArrayView(this->values, TypeTraits<Type>::type_singleton()));
+ return offset_builder.Reserve(output_length + 1);
+ }
+
+ Status Finish() override {
+ RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1]));
+ return data_builder.Finish(&out->buffers[2]);
+ }
+};
+
+struct FSBImpl : public Selection<FSBImpl, FixedSizeBinaryType> {
+ using Base = Selection<FSBImpl, FixedSizeBinaryType>;
+ LIFT_BASE_MEMBERS();
+
+ TypedBufferBuilder<uint8_t> data_builder;
+
+ FSBImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out)
+ : Base(ctx, batch, output_length, out), data_builder(ctx->memory_pool()) {}
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ FixedSizeBinaryArray typed_values(this->values);
+ int32_t value_size = typed_values.byte_width();
+
+ RETURN_NOT_OK(data_builder.Reserve(value_size * output_length));
+ Adapter adapter(this);
+ return adapter.Generate(
+ [&](int64_t index) {
+ auto val = typed_values.GetView(index);
+ data_builder.UnsafeAppend(reinterpret_cast<const uint8_t*>(val.data()),
+ value_size);
+ return Status::OK();
+ },
+ [&]() {
+ data_builder.UnsafeAppend(value_size, static_cast<uint8_t>(0x00));
+ return Status::OK();
+ });
+ }
+
+ Status Finish() override { return data_builder.Finish(&out->buffers[1]); }
+};
+
+template <typename Type>
+struct ListImpl : public Selection<ListImpl<Type>, Type> {
+ using offset_type = typename Type::offset_type;
+
+ using Base = Selection<ListImpl<Type>, Type>;
+ LIFT_BASE_MEMBERS();
+
+ TypedBufferBuilder<offset_type> offset_builder;
+ typename TypeTraits<Type>::OffsetBuilderType child_index_builder;
+
+ ListImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out)
+ : Base(ctx, batch, output_length, out),
+ offset_builder(ctx->memory_pool()),
+ child_index_builder(ctx->memory_pool()) {}
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ ValuesArrayType typed_values(this->values);
+
+ // TODO presize child_index_builder with a similar heuristic as VarBinaryImpl
+
+ offset_type offset = 0;
+ Adapter adapter(this);
+ RETURN_NOT_OK(adapter.Generate(
+ [&](int64_t index) {
+ offset_builder.UnsafeAppend(offset);
+ offset_type value_offset = typed_values.value_offset(index);
+ offset_type value_length = typed_values.value_length(index);
+ offset += value_length;
+ RETURN_NOT_OK(child_index_builder.Reserve(value_length));
+ for (offset_type j = value_offset; j < value_offset + value_length; ++j) {
+ child_index_builder.UnsafeAppend(j);
+ }
+ return Status::OK();
+ },
+ [&]() {
+ offset_builder.UnsafeAppend(offset);
+ return Status::OK();
+ }));
+ offset_builder.UnsafeAppend(offset);
+ return Status::OK();
+ }
+
+ Status Init() override {
+ RETURN_NOT_OK(offset_builder.Reserve(output_length + 1));
+ return Status::OK();
+ }
+
+ Status Finish() override {
+ std::shared_ptr<Array> child_indices;
+ RETURN_NOT_OK(child_index_builder.Finish(&child_indices));
+
+ ValuesArrayType typed_values(this->values);
+
+ // No need to boundscheck the child values indices
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child,
+ Take(*typed_values.values(), *child_indices,
+ TakeOptions::NoBoundsCheck(), ctx->exec_context()));
+ RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1]));
+ out->child_data = {taken_child->data()};
+ return Status::OK();
+ }
+};
+
+struct FSLImpl : public Selection<FSLImpl, FixedSizeListType> {
+ Int64Builder child_index_builder;
+
+ using Base = Selection<FSLImpl, FixedSizeListType>;
+ LIFT_BASE_MEMBERS();
+
+ FSLImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, Datum* out)
+ : Base(ctx, batch, output_length, out), child_index_builder(ctx->memory_pool()) {}
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ ValuesArrayType typed_values(this->values);
+ int32_t list_size = typed_values.list_type()->list_size();
+
+ /// We must take list_size elements even for null elements of
+ /// indices.
+ RETURN_NOT_OK(child_index_builder.Reserve(output_length * list_size));
+
+ Adapter adapter(this);
+ return adapter.Generate(
+ [&](int64_t index) {
+ int64_t offset = index * list_size;
+ for (int64_t j = offset; j < offset + list_size; ++j) {
+ child_index_builder.UnsafeAppend(j);
+ }
+ return Status::OK();
+ },
+ [&]() { return child_index_builder.AppendNulls(list_size); });
+ }
+
+ Status Finish() override {
+ std::shared_ptr<Array> child_indices;
+ RETURN_NOT_OK(child_index_builder.Finish(&child_indices));
+
+ ValuesArrayType typed_values(this->values);
+
+ // No need to boundscheck the child values indices
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child,
+ Take(*typed_values.values(), *child_indices,
+ TakeOptions::NoBoundsCheck(), ctx->exec_context()));
+ out->child_data = {taken_child->data()};
+ return Status::OK();
+ }
+};
+
+// ----------------------------------------------------------------------
+// Struct selection implementations
+
+// We need a slightly different approach for StructType. For Take, we can
+// invoke Take on each struct field's data with boundschecking disabled. For
+// Filter on the other hand, if we naively call Filter on each field, then the
+// filter output length will have to be redundantly computed. Thus, for Filter
+// we instead convert the filter to selection indices and then invoke take.
+
+// Struct selection implementation. ONLY used for Take
+struct StructImpl : public Selection<StructImpl, StructType> {
+ using Base = Selection<StructImpl, StructType>;
+ LIFT_BASE_MEMBERS();
+ using Base::Base;
+
+ template <typename Adapter>
+ Status GenerateOutput() {
+ StructArray typed_values(values);
+ Adapter adapter(this);
+ // There's nothing to do for Struct except to generate the validity bitmap
+ return adapter.Generate([&](int64_t index) { return Status::OK(); },
+ /*visit_null=*/VisitNoop);
+ }
+
+ Status Finish() override {
+ StructArray typed_values(values);
+
+ // Select from children without boundschecking
+ out->child_data.resize(values->type->num_fields());
+ for (int field_index = 0; field_index < values->type->num_fields(); ++field_index) {
+ ARROW_ASSIGN_OR_RAISE(Datum taken_field,
+ Take(Datum(typed_values.field(field_index)), Datum(selection),
+ TakeOptions::NoBoundsCheck(), ctx->exec_context()));
+ out->child_data[field_index] = taken_field.array();
+ }
+ return Status::OK();
+ }
+};
+
+void StructFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const FilterState&>(*ctx->state());
+
+ // Transform filter to selection indices and then use Take.
+ std::shared_ptr<ArrayData> indices;
+ KERNEL_RETURN_IF_ERROR(
+ ctx, GetTakeIndices(*batch[1].array(), state.options.null_selection_behavior)
+ .Value(&indices));
+
+ Datum result;
+ KERNEL_RETURN_IF_ERROR(ctx, Take(batch[0], Datum(indices), TakeOptions::NoBoundsCheck(),
+ ctx->exec_context())
+ .Value(&result));
+ out->value = result.array();
+}
+
+#undef LIFT_BASE_MEMBERS
+
+// ----------------------------------------------------------------------
+// Implement Filter metafunction
+
+Result<std::shared_ptr<RecordBatch>> FilterRecordBatch(const RecordBatch& batch,
+ const Datum& filter,
+ const FunctionOptions* options,
+ ExecContext* ctx) {
+ if (batch.num_rows() != filter.length()) {
+ return Status::Invalid("Filter inputs must all be the same length");
+ }
+
+ // Convert filter to selection vector/indices and use Take
+ const auto& filter_opts = *static_cast<const FilterOptions*>(options);
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<ArrayData> indices,
+ GetTakeIndices(*filter.array(), filter_opts.null_selection_behavior));
+ std::vector<std::shared_ptr<Array>> columns(batch.num_columns());
+ for (int i = 0; i < batch.num_columns(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(Datum out, Take(batch.column(i)->data(), Datum(indices),
+ TakeOptions::NoBoundsCheck(), ctx));
+ columns[i] = out.make_array();
+ }
+ return RecordBatch::Make(batch.schema(), indices->length, columns);
+}
+
+Result<std::shared_ptr<Table>> FilterTable(const Table& table, const Datum& filter,
+ const FunctionOptions* options,
+ ExecContext* ctx) {
+ if (table.num_rows() != filter.length()) {
+ return Status::Invalid("Filter inputs must all be the same length");
+ }
+
+ // The selection vector "trick" cannot currently be easily applied on Table
+ // because either the filter or the columns may be ChunkedArray, so we use
+ // Filter recursively on the columns for now until a more efficient
+ // implementation of Take with chunked data is available.
+ auto new_columns = table.columns();
+ for (auto& column : new_columns) {
+ ARROW_ASSIGN_OR_RAISE(
+ Datum out_column,
+ Filter(column, filter, *static_cast<const FilterOptions*>(options), ctx));
+ column = out_column.chunked_array();
+ }
+ return Table::Make(table.schema(), std::move(new_columns));
+}
+
+class FilterMetaFunction : public MetaFunction {
+ public:
+ FilterMetaFunction() : MetaFunction("filter", Arity::Binary()) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ if (args[1].type()->id() != Type::BOOL) {
+ return Status::NotImplemented("Filter argument must be boolean type");
+ }
+
+ if (args[0].kind() == Datum::RECORD_BATCH) {
+ auto values_batch = args[0].record_batch();
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr<RecordBatch> out_batch,
+ FilterRecordBatch(*args[0].record_batch(), args[1], options, ctx));
+ return Datum(out_batch);
+ } else if (args[0].kind() == Datum::TABLE) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> out_table,
+ FilterTable(*args[0].table(), args[1], options, ctx));
+ return Datum(out_table);
+ } else {
+ return CallFunction("array_filter", args, options, ctx);
+ }
+ }
+};
+
+// ----------------------------------------------------------------------
+// Implement Take metafunction
+
+// Shorthand naming of these functions
+// A -> Array
+// C -> ChunkedArray
+// R -> RecordBatch
+// T -> Table
+
+Result<std::shared_ptr<Array>> TakeAA(const Array& values, const Array& indices,
+ const TakeOptions& options, ExecContext* ctx) {
+ ARROW_ASSIGN_OR_RAISE(Datum result,
+ CallFunction("array_take", {values, indices}, &options, ctx));
+ return result.make_array();
+}
+
+Result<std::shared_ptr<ChunkedArray>> TakeCA(const ChunkedArray& values,
+ const Array& indices,
+ const TakeOptions& options,
+ ExecContext* ctx) {
+ auto num_chunks = values.num_chunks();
+ std::vector<std::shared_ptr<Array>> new_chunks(1); // Hard-coded 1 for now
+ std::shared_ptr<Array> current_chunk;
+
+ // Case 1: `values` has a single chunk, so just use it
+ if (num_chunks == 1) {
+ current_chunk = values.chunk(0);
+ } else {
+ // TODO Case 2: See if all `indices` fall in the same chunk and call Array Take on it
+ // See
+ // https://github.com/apache/arrow/blob/6f2c9041137001f7a9212f244b51bc004efc29af/r/src/compute.cpp#L123-L151
+ // TODO Case 3: If indices are sorted, can slice them and call Array Take
+
+ // Case 4: Else, concatenate chunks and call Array Take
+ RETURN_NOT_OK(Concatenate(values.chunks(), default_memory_pool(), ¤t_chunk));
+ }
+ // Call Array Take on our single chunk
+ ARROW_ASSIGN_OR_RAISE(new_chunks[0], TakeAA(*current_chunk, indices, options, ctx));
+ return std::make_shared<ChunkedArray>(std::move(new_chunks));
+}
+
+Result<std::shared_ptr<ChunkedArray>> TakeCC(const ChunkedArray& values,
+ const ChunkedArray& indices,
+ const TakeOptions& options,
+ ExecContext* ctx) {
+ auto num_chunks = indices.num_chunks();
+ std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
+ for (int i = 0; i < num_chunks; i++) {
+ // Take with that indices chunk
+ // Note that as currently implemented, this is inefficient because `values`
+ // will get concatenated on every iteration of this loop
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ChunkedArray> current_chunk,
+ TakeCA(values, *indices.chunk(i), options, ctx));
+ // Concatenate the result to make a single array for this chunk
+ RETURN_NOT_OK(
+ Concatenate(current_chunk->chunks(), default_memory_pool(), &new_chunks[i]));
+ }
+ return std::make_shared<ChunkedArray>(std::move(new_chunks));
+}
+
+Result<std::shared_ptr<ChunkedArray>> TakeAC(const Array& values,
+ const ChunkedArray& indices,
+ const TakeOptions& options,
+ ExecContext* ctx) {
+ auto num_chunks = indices.num_chunks();
+ std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
+ for (int i = 0; i < num_chunks; i++) {
+ // Take with that indices chunk
+ ARROW_ASSIGN_OR_RAISE(new_chunks[i], TakeAA(values, *indices.chunk(i), options, ctx));
+ }
+ return std::make_shared<ChunkedArray>(std::move(new_chunks));
+}
+
+Result<std::shared_ptr<RecordBatch>> TakeRA(const RecordBatch& batch,
+ const Array& indices,
+ const TakeOptions& options,
+ ExecContext* ctx) {
+ auto ncols = batch.num_columns();
+ auto nrows = indices.length();
+ std::vector<std::shared_ptr<Array>> columns(ncols);
+ for (int j = 0; j < ncols; j++) {
+ ARROW_ASSIGN_OR_RAISE(columns[j], TakeAA(*batch.column(j), indices, options, ctx));
+ }
+ return RecordBatch::Make(batch.schema(), nrows, columns);
+}
+
+Result<std::shared_ptr<Table>> TakeTA(const Table& table, const Array& indices,
+ const TakeOptions& options, ExecContext* ctx) {
+ auto ncols = table.num_columns();
+ std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
+
+ for (int j = 0; j < ncols; j++) {
+ ARROW_ASSIGN_OR_RAISE(columns[j], TakeCA(*table.column(j), indices, options, ctx));
+ }
+ return Table::Make(table.schema(), columns);
+}
+
+Result<std::shared_ptr<Table>> TakeTC(const Table& table, const ChunkedArray& indices,
+ const TakeOptions& options, ExecContext* ctx) {
+ auto ncols = table.num_columns();
+ std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
+ for (int j = 0; j < ncols; j++) {
+ ARROW_ASSIGN_OR_RAISE(columns[j], TakeCC(*table.column(j), indices, options, ctx));
+ }
+ return Table::Make(table.schema(), columns);
+}
+
+// Metafunction for dispatching to different Take implementations other than
+// Array-Array.
+//
+// TODO: Revamp approach to executing Take operations. In addition to being
+// overly complex dispatching, there is no parallelization.
+class TakeMetaFunction : public MetaFunction {
+ public:
+ TakeMetaFunction() : MetaFunction("take", Arity::Binary()) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ Datum::Kind index_kind = args[1].kind();
+ const TakeOptions& take_opts = static_cast<const TakeOptions&>(*options);
+ switch (args[0].kind()) {
+ case Datum::ARRAY:
+ if (index_kind == Datum::ARRAY) {
+ return TakeAA(*args[0].make_array(), *args[1].make_array(), take_opts, ctx);
+ } else if (index_kind == Datum::CHUNKED_ARRAY) {
+ return TakeAC(*args[0].make_array(), *args[1].chunked_array(), take_opts, ctx);
+ }
+ break;
+ case Datum::CHUNKED_ARRAY:
+ if (index_kind == Datum::ARRAY) {
+ return TakeCA(*args[0].chunked_array(), *args[1].make_array(), take_opts, ctx);
+ } else if (index_kind == Datum::CHUNKED_ARRAY) {
+ return TakeCC(*args[0].chunked_array(), *args[1].chunked_array(), take_opts,
+ ctx);
+ }
+ break;
+ case Datum::RECORD_BATCH:
+ if (index_kind == Datum::ARRAY) {
+ return TakeRA(*args[0].record_batch(), *args[1].make_array(), take_opts, ctx);
+ }
+ break;
+ case Datum::TABLE:
+ if (index_kind == Datum::ARRAY) {
+ return TakeTA(*args[0].table(), *args[1].make_array(), take_opts, ctx);
+ } else if (index_kind == Datum::CHUNKED_ARRAY) {
+ return TakeTC(*args[0].table(), *args[1].chunked_array(), take_opts, ctx);
+ }
+ break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for take operation: "
+ "values=",
+ args[0].ToString(), "indices=", args[1].ToString());
+ }
+};
+
+// ----------------------------------------------------------------------
+
+template <typename Impl>
+void FilterExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const FilterState&>(*ctx->state());
+
+ // TODO: where are the values and filter length equality checked?
+ int64_t output_length =
+ GetFilterOutputSize(*batch[1].array(), state.options.null_selection_behavior);
+ Impl kernel(ctx, batch, output_length, out);
+ KERNEL_RETURN_IF_ERROR(ctx, kernel.ExecFilter());
+}
+
+template <typename Impl>
+void TakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& state = checked_cast<const TakeState&>(*ctx->state());
+ if (state.options.boundscheck) {
+ KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length()));
+ }
+ Impl kernel(ctx, batch, /*output_length=*/batch[1].length(), out);
+ KERNEL_RETURN_IF_ERROR(ctx, kernel.ExecTake());
+}
+
+struct SelectionKernelDescr {
+ InputType input;
+ ArrayKernelExec exec;
+};
+
+void RegisterSelectionFunction(const std::string& name, VectorKernel base_kernel,
+ InputType selection_type,
+ const std::vector<SelectionKernelDescr>& descrs,
+ FunctionRegistry* registry) {
+ auto func = std::make_shared<VectorFunction>(name, Arity::Binary());
+ for (auto& descr : descrs) {
+ base_kernel.signature = KernelSignature::Make(
+ {std::move(descr.input), selection_type}, OutputType(FirstType));
+ base_kernel.exec = descr.exec;
+ DCHECK_OK(func->AddKernel(base_kernel));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
+} // namespace
+
+void RegisterVectorSelection(FunctionRegistry* registry) {
+ // Filter kernels
+ std::vector<SelectionKernelDescr> filter_kernel_descrs = {
+ {InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveFilter},
+ {InputType(match::BinaryLike(), ValueDescr::ARRAY),
+ FilterExec<VarBinaryImpl<BinaryType>>},
+ {InputType(match::LargeBinaryLike(), ValueDescr::ARRAY),
+ FilterExec<VarBinaryImpl<LargeBinaryType>>},
+ {InputType::Array(Type::FIXED_SIZE_BINARY), FilterExec<FSBImpl>},
+ {InputType::Array(null()), NullFilter},
+ {InputType::Array(Type::DECIMAL), FilterExec<FSBImpl>},
+ {InputType::Array(Type::DICTIONARY), DictionaryFilter},
+ {InputType::Array(Type::EXTENSION), ExtensionFilter},
+ {InputType::Array(Type::LIST), FilterExec<ListImpl<ListType>>},
+ {InputType::Array(Type::LARGE_LIST), FilterExec<ListImpl<LargeListType>>},
+ {InputType::Array(Type::FIXED_SIZE_LIST), FilterExec<FSLImpl>},
+ {InputType::Array(Type::STRUCT), StructFilter},
+ // TODO: Reuse ListType kernel for MAP
+ {InputType::Array(Type::MAP), FilterExec<ListImpl<MapType>>},
+ };
+
+ VectorKernel filter_base;
+ filter_base.init = InitWrapOptions<FilterOptions>;
+ RegisterSelectionFunction("array_filter", filter_base,
+ /*selection_type=*/InputType::Array(boolean()),
+ filter_kernel_descrs, registry);
+
+ DCHECK_OK(registry->AddFunction(std::make_shared<FilterMetaFunction>()));
+
+ // Take kernels
+ std::vector<SelectionKernelDescr> take_kernel_descrs = {
+ {InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveTake},
+ {InputType(match::BinaryLike(), ValueDescr::ARRAY),
+ TakeExec<VarBinaryImpl<BinaryType>>},
+ {InputType(match::LargeBinaryLike(), ValueDescr::ARRAY),
+ TakeExec<VarBinaryImpl<LargeBinaryType>>},
+ {InputType::Array(Type::FIXED_SIZE_BINARY), TakeExec<FSBImpl>},
+ {InputType::Array(null()), NullTake},
+ {InputType::Array(Type::DECIMAL), TakeExec<FSBImpl>},
+ {InputType::Array(Type::DICTIONARY), DictionaryTake},
+ {InputType::Array(Type::EXTENSION), ExtensionTake},
+ {InputType::Array(Type::LIST), TakeExec<ListImpl<ListType>>},
+ {InputType::Array(Type::LARGE_LIST), TakeExec<ListImpl<LargeListType>>},
+ {InputType::Array(Type::FIXED_SIZE_LIST), TakeExec<FSLImpl>},
+ {InputType::Array(Type::STRUCT), TakeExec<StructImpl>},
+ // TODO: Reuse ListType kernel for MAP
+ {InputType::Array(Type::MAP), TakeExec<ListImpl<MapType>>},
+ };
+
+ VectorKernel take_base;
+ take_base.init = InitWrapOptions<TakeOptions>;
+ take_base.can_execute_chunkwise = false;
+ RegisterSelectionFunction(
+ "array_take", take_base,
+ /*selection_type=*/InputType(match::Integer(), ValueDescr::ARRAY),
+ take_kernel_descrs, registry);
+
+ DCHECK_OK(registry->AddFunction(std::make_shared<TakeMetaFunction>()));
+}
+
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc
index d03fa57..cc97c1c 100644
--- a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc
@@ -17,6 +17,8 @@
#include "benchmark/benchmark.h"
+#include <cstdint>
+
#include "arrow/compute/api_vector.h"
#include "arrow/compute/benchmark_util.h"
#include "arrow/compute/kernels/test_util.h"
@@ -43,10 +45,26 @@ std::vector<int64_t> g_data_sizes = {kL2Size};
// The benchmark state parameter references this vector of cases. Test high and
// low selectivity filters.
+
+// clang-format off
std::vector<FilterParams> g_filter_params = {
- {0., 0.95, 0.05}, {0., 0.10, 0.05}, {0.001, 0.95, 0.05}, {0.001, 0.10, 0.05},
- {0.01, 0.95, 0.05}, {0.01, 0.10, 0.05}, {0.1, 0.95, 0.05}, {0.1, 0.10, 0.05},
- {0.9, 0.95, 0.05}, {0.9, 0.10, 0.05}};
+ {0., 0.999, 0.05},
+ {0., 0.50, 0.05},
+ {0., 0.01, 0.05},
+ {0.001, 0.999, 0.05},
+ {0.001, 0.50, 0.05},
+ {0.001, 0.01, 0.05},
+ {0.01, 0.999, 0.05},
+ {0.01, 0.50, 0.05},
+ {0.01, 0.01, 0.05},
+ {0.1, 0.999, 0.05},
+ {0.1, 0.50, 0.05},
+ {0.1, 0.01, 0.05},
+ {0.9, 0.999, 0.05},
+ {0.9, 0.50, 0.05},
+ {0.9, 0.01, 0.05}
+};
+// clang-format on
// RAII struct to handle some of the boilerplate in filter
struct FilterArgs {
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h b/cpp/src/arrow/compute/kernels/vector_selection_internal.h
deleted file mode 100644
index 8908b3b..0000000
--- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h
+++ /dev/null
@@ -1,819 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#pragma once
-
-#include <algorithm>
-#include <limits>
-#include <memory>
-#include <utility>
-#include <vector>
-
-#include "arrow/builder.h"
-#include "arrow/compute/api_vector.h"
-#include "arrow/compute/kernels/common.h"
-#include "arrow/record_batch.h"
-#include "arrow/result.h"
-
-namespace arrow {
-namespace compute {
-namespace internal {
-
-template <typename T, typename R = void>
-using enable_if_not_base_binary =
- enable_if_t<!std::is_base_of<BaseBinaryType, T>::value, R>;
-
-// For non-binary builders, use regular value append
-template <typename Builder, typename Scalar>
-static enable_if_not_base_binary<typename Builder::TypeClass, Status> UnsafeAppend(
- Builder* builder, Scalar&& value) {
- builder->UnsafeAppend(std::forward<Scalar>(value));
- return Status::OK();
-}
-
-// For binary builders, need to reserve byte storage first
-template <typename Builder>
-static enable_if_base_binary<typename Builder::TypeClass, Status> UnsafeAppend(
- Builder* builder, util::string_view value) {
- RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
- builder->UnsafeAppend(value);
- return Status::OK();
-}
-
-/// \brief visit indices from an IndexSequence while bounds checking
-///
-/// \param[in] indices IndexSequence to visit
-/// \param[in] values array to bounds check against, if necessary
-/// \param[in] vis index visitor, signature must be Status(int64_t index, bool is_valid)
-template <bool SomeIndicesNull, bool SomeValuesNull, bool NeverOutOfBounds,
- typename IndexSequence, typename Visitor>
-Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) {
- for (int64_t i = 0; i < indices.length(); ++i) {
- auto index_valid = indices.Next();
- if (SomeIndicesNull && !index_valid.second) {
- RETURN_NOT_OK(vis(0, false));
- continue;
- }
-
- auto index = index_valid.first;
- if (!NeverOutOfBounds) {
- if (index < 0 || index >= values.length()) {
- return Status::IndexError("take index out of bounds");
- }
- } else {
- DCHECK_GE(index, 0);
- DCHECK_LT(index, values.length());
- }
-
- bool is_valid = !SomeValuesNull || values.IsValid(index);
- RETURN_NOT_OK(vis(index, is_valid));
- }
- return Status::OK();
-}
-
-template <bool SomeIndicesNull, bool SomeValuesNull, typename IndexSequence,
- typename Visitor>
-Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) {
- if (indices.never_out_of_bounds()) {
- return VisitIndices<SomeIndicesNull, SomeValuesNull, true>(
- indices, values, std::forward<Visitor>(vis));
- }
- return VisitIndices<SomeIndicesNull, SomeValuesNull, false>(indices, values,
- std::forward<Visitor>(vis));
-}
-
-template <bool SomeIndicesNull, typename IndexSequence, typename Visitor>
-Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) {
- if (values.null_count() == 0) {
- return VisitIndices<SomeIndicesNull, false>(indices, values,
- std::forward<Visitor>(vis));
- }
- return VisitIndices<SomeIndicesNull, true>(indices, values, std::forward<Visitor>(vis));
-}
-
-template <typename IndexSequence, typename Visitor>
-Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) {
- if (indices.null_count() == 0) {
- return VisitIndices<false>(indices, values, std::forward<Visitor>(vis));
- }
- return VisitIndices<true>(indices, values, std::forward<Visitor>(vis));
-}
-
-// Helper class for gathering values from an array
-template <typename IndexSequence>
-class Taker {
- public:
- explicit Taker(const std::shared_ptr<DataType>& type) : type_(type) {}
-
- virtual ~Taker() = default;
-
- // initialize this taker including constructing any children,
- // must be called once after construction before any other methods are called
- virtual Status Init() { return Status::OK(); }
-
- // reset this Taker and set KernelContext for taking an array
- // must be called each time the KernelContext may have changed
- virtual Status SetContext(KernelContext* ctx) = 0;
-
- // gather elements from an array at the provided indices
- virtual Status Take(const Array& values, IndexSequence indices) = 0;
-
- // assemble an array of all gathered values
- virtual Status Finish(std::shared_ptr<Array>*) = 0;
-
- // factory; the output Taker will support gathering values of the given type
- static Status Make(const std::shared_ptr<DataType>& type, std::unique_ptr<Taker>* out);
-
- static_assert(std::is_literal_type<IndexSequence>::value,
- "Index sequences must be literal type");
-
- static_assert(std::is_copy_constructible<IndexSequence>::value,
- "Index sequences must be copy constructible");
-
- static_assert(std::is_same<decltype(std::declval<IndexSequence>().Next()),
- std::pair<int64_t, bool>>::value,
- "An index sequence must yield pairs of indices:int64_t, validity:bool.");
-
- static_assert(std::is_same<decltype(std::declval<const IndexSequence>().length()),
- int64_t>::value,
- "An index sequence must provide its length.");
-
- static_assert(std::is_same<decltype(std::declval<const IndexSequence>().null_count()),
- int64_t>::value,
- "An index sequence must provide the number of nulls it will take.");
-
- static_assert(
- std::is_same<decltype(std::declval<const IndexSequence>().never_out_of_bounds()),
- bool>::value,
- "Index sequences must declare whether bounds checking is necessary");
-
- static_assert(
- std::is_same<decltype(std::declval<IndexSequence>().set_never_out_of_bounds()),
- void>::value,
- "An index sequence must support ignoring bounds checking.");
-
- protected:
- template <typename Builder>
- Status MakeBuilder(MemoryPool* pool, std::unique_ptr<Builder>* out) {
- std::unique_ptr<ArrayBuilder> builder;
- RETURN_NOT_OK(arrow::MakeBuilder(pool, type_, &builder));
- out->reset(checked_cast<Builder*>(builder.release()));
- return Status::OK();
- }
-
- std::shared_ptr<DataType> type_;
-};
-
-// an IndexSequence which yields indices from a specified range
-// or yields null for the length of that range
-class RangeIndexSequence {
- public:
- constexpr bool never_out_of_bounds() const { return true; }
- void set_never_out_of_bounds() {}
-
- constexpr RangeIndexSequence() = default;
-
- RangeIndexSequence(bool is_valid, int64_t offset, int64_t length)
- : is_valid_(is_valid), index_(offset), length_(length) {}
-
- std::pair<int64_t, bool> Next() { return std::make_pair(index_++, is_valid_); }
-
- int64_t length() const { return length_; }
-
- int64_t null_count() const { return is_valid_ ? 0 : length_; }
-
- private:
- bool is_valid_ = true;
- int64_t index_ = 0, length_ = -1;
-};
-
-// an IndexSequence which yields the values of an Array of integers
-template <typename IndexType>
-class ArrayIndexSequence {
- public:
- bool never_out_of_bounds() const { return never_out_of_bounds_; }
- void set_never_out_of_bounds() { never_out_of_bounds_ = true; }
-
- constexpr ArrayIndexSequence() = default;
-
- explicit ArrayIndexSequence(const Array& indices)
- : indices_(&checked_cast<const NumericArray<IndexType>&>(indices)) {}
-
- std::pair<int64_t, bool> Next() {
- if (indices_->IsNull(index_)) {
- ++index_;
- return std::make_pair(-1, false);
- }
- return std::make_pair(indices_->Value(index_++), true);
- }
-
- int64_t length() const { return indices_->length(); }
-
- int64_t null_count() const { return indices_->null_count(); }
-
- private:
- const NumericArray<IndexType>* indices_ = nullptr;
- int64_t index_ = 0;
- bool never_out_of_bounds_ = false;
-};
-
-// Default implementation: taking from a simple array into a builder requires only that
-// the array supports array.GetView() and the corresponding builder supports
-// builder.UnsafeAppend(array.GetView())
-template <typename IndexSequence, typename T>
-class TakerImpl : public Taker<IndexSequence> {
- public:
- using ArrayType = typename TypeTraits<T>::ArrayType;
- using BuilderType = typename TypeTraits<T>::BuilderType;
-
- using Taker<IndexSequence>::Taker;
-
- Status SetContext(KernelContext* ctx) override {
- return this->MakeBuilder(ctx->memory_pool(), &builder_);
- }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
- RETURN_NOT_OK(builder_->Reserve(indices.length()));
- return VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
- if (!is_valid) {
- builder_->UnsafeAppendNull();
- return Status::OK();
- }
- auto value = checked_cast<const ArrayType&>(values).GetView(index);
- return UnsafeAppend(builder_.get(), value);
- });
- }
-
- Status Finish(std::shared_ptr<Array>* out) override { return builder_->Finish(out); }
-
- private:
- std::unique_ptr<BuilderType> builder_;
-};
-
-// Gathering from NullArrays is trivial; skip the builder and just
-// do bounds checking
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, NullType> : public Taker<IndexSequence> {
- public:
- using Taker<IndexSequence>::Taker;
-
- Status SetContext(KernelContext*) override { return Status::OK(); }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
-
- length_ += indices.length();
-
- if (indices.never_out_of_bounds()) {
- return Status::OK();
- }
-
- return VisitIndices(indices, values, [](int64_t, bool) { return Status::OK(); });
- }
-
- Status Finish(std::shared_ptr<Array>* out) override {
- out->reset(new NullArray(length_));
- return Status::OK();
- }
-
- private:
- int64_t length_ = 0;
-};
-
-template <typename IndexSequence, typename TypeClass>
-class ListTakerImpl : public Taker<IndexSequence> {
- public:
- using offset_type = typename TypeClass::offset_type;
- using ArrayType = typename TypeTraits<TypeClass>::ArrayType;
-
- using Taker<IndexSequence>::Taker;
-
- Status Init() override {
- const auto& list_type = checked_cast<const TypeClass&>(*this->type_);
- return Taker<RangeIndexSequence>::Make(list_type.value_type(), &value_taker_);
- }
-
- Status SetContext(KernelContext* ctx) override {
- auto pool = ctx->memory_pool();
- null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool));
- offset_builder_.reset(new TypedBufferBuilder<offset_type>(pool));
- RETURN_NOT_OK(offset_builder_->Append(0));
- return value_taker_->SetContext(ctx);
- }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
-
- const auto& list_array = checked_cast<const ArrayType&>(values);
-
- RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length()));
- RETURN_NOT_OK(offset_builder_->Reserve(indices.length()));
-
- offset_type offset = offset_builder_->data()[offset_builder_->length() - 1];
- return VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
- null_bitmap_builder_->UnsafeAppend(is_valid);
-
- if (is_valid) {
- offset += list_array.value_length(index);
- RangeIndexSequence value_indices(true, list_array.value_offset(index),
- list_array.value_length(index));
- RETURN_NOT_OK(value_taker_->Take(*list_array.values(), value_indices));
- }
-
- offset_builder_->UnsafeAppend(offset);
- return Status::OK();
- });
- }
-
- Status Finish(std::shared_ptr<Array>* out) override {
- auto null_count = null_bitmap_builder_->false_count();
- auto length = null_bitmap_builder_->length();
-
- std::shared_ptr<Buffer> offsets, null_bitmap;
- RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap));
- RETURN_NOT_OK(offset_builder_->Finish(&offsets));
-
- std::shared_ptr<Array> taken_values;
- RETURN_NOT_OK(value_taker_->Finish(&taken_values));
-
- out->reset(new ArrayType(this->type_, length, offsets, taken_values, null_bitmap,
- null_count));
- return Status::OK();
- }
-
- std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_;
- std::unique_ptr<TypedBufferBuilder<offset_type>> offset_builder_;
- std::unique_ptr<Taker<RangeIndexSequence>> value_taker_;
-};
-
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, ListType> : public ListTakerImpl<IndexSequence, ListType> {
- using ListTakerImpl<IndexSequence, ListType>::ListTakerImpl;
-};
-
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, LargeListType>
- : public ListTakerImpl<IndexSequence, LargeListType> {
- using ListTakerImpl<IndexSequence, LargeListType>::ListTakerImpl;
-};
-
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, MapType> : public ListTakerImpl<IndexSequence, MapType> {
- using ListTakerImpl<IndexSequence, MapType>::ListTakerImpl;
-};
-
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, FixedSizeListType> : public Taker<IndexSequence> {
- public:
- using Taker<IndexSequence>::Taker;
-
- Status Init() override {
- const auto& list_type = checked_cast<const FixedSizeListType&>(*this->type_);
- return Taker<RangeIndexSequence>::Make(list_type.value_type(), &value_taker_);
- }
-
- Status SetContext(KernelContext* ctx) override {
- auto pool = ctx->memory_pool();
- null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool));
- return value_taker_->SetContext(ctx);
- }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
-
- const auto& list_array = checked_cast<const FixedSizeListArray&>(values);
- auto list_size = list_array.list_type()->list_size();
-
- RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length()));
- return VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
- null_bitmap_builder_->UnsafeAppend(is_valid);
-
- // for FixedSizeList, null lists are not empty (they also span a segment of
- // list_size in the child data), so we must append to value_taker_ even if !is_valid
- RangeIndexSequence value_indices(is_valid, list_array.value_offset(index),
- list_size);
- return value_taker_->Take(*list_array.values(), value_indices);
- });
- }
-
- Status Finish(std::shared_ptr<Array>* out) override {
- auto null_count = null_bitmap_builder_->false_count();
- auto length = null_bitmap_builder_->length();
-
- std::shared_ptr<Buffer> null_bitmap;
- RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap));
-
- std::shared_ptr<Array> taken_values;
- RETURN_NOT_OK(value_taker_->Finish(&taken_values));
-
- out->reset(new FixedSizeListArray(this->type_, length, taken_values, null_bitmap,
- null_count));
- return Status::OK();
- }
-
- protected:
- std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_;
- std::unique_ptr<Taker<RangeIndexSequence>> value_taker_;
-};
-
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, StructType> : public Taker<IndexSequence> {
- public:
- using Taker<IndexSequence>::Taker;
-
- Status Init() override {
- children_.resize(this->type_->num_fields());
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(
- Taker<IndexSequence>::Make(this->type_->field(i)->type(), &children_[i]));
- }
- return Status::OK();
- }
-
- Status SetContext(KernelContext* ctx) override {
- null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(ctx->memory_pool()));
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(children_[i]->SetContext(ctx));
- }
- return Status::OK();
- }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
-
- RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length()));
- RETURN_NOT_OK(VisitIndices(indices, values, [&](int64_t, bool is_valid) {
- null_bitmap_builder_->UnsafeAppend(is_valid);
- return Status::OK();
- }));
-
- // bounds checking was done while appending to the null bitmap
- indices.set_never_out_of_bounds();
-
- const auto& struct_array = checked_cast<const StructArray&>(values);
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(children_[i]->Take(*struct_array.field(i), indices));
- }
- return Status::OK();
- }
-
- Status Finish(std::shared_ptr<Array>* out) override {
- auto null_count = null_bitmap_builder_->false_count();
- auto length = null_bitmap_builder_->length();
- std::shared_ptr<Buffer> null_bitmap;
- RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap));
-
- ArrayVector fields(this->type_->num_fields());
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(children_[i]->Finish(&fields[i]));
- }
-
- out->reset(
- new StructArray(this->type_, length, std::move(fields), null_bitmap, null_count));
- return Status::OK();
- }
-
- protected:
- std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_;
- std::vector<std::unique_ptr<Taker<IndexSequence>>> children_;
-};
-
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, SparseUnionType> : public Taker<IndexSequence> {
- public:
- using Taker<IndexSequence>::Taker;
-
- Status Init() override {
- union_type_ = checked_cast<const SparseUnionType*>(this->type_.get());
- children_.resize(this->type_->num_fields());
-
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(
- Taker<IndexSequence>::Make(this->type_->field(i)->type(), &children_[i]));
- }
- return Status::OK();
- }
-
- Status SetContext(KernelContext* ctx) override {
- pool_ = ctx->memory_pool();
- null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool_));
- type_code_builder_.reset(new TypedBufferBuilder<int8_t>(pool_));
-
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(children_[i]->SetContext(ctx));
- }
- return Status::OK();
- }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
- const auto& union_array = checked_cast<const UnionArray&>(values);
- auto type_codes = union_array.raw_type_codes();
-
- RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length()));
- RETURN_NOT_OK(type_code_builder_->Reserve(indices.length()));
- RETURN_NOT_OK(VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
- null_bitmap_builder_->UnsafeAppend(is_valid);
- type_code_builder_->UnsafeAppend(type_codes[index]);
- return Status::OK();
- }));
-
- // bounds checking was done while appending to the null bitmap
- indices.set_never_out_of_bounds();
-
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(children_[i]->Take(*union_array.field(i), indices));
- }
- return Status::OK();
- }
-
- Status Finish(std::shared_ptr<Array>* out) override {
- auto null_count = null_bitmap_builder_->false_count();
- auto length = null_bitmap_builder_->length();
- std::shared_ptr<Buffer> null_bitmap, type_codes;
- RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap));
- RETURN_NOT_OK(type_code_builder_->Finish(&type_codes));
-
- ArrayVector fields(this->type_->num_fields());
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(children_[i]->Finish(&fields[i]));
- }
-
- out->reset(new SparseUnionArray(this->type_, length, std::move(fields), type_codes,
- null_bitmap, null_count));
- return Status::OK();
- }
-
- protected:
- int32_t* GetInt32(const std::shared_ptr<Buffer>& b) const {
- return reinterpret_cast<int32_t*>(b->mutable_data());
- }
-
- const SparseUnionType* union_type_ = nullptr;
- MemoryPool* pool_ = nullptr;
- std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_;
- std::unique_ptr<TypedBufferBuilder<int8_t>> type_code_builder_;
- std::unique_ptr<TypedBufferBuilder<int32_t>> offset_builder_;
- std::vector<std::unique_ptr<Taker<IndexSequence>>> children_;
- std::vector<int32_t> child_length_;
-};
-
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, DenseUnionType> : public Taker<IndexSequence> {
- public:
- using Taker<IndexSequence>::Taker;
-
- Status Init() override {
- union_type_ = checked_cast<const DenseUnionType*>(this->type_.get());
-
- dense_children_.resize(this->type_->num_fields());
- child_length_.resize(union_type_->max_type_code() + 1);
-
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(Taker<ArrayIndexSequence<Int32Type>>::Make(
- this->type_->field(i)->type(), &dense_children_[i]));
- }
-
- return Status::OK();
- }
-
- Status SetContext(KernelContext* ctx) override {
- pool_ = ctx->memory_pool();
- null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool_));
- type_code_builder_.reset(new TypedBufferBuilder<int8_t>(pool_));
- offset_builder_.reset(new TypedBufferBuilder<int32_t>(pool_));
-
- std::fill(child_length_.begin(), child_length_.end(), 0);
-
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(dense_children_[i]->SetContext(ctx));
- }
-
- return Status::OK();
- }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
- const auto& union_array = checked_cast<const UnionArray&>(values);
- auto type_codes = union_array.raw_type_codes();
-
- // Gathering from the offsets into child arrays is a bit tricky.
- std::vector<uint32_t> child_counts(union_type_->max_type_code() + 1);
- RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length()));
- RETURN_NOT_OK(type_code_builder_->Reserve(indices.length()));
- RETURN_NOT_OK(VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
- null_bitmap_builder_->UnsafeAppend(is_valid);
- type_code_builder_->UnsafeAppend(type_codes[index]);
- child_counts[type_codes[index]] += is_valid;
- return Status::OK();
- }));
-
- // bounds checking was done while appending to the null bitmap
- indices.set_never_out_of_bounds();
-
- // Allocate temporary storage for the offsets of all valid slots
- auto child_offsets_storage_size =
- std::accumulate(child_counts.begin(), child_counts.end(), 0);
- ARROW_ASSIGN_OR_RAISE(
- std::shared_ptr<Buffer> child_offsets_storage,
- AllocateBuffer(child_offsets_storage_size * sizeof(int32_t), pool_));
-
- // Partition offsets by type_code: child_offset_partitions[type_code] will
- // point to storage for child_counts[type_code] offsets
- std::vector<int32_t*> child_offset_partitions(child_counts.size());
- auto child_offsets_storage_data = GetInt32(child_offsets_storage);
- for (auto type_code : union_type_->type_codes()) {
- child_offset_partitions[type_code] = child_offsets_storage_data;
- child_offsets_storage_data += child_counts[type_code];
- }
- DCHECK_EQ(child_offsets_storage_data - GetInt32(child_offsets_storage),
- child_offsets_storage_size);
-
- // Fill child_offsets_storage with the taken offsets
- RETURN_NOT_OK(offset_builder_->Reserve(indices.length()));
- RETURN_NOT_OK(VisitIndices(indices, values, [&](int64_t index, bool is_valid) {
- auto type_code = type_codes[index];
- if (is_valid) {
- offset_builder_->UnsafeAppend(child_length_[type_code]++);
- *child_offset_partitions[type_code] =
- checked_cast<const DenseUnionArray&>(union_array).value_offset(index);
- ++child_offset_partitions[type_code];
- } else {
- offset_builder_->UnsafeAppend(0);
- }
- return Status::OK();
- }));
-
- // Take from each child at those offsets
- int64_t taken_offset_begin = 0;
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- auto type_code = union_type_->type_codes()[i];
- auto length = child_counts[type_code];
- Int32Array taken_offsets(
- length, SliceBuffer(child_offsets_storage, sizeof(int32_t) * taken_offset_begin,
- sizeof(int32_t) * length));
- ArrayIndexSequence<Int32Type> child_indices(taken_offsets);
- child_indices.set_never_out_of_bounds();
- RETURN_NOT_OK(dense_children_[i]->Take(*union_array.field(i), child_indices));
- taken_offset_begin += length;
- }
-
- return Status::OK();
- }
-
- Status Finish(std::shared_ptr<Array>* out) override {
- auto null_count = null_bitmap_builder_->false_count();
- auto length = null_bitmap_builder_->length();
- std::shared_ptr<Buffer> null_bitmap, type_codes, offsets;
- RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap));
- RETURN_NOT_OK(type_code_builder_->Finish(&type_codes));
- RETURN_NOT_OK(offset_builder_->Finish(&offsets));
-
- ArrayVector fields(this->type_->num_fields());
- for (int i = 0; i < this->type_->num_fields(); ++i) {
- RETURN_NOT_OK(dense_children_[i]->Finish(&fields[i]));
- }
-
- out->reset(new DenseUnionArray(this->type_, length, std::move(fields), type_codes,
- offsets, null_bitmap, null_count));
- return Status::OK();
- }
-
- protected:
- int32_t* GetInt32(const std::shared_ptr<Buffer>& b) const {
- return reinterpret_cast<int32_t*>(b->mutable_data());
- }
-
- const DenseUnionType* union_type_ = nullptr;
- MemoryPool* pool_ = nullptr;
- std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_;
- std::unique_ptr<TypedBufferBuilder<int8_t>> type_code_builder_;
- std::unique_ptr<TypedBufferBuilder<int32_t>> offset_builder_;
- std::vector<std::unique_ptr<Taker<IndexSequence>>> sparse_children_;
- std::vector<std::unique_ptr<Taker<ArrayIndexSequence<Int32Type>>>> dense_children_;
- std::vector<int32_t> child_length_;
-};
-
-// taking from a DictionaryArray is accomplished by taking from its indices
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, DictionaryType> : public Taker<IndexSequence> {
- public:
- using Taker<IndexSequence>::Taker;
-
- Status Init() override {
- const auto& dict_type = checked_cast<const DictionaryType&>(*this->type_);
- return Taker<IndexSequence>::Make(dict_type.index_type(), &index_taker_);
- }
-
- Status SetContext(KernelContext* ctx) override {
- dictionary_ = nullptr;
- return index_taker_->SetContext(ctx);
- }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
- const auto& dict_array = checked_cast<const DictionaryArray&>(values);
-
- if (dictionary_ != nullptr && dictionary_ != dict_array.dictionary()) {
- return Status::NotImplemented(
- "taking from DictionaryArrays with different dictionaries");
- } else {
- dictionary_ = dict_array.dictionary();
- }
- return index_taker_->Take(*dict_array.indices(), indices);
- }
-
- Status Finish(std::shared_ptr<Array>* out) override {
- std::shared_ptr<Array> taken_indices;
- RETURN_NOT_OK(index_taker_->Finish(&taken_indices));
- out->reset(new DictionaryArray(this->type_, taken_indices, dictionary_));
- return Status::OK();
- }
-
- protected:
- std::shared_ptr<Array> dictionary_;
- std::unique_ptr<Taker<IndexSequence>> index_taker_;
-};
-
-// taking from an ExtensionArray is accomplished by taking from its storage
-template <typename IndexSequence>
-class TakerImpl<IndexSequence, ExtensionType> : public Taker<IndexSequence> {
- public:
- using Taker<IndexSequence>::Taker;
-
- Status Init() override {
- const auto& ext_type = checked_cast<const ExtensionType&>(*this->type_);
- return Taker<IndexSequence>::Make(ext_type.storage_type(), &storage_taker_);
- }
-
- Status SetContext(KernelContext* ctx) override {
- return storage_taker_->SetContext(ctx);
- }
-
- Status Take(const Array& values, IndexSequence indices) override {
- DCHECK(this->type_->Equals(values.type()));
- const auto& ext_array = checked_cast<const ExtensionArray&>(values);
- return storage_taker_->Take(*ext_array.storage(), indices);
- }
-
- Status Finish(std::shared_ptr<Array>* out) override {
- std::shared_ptr<Array> taken_storage;
- RETURN_NOT_OK(storage_taker_->Finish(&taken_storage));
- out->reset(new ExtensionArray(this->type_, taken_storage));
- return Status::OK();
- }
-
- protected:
- std::unique_ptr<Taker<IndexSequence>> storage_taker_;
-};
-
-template <typename IndexSequence>
-struct TakerMakeImpl {
- template <typename T>
- Status Visit(const T&) {
- out_->reset(new TakerImpl<IndexSequence, T>(type_));
- return (*out_)->Init();
- }
-
- std::shared_ptr<DataType> type_;
- std::unique_ptr<Taker<IndexSequence>>* out_;
-};
-
-template <typename IndexSequence>
-Status Taker<IndexSequence>::Make(const std::shared_ptr<DataType>& type,
- std::unique_ptr<Taker>* out) {
- TakerMakeImpl<IndexSequence> visitor{type, out};
- return VisitTypeInline(*type, &visitor);
-}
-
-int64_t FilterOutputSize(FilterOptions::NullSelectionBehavior null_selection,
- const Array& filter);
-
-template <typename IndexSequence>
-Status Select(KernelContext* ctx, const Array& values, IndexSequence sequence,
- std::shared_ptr<Array>* out) {
- std::unique_ptr<Taker<IndexSequence>> taker;
- RETURN_NOT_OK(Taker<IndexSequence>::Make(values.type(), &taker));
- RETURN_NOT_OK(taker->SetContext(ctx));
- RETURN_NOT_OK(taker->Take(values, std::move(sequence)));
- return taker->Finish(out);
-}
-
-} // namespace internal
-} // namespace compute
-} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
new file mode 100644
index 0000000..fb8182e
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc
@@ -0,0 +1,1638 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <iostream>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "arrow/compute/api.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/testing/util.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+using util::string_view;
+
+namespace compute {
+
+// ----------------------------------------------------------------------
+// Some random data generation helpers
+
+template <typename Type>
+std::shared_ptr<Array> RandomNumeric(int64_t length, double null_probability,
+ random::RandomArrayGenerator* rng) {
+ return rng->Numeric<Type>(length, 0, 127, null_probability);
+}
+
+std::shared_ptr<Array> RandomBoolean(int64_t length, double null_probability,
+ random::RandomArrayGenerator* rng) {
+ return rng->Boolean(length, 0.5, null_probability);
+}
+
+std::shared_ptr<Array> RandomString(int64_t length, double null_probability,
+ random::RandomArrayGenerator* rng) {
+ return rng->String(length, 0, 32, null_probability);
+}
+
+std::shared_ptr<Array> RandomLargeString(int64_t length, double null_probability,
+ random::RandomArrayGenerator* rng) {
+ return rng->LargeString(length, 0, 32, null_probability);
+}
+
+std::shared_ptr<Array> RandomFixedSizeBinary(int64_t length, double null_probability,
+ random::RandomArrayGenerator* rng) {
+ const int32_t value_size = 16;
+ int64_t data_nbytes = length * value_size;
+ std::shared_ptr<Buffer> data = *AllocateBuffer(data_nbytes);
+ random_bytes(data_nbytes, /*seed=*/0, data->mutable_data());
+ auto validity = rng->Boolean(length, 1 - null_probability);
+
+ // Assemble the data for a FixedSizeBinaryArray
+ auto values_data = std::make_shared<ArrayData>(fixed_size_binary(value_size), length);
+ values_data->buffers = {validity->data()->buffers[1], data};
+ return MakeArray(values_data);
+}
+
+// ----------------------------------------------------------------------
+
+TEST(GetTakeIndices, Basics) {
+ auto CheckCase = [&](const std::string& filter_json, const std::string& indices_json,
+ FilterOptions::NullSelectionBehavior null_selection,
+ const std::shared_ptr<DataType>& indices_type = uint16()) {
+ auto filter = ArrayFromJSON(boolean(), filter_json);
+ auto expected_indices = ArrayFromJSON(indices_type, indices_json);
+ ASSERT_OK_AND_ASSIGN(auto indices,
+ internal::GetTakeIndices(*filter->data(), null_selection));
+ AssertArraysEqual(*expected_indices, *MakeArray(indices), /*verbose=*/true);
+ };
+
+ // Drop null cases
+ CheckCase("[]", "[]", FilterOptions::DROP);
+ CheckCase("[null]", "[]", FilterOptions::DROP);
+ CheckCase("[null, false, true, true, false, true]", "[2, 3, 5]", FilterOptions::DROP);
+
+ // Emit null cases
+ CheckCase("[]", "[]", FilterOptions::EMIT_NULL);
+ CheckCase("[null]", "[null]", FilterOptions::EMIT_NULL);
+ CheckCase("[null, false, true, true]", "[null, 2, 3]", FilterOptions::EMIT_NULL);
+}
+
+// TODO: Add slicing
+
+template <typename IndexArrayType>
+void CheckGetTakeIndicesCase(const Array& untyped_filter) {
+ const auto& filter = checked_cast<const BooleanArray&>(untyped_filter);
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<ArrayData> drop_indices,
+ internal::GetTakeIndices(*filter.data(), FilterOptions::DROP));
+ // Verify DROP indices
+ {
+ IndexArrayType indices(drop_indices);
+ int64_t out_position = 0;
+ for (int64_t i = 0; i < filter.length(); ++i) {
+ if (filter.IsValid(i)) {
+ if (filter.Value(i)) {
+ ASSERT_EQ(indices.Value(out_position), i);
+ ++out_position;
+ }
+ }
+ }
+ // Check that the end length agrees with the output of GetFilterOutputSize
+ ASSERT_EQ(out_position,
+ internal::GetFilterOutputSize(*filter.data(), FilterOptions::DROP));
+ }
+
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<ArrayData> emit_indices,
+ internal::GetTakeIndices(*filter.data(), FilterOptions::EMIT_NULL));
+
+ // Verify EMIT_NULL indices
+ {
+ IndexArrayType indices(emit_indices);
+ int64_t out_position = 0;
+ for (int64_t i = 0; i < filter.length(); ++i) {
+ if (filter.IsValid(i)) {
+ if (filter.Value(i)) {
+ ASSERT_EQ(indices.Value(out_position), i);
+ ++out_position;
+ }
+ } else {
+ ASSERT_TRUE(indices.IsNull(out_position));
+ ++out_position;
+ }
+ }
+
+ // Check that the end length agrees with the output of GetFilterOutputSize
+ ASSERT_EQ(out_position,
+ internal::GetFilterOutputSize(*filter.data(), FilterOptions::EMIT_NULL));
+ }
+}
+
+TEST(GetTakeIndices, RandomlyGenerated) {
+ random::RandomArrayGenerator rng(kRandomSeed);
+
+ // Multiple of word size + 1
+ const int64_t length = 6401;
+ for (auto null_prob : {0.0, 0.01, 0.999, 1.0}) {
+ for (auto true_prob : {0.0, 0.01, 0.999, 1.0}) {
+ auto filter = rng.Boolean(length, true_prob, null_prob);
+ CheckGetTakeIndicesCase<UInt16Array>(*filter);
+ CheckGetTakeIndicesCase<UInt16Array>(*filter->Slice(7));
+ }
+ }
+
+ // Check that the uint32 path is traveled successfully
+ const int64_t uint16_max = std::numeric_limits<uint16_t>::max();
+ auto filter =
+ std::static_pointer_cast<BooleanArray>(rng.Boolean(uint16_max + 1, 0.99, 0.01));
+ CheckGetTakeIndicesCase<UInt16Array>(*filter->Slice(1));
+ CheckGetTakeIndicesCase<UInt32Array>(*filter);
+}
+
+// ----------------------------------------------------------------------
+// Filter tests
+
+std::shared_ptr<Array> CoalesceNullToFalse(std::shared_ptr<Array> filter) {
+ if (filter->null_count() == 0) {
+ return filter;
+ }
+ const auto& data = *filter->data();
+ auto is_true = std::make_shared<BooleanArray>(data.length, data.buffers[1]);
+ auto is_valid = std::make_shared<BooleanArray>(data.length, data.buffers[0]);
+ EXPECT_OK_AND_ASSIGN(Datum out_datum, And(is_true, is_valid));
+ return out_datum.make_array();
+}
+
+template <typename ArrowType>
+class TestFilterKernel : public ::testing::Test {
+ protected:
+ TestFilterKernel() : emit_null_(FilterOptions::EMIT_NULL), drop_(FilterOptions::DROP) {}
+
+ void AssertFilter(std::shared_ptr<Array> values, std::shared_ptr<Array> filter,
+ std::shared_ptr<Array> expected) {
+ // test with EMIT_NULL
+ ASSERT_OK_AND_ASSIGN(Datum out_datum, Filter(values, filter, emit_null_));
+ auto actual = out_datum.make_array();
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(*expected, *actual);
+
+ // test with DROP using EMIT_NULL and a coalesced filter
+ auto coalesced_filter = CoalesceNullToFalse(filter);
+ ASSERT_OK_AND_ASSIGN(out_datum, Filter(values, coalesced_filter, emit_null_));
+ expected = out_datum.make_array();
+ ASSERT_OK_AND_ASSIGN(out_datum, Filter(values, filter, drop_));
+ actual = out_datum.make_array();
+ AssertArraysEqual(*expected, *actual);
+ }
+
+ void AssertFilter(std::shared_ptr<DataType> type, const std::string& values,
+ const std::string& filter, const std::string& expected) {
+ AssertFilter(ArrayFromJSON(type, values), ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(type, expected));
+ }
+
+ FilterOptions emit_null_, drop_;
+};
+
+void ValidateFilter(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& filter_boxed) {
+ FilterOptions emit_null(FilterOptions::EMIT_NULL);
+ FilterOptions drop(FilterOptions::DROP);
+
+ ASSERT_OK_AND_ASSIGN(Datum out_datum, Filter(values, filter_boxed, emit_null));
+ auto filtered_emit_null = out_datum.make_array();
+ ASSERT_OK(filtered_emit_null->ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(out_datum, Filter(values, filter_boxed, drop));
+ auto filtered_drop = out_datum.make_array();
+ ASSERT_OK(filtered_drop->ValidateFull());
+
+ // Create the expected arrays using Take
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<ArrayData> drop_indices,
+ internal::GetTakeIndices(*filter_boxed->data(), FilterOptions::DROP));
+ ASSERT_OK_AND_ASSIGN(Datum expected_drop, Take(values, Datum(drop_indices)));
+
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<ArrayData> emit_null_indices,
+ internal::GetTakeIndices(*filter_boxed->data(), FilterOptions::EMIT_NULL));
+ ASSERT_OK_AND_ASSIGN(Datum expected_emit_null, Take(values, Datum(emit_null_indices)));
+
+ AssertArraysEqual(*expected_drop.make_array(), *filtered_drop,
+ /*verbose=*/true);
+ AssertArraysEqual(*expected_emit_null.make_array(), *filtered_emit_null,
+ /*verbose=*/true);
+}
+
+class TestFilterKernelWithNull : public TestFilterKernel<NullType> {
+ protected:
+ void AssertFilter(const std::string& values, const std::string& filter,
+ const std::string& expected) {
+ TestFilterKernel<NullType>::AssertFilter(ArrayFromJSON(null(), values),
+ ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(null(), expected));
+ }
+};
+
+TEST_F(TestFilterKernelWithNull, FilterNull) {
+ this->AssertFilter("[]", "[]", "[]");
+
+ this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]");
+ this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]");
+}
+
+class TestFilterKernelWithBoolean : public TestFilterKernel<BooleanType> {
+ protected:
+ void AssertFilter(const std::string& values, const std::string& filter,
+ const std::string& expected) {
+ TestFilterKernel<BooleanType>::AssertFilter(ArrayFromJSON(boolean(), values),
+ ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(boolean(), expected));
+ }
+};
+
+TEST_F(TestFilterKernelWithBoolean, FilterBoolean) {
+ this->AssertFilter("[]", "[]", "[]");
+
+ this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]");
+ this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]");
+ this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]");
+}
+
+template <typename ArrowType>
+class TestFilterKernelWithNumeric : public TestFilterKernel<ArrowType> {
+ protected:
+ std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+};
+
+TYPED_TEST_SUITE(TestFilterKernelWithNumeric, NumericArrowTypes);
+TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) {
+ auto type = this->type_singleton();
+ this->AssertFilter(type, "[]", "[]", "[]");
+
+ this->AssertFilter(type, "[9]", "[0]", "[]");
+ this->AssertFilter(type, "[9]", "[1]", "[9]");
+ this->AssertFilter(type, "[9]", "[null]", "[null]");
+ this->AssertFilter(type, "[null]", "[0]", "[]");
+ this->AssertFilter(type, "[null]", "[1]", "[null]");
+ this->AssertFilter(type, "[null]", "[null]", "[null]");
+
+ this->AssertFilter(type, "[7, 8, 9]", "[0, 1, 0]", "[8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[1, 0, 1]", "[7, 9]");
+ this->AssertFilter(type, "[null, 8, 9]", "[0, 1, 0]", "[8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8]");
+ this->AssertFilter(type, "[7, 8, 9]", "[1, null, 1]", "[7, null, 9]");
+
+ this->AssertFilter(ArrayFromJSON(type, "[7, 8, 9]"),
+ ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3, 3),
+ ArrayFromJSON(type, "[7, 9]"));
+
+ ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"),
+ ArrayFromJSON(boolean(), "[]"), this->emit_null_));
+ ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"),
+ ArrayFromJSON(boolean(), "[]"), this->drop_));
+}
+
+template <typename DataGenerator>
+void DoRandomFilterTests(DataGenerator&& generate_values) {
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ for (size_t i = 3; i < 10; i++) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.999, 1.0}) {
+ for (auto true_probability : {0.0, 0.1, 0.999, 1.0}) {
+ auto values = generate_values(length, null_probability, &rand);
+ auto filter = rand.Boolean(length + 1, true_probability, null_probability);
+ auto filter_no_nulls = rand.Boolean(length + 1, true_probability, 0.0);
+ ValidateFilter(values, filter->Slice(0, values->length()));
+ ValidateFilter(values, filter_no_nulls->Slice(0, values->length()));
+ // Test values and filter have different offsets
+ ValidateFilter(values->Slice(3), filter->Slice(4));
+ ValidateFilter(values->Slice(3), filter_no_nulls->Slice(4));
+ }
+ }
+ }
+}
+
+TYPED_TEST(TestFilterKernelWithNumeric, FilterRandomNumeric) {
+ DoRandomFilterTests(RandomNumeric<TypeParam>);
+}
+
+TEST(TestFilter, RandomBoolean) { DoRandomFilterTests(RandomBoolean); }
+
+TEST(TestFilter, RandomString) {
+ DoRandomFilterTests(RandomString);
+ DoRandomFilterTests(RandomLargeString);
+}
+
+TEST(TestFilter, RandomFixedSizeBinary) { DoRandomFilterTests(RandomFixedSizeBinary); }
+
+template <typename CType>
+using Comparator = bool(CType, CType);
+
+template <typename CType>
+Comparator<CType>* GetComparator(CompareOperator op) {
+ static Comparator<CType>* cmp[] = {
+ // EQUAL
+ [](CType l, CType r) { return l == r; },
+ // NOT_EQUAL
+ [](CType l, CType r) { return l != r; },
+ // GREATER
+ [](CType l, CType r) { return l > r; },
+ // GREATER_EQUAL
+ [](CType l, CType r) { return l >= r; },
+ // LESS
+ [](CType l, CType r) { return l < r; },
+ // LESS_EQUAL
+ [](CType l, CType r) { return l <= r; },
+ };
+ return cmp[op];
+}
+
+template <typename T, typename Fn, typename CType = typename TypeTraits<T>::CType>
+std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, Fn&& fn) {
+ std::vector<CType> filtered;
+ filtered.reserve(length);
+ std::copy_if(data, data + length, std::back_inserter(filtered), std::forward<Fn>(fn));
+ std::shared_ptr<Array> filtered_array;
+ ArrayFromVector<T, CType>(filtered, &filtered_array);
+ return filtered_array;
+}
+
+template <typename T, typename CType = typename TypeTraits<T>::CType>
+std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length, CType val,
+ CompareOperator op) {
+ auto cmp = GetComparator<CType>(op);
+ return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, val); });
+}
+
+template <typename T, typename CType = typename TypeTraits<T>::CType>
+std::shared_ptr<Array> CompareAndFilter(const CType* data, int64_t length,
+ const CType* other, CompareOperator op) {
+ auto cmp = GetComparator<CType>(op);
+ return CompareAndFilter<T>(data, length, [&](CType e) { return cmp(e, *other++); });
+}
+
+TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) {
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ using CType = typename TypeTraits<TypeParam>::CType;
+
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ for (size_t i = 3; i < 10; i++) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ // TODO(bkietz) rewrite with some nulls
+ auto array =
+ checked_pointer_cast<ArrayType>(rand.Numeric<TypeParam>(length, 0, 100, 0));
+ CType c_fifty = 50;
+ auto fifty = std::make_shared<ScalarType>(c_fifty);
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ ASSERT_OK_AND_ASSIGN(Datum selection,
+ Compare(array, Datum(fifty), CompareOptions(op)));
+ ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(array, selection));
+ auto filtered_array = filtered.make_array();
+ ASSERT_OK(filtered_array->ValidateFull());
+ auto expected =
+ CompareAndFilter<TypeParam>(array->raw_values(), array->length(), c_fifty, op);
+ ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
+ }
+ }
+}
+
+TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) {
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ for (size_t i = 3; i < 10; i++) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ auto lhs = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
+ auto rhs = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ ASSERT_OK_AND_ASSIGN(Datum selection, Compare(lhs, rhs, CompareOptions(op)));
+ ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(lhs, selection));
+ auto filtered_array = filtered.make_array();
+ ASSERT_OK(filtered_array->ValidateFull());
+ auto expected = CompareAndFilter<TypeParam>(lhs->raw_values(), lhs->length(),
+ rhs->raw_values(), op);
+ ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
+ }
+ }
+}
+
+TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) {
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
+ using CType = typename TypeTraits<TypeParam>::CType;
+
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ for (size_t i = 3; i < 10; i++) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ auto array = checked_pointer_cast<ArrayType>(
+ rand.Numeric<TypeParam>(length, 0, 100, /*null_probability=*/0.0));
+ CType c_fifty = 50, c_hundred = 100;
+ auto fifty = std::make_shared<ScalarType>(c_fifty);
+ auto hundred = std::make_shared<ScalarType>(c_hundred);
+ ASSERT_OK_AND_ASSIGN(Datum greater_than_fifty,
+ Compare(array, Datum(fifty), CompareOptions(GREATER)));
+ ASSERT_OK_AND_ASSIGN(Datum less_than_hundred,
+ Compare(array, Datum(hundred), CompareOptions(LESS)));
+ ASSERT_OK_AND_ASSIGN(Datum selection, And(greater_than_fifty, less_than_hundred));
+ ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(array, selection));
+ auto filtered_array = filtered.make_array();
+ ASSERT_OK(filtered_array->ValidateFull());
+ auto expected = CompareAndFilter<TypeParam>(
+ array->raw_values(), array->length(),
+ [&](CType e) { return (e > c_fifty) && (e < c_hundred); });
+ ASSERT_ARRAYS_EQUAL(*filtered_array, *expected);
+ }
+}
+
+using StringTypes =
+ ::testing::Types<BinaryType, StringType, LargeBinaryType, LargeStringType>;
+
+template <typename TypeClass>
+class TestFilterKernelWithString : public TestFilterKernel<TypeClass> {
+ protected:
+ std::shared_ptr<DataType> value_type() {
+ return TypeTraits<TypeClass>::type_singleton();
+ }
+
+ void AssertFilter(const std::string& values, const std::string& filter,
+ const std::string& expected) {
+ TestFilterKernel<TypeClass>::AssertFilter(ArrayFromJSON(value_type(), values),
+ ArrayFromJSON(boolean(), filter),
+ ArrayFromJSON(value_type(), expected));
+ }
+
+ void AssertFilterDictionary(const std::string& dictionary_values,
+ const std::string& dictionary_filter,
+ const std::string& filter,
+ const std::string& expected_filter) {
+ auto dict = ArrayFromJSON(value_type(), dictionary_values);
+ auto type = dictionary(int8(), value_type());
+ ASSERT_OK_AND_ASSIGN(auto values,
+ DictionaryArray::FromArrays(
+ type, ArrayFromJSON(int8(), dictionary_filter), dict));
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_filter), dict));
+ auto take_filter = ArrayFromJSON(boolean(), filter);
+ TestFilterKernel<TypeClass>::AssertFilter(values, take_filter, expected);
+ }
+};
+
+TYPED_TEST_SUITE(TestFilterKernelWithString, StringTypes);
+
+TYPED_TEST(TestFilterKernelWithString, FilterString) {
+ this->AssertFilter(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["b"])");
+ this->AssertFilter(R"([null, "b", "c"])", "[0, 1, 0]", R"(["b"])");
+ this->AssertFilter(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b"])");
+}
+
+TYPED_TEST(TestFilterKernelWithString, FilterDictionary) {
+ auto dict = R"(["a", "b", "c", "d", "e"])";
+ this->AssertFilterDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[4]");
+ this->AssertFilterDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[4]");
+ this->AssertFilterDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4]");
+}
+
+class TestFilterKernelWithList : public TestFilterKernel<ListType> {
+ public:
+};
+
+TEST_F(TestFilterKernelWithList, FilterListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ this->AssertFilter(list(int32()), list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(list(int32()), list_json, "[0, 1, 1, null]", "[[1,2], null, null]");
+ this->AssertFilter(list(int32()), list_json, "[0, 0, 1, null]", "[null, null]");
+ this->AssertFilter(list(int32()), list_json, "[1, 0, 0, 1]", "[[], [3]]");
+ this->AssertFilter(list(int32()), list_json, "[1, 1, 1, 1]", list_json);
+ this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]");
+}
+
+TEST_F(TestFilterKernelWithList, FilterListListInt32) {
+ std::string list_json = R"([
+ [],
+ [[1], [2, null, 2], []],
+ null,
+ [[3, null], null]
+ ])";
+ auto type = list(list(int32()));
+ this->AssertFilter(type, list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(type, list_json, "[0, 1, 1, null]", R"([
+ [[1], [2, null, 2], []],
+ null,
+ null
+ ])");
+ this->AssertFilter(type, list_json, "[0, 0, 1, null]", "[null, null]");
+ this->AssertFilter(type, list_json, "[1, 0, 0, 1]", R"([
+ [],
+ [[3, null], null]
+ ])");
+ this->AssertFilter(type, list_json, "[1, 1, 1, 1]", list_json);
+ this->AssertFilter(type, list_json, "[0, 1, 0, 1]", R"([
+ [[1], [2, null, 2], []],
+ [[3, null], null]
+ ])");
+}
+
+class TestFilterKernelWithLargeList : public TestFilterKernel<LargeListType> {};
+
+TEST_F(TestFilterKernelWithLargeList, FilterListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ this->AssertFilter(large_list(int32()), list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(large_list(int32()), list_json, "[0, 1, 1, null]",
+ "[[1,2], null, null]");
+}
+
+class TestFilterKernelWithFixedSizeList : public TestFilterKernel<FixedSizeListType> {};
+
+TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) {
+ std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 1, null]",
+ "[[1, null, 3], [4, 5, 6], null]");
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 1, null]",
+ "[[4, 5, 6], null]");
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[1, 1, 1, 1]", list_json);
+ this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 0, 1]",
+ "[[1, null, 3], [7, 8, null]]");
+}
+
+class TestFilterKernelWithMap : public TestFilterKernel<MapType> {};
+
+TEST_F(TestFilterKernelWithMap, FilterMapStringToInt32) {
+ std::string map_json = R"([
+ [["joe", 0], ["mark", null]],
+ null,
+ [["cap", 8]],
+ []
+ ])";
+ this->AssertFilter(map(utf8(), int32()), map_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 1, null]", R"([
+ null,
+ [["cap", 8]],
+ null
+ ])");
+ this->AssertFilter(map(utf8(), int32()), map_json, "[1, 1, 1, 1]", map_json);
+ this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 0, 1]", "[null, []]");
+}
+
+class TestFilterKernelWithStruct : public TestFilterKernel<StructType> {};
+
+TEST_F(TestFilterKernelWithStruct, FilterStruct) {
+ auto struct_type = struct_({field("a", int32()), field("b", utf8())});
+ auto struct_json = R"([
+ null,
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ this->AssertFilter(struct_type, struct_json, "[0, 0, 0, 0]", "[]");
+ this->AssertFilter(struct_type, struct_json, "[0, 1, 1, null]", R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ null
+ ])");
+ this->AssertFilter(struct_type, struct_json, "[1, 1, 1, 1]", struct_json);
+ this->AssertFilter(struct_type, struct_json, "[1, 0, 1, 0]", R"([
+ null,
+ {"a": 2, "b": "hello"}
+ ])");
+}
+
+class TestFilterKernelWithUnion : public TestFilterKernel<UnionType> {};
+
+TEST_F(TestFilterKernelWithUnion, DISABLED_FilterUnion) {
+ for (auto union_ : UnionTypeFactories()) {
+ auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5});
+ auto union_json = R"([
+ null,
+ [2, 222],
+ [5, "hello"],
+ [5, "eh"],
+ null,
+ [2, 111]
+ ])";
+ this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0]", "[]");
+ this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1]", R"([
+ [2, 222],
+ [5, "hello"],
+ null,
+ [2, 111]
+ ])");
+ this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0]", R"([
+ null,
+ [5, "hello"],
+ null
+ ])");
+ this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1]", union_json);
+ }
+}
+
+class TestFilterKernelWithRecordBatch : public TestFilterKernel<RecordBatch> {
+ public:
+ void AssertFilter(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::string& selection, FilterOptions options,
+ const std::string& expected_batch) {
+ std::shared_ptr<RecordBatch> actual;
+
+ ASSERT_OK(this->DoFilter(schm, batch_json, selection, options, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
+ }
+
+ Status DoFilter(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::string& selection, FilterOptions options,
+ std::shared_ptr<RecordBatch>* out) {
+ auto batch = RecordBatchFromJSON(schm, batch_json);
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum,
+ Filter(batch, ArrayFromJSON(boolean(), selection), options));
+ *out = out_datum.record_batch();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestFilterKernelWithRecordBatch, FilterRecordBatch) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ auto batch_json = R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ for (auto options : {this->emit_null_, this->drop_}) {
+ this->AssertFilter(schm, batch_json, "[0, 0, 0, 0]", options, "[]");
+ this->AssertFilter(schm, batch_json, "[1, 1, 1, 1]", options, batch_json);
+ this->AssertFilter(schm, batch_json, "[1, 0, 1, 0]", options, R"([
+ {"a": null, "b": "yo"},
+ {"a": 2, "b": "hello"}
+ ])");
+ }
+
+ this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->drop_, R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"}
+ ])");
+
+ this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->emit_null_, R"([
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": null, "b": null}
+ ])");
+}
+
+class TestFilterKernelWithChunkedArray : public TestFilterKernel<ChunkedArray> {
+ public:
+ void AssertFilter(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, const std::string& filter,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->FilterWithArray(type, values, filter, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ void AssertChunkedFilter(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& filter,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->FilterWithChunkedArray(type, values, filter, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ Status FilterWithArray(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::string& filter, std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum, Filter(ChunkedArrayFromJSON(type, values),
+ ArrayFromJSON(boolean(), filter)));
+ *out = out_datum.chunked_array();
+ return Status::OK();
+ }
+
+ Status FilterWithChunkedArray(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& filter,
+ std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum,
+ Filter(ChunkedArrayFromJSON(type, values),
+ ChunkedArrayFromJSON(boolean(), filter)));
+ *out = out_datum.chunked_array();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestFilterKernelWithChunkedArray, FilterChunkedArray) {
+ this->AssertFilter(int8(), {"[]"}, "[]", {});
+ this->AssertChunkedFilter(int8(), {"[]"}, {"[]"}, {});
+
+ this->AssertFilter(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0]", {"[8]"});
+ this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0]", "[1, 0]"}, {"[8]"});
+ this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0, 1]", "[0]"}, {"[8]"});
+
+ std::shared_ptr<ChunkedArray> arr;
+ ASSERT_RAISES(
+ Invalid, this->FilterWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 1, 1]", &arr));
+ ASSERT_RAISES(Invalid, this->FilterWithChunkedArray(int8(), {"[7]", "[8, 9]"},
+ {"[0, 1, 0]", "[1, 1]"}, &arr));
+}
+
+class TestFilterKernelWithTable : public TestFilterKernel<Table> {
+ public:
+ void AssertFilter(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json, const std::string& filter,
+ FilterOptions options,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+
+ ASSERT_OK(this->FilterWithArray(schm, table_json, filter, options, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
+ }
+
+ void AssertChunkedFilter(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json,
+ const std::vector<std::string>& filter, FilterOptions options,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+
+ ASSERT_OK(this->FilterWithChunkedArray(schm, table_json, filter, options, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertTablesEqual(*TableFromJSON(schm, expected_table), *actual,
+ /*same_chunk_layout=*/false);
+ }
+
+ Status FilterWithArray(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values,
+ const std::string& filter, FilterOptions options,
+ std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(
+ Datum out_datum,
+ Filter(TableFromJSON(schm, values), ArrayFromJSON(boolean(), filter), options));
+ *out = out_datum.table();
+ return Status::OK();
+ }
+
+ Status FilterWithChunkedArray(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& filter,
+ FilterOptions options, std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum out_datum,
+ Filter(TableFromJSON(schm, values),
+ ChunkedArrayFromJSON(boolean(), filter), options));
+ *out = out_datum.table();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestFilterKernelWithTable, FilterTable) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ std::vector<std::string> table_json = {R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""}
+ ])",
+ R"([
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])"};
+ for (auto options : {this->emit_null_, this->drop_}) {
+ this->AssertFilter(schm, table_json, "[0, 0, 0, 0]", options, {});
+ this->AssertChunkedFilter(schm, table_json, {"[0]", "[0, 0, 0]"}, options, {});
+ this->AssertFilter(schm, table_json, "[1, 1, 1, 1]", options, table_json);
+ this->AssertChunkedFilter(schm, table_json, {"[1]", "[1, 1, 1]"}, options,
+ table_json);
+ }
+
+ std::vector<std::string> expected_emit_null = {R"([
+ {"a": 1, "b": ""}
+ ])",
+ R"([
+ {"a": 2, "b": "hello"},
+ {"a": null, "b": null}
+ ])"};
+ this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->emit_null_,
+ expected_emit_null);
+ this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->emit_null_,
+ expected_emit_null);
+
+ std::vector<std::string> expected_drop = {R"([{"a": 1, "b": ""}])",
+ R"([{"a": 2, "b": "hello"}])"};
+ this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->drop_, expected_drop);
+ this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->drop_,
+ expected_drop);
+}
+
+// ----------------------------------------------------------------------
+// Take tests
+
+void AssertTakeArrays(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& expected) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, Take(*values, *indices));
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+Status TakeJSON(const std::shared_ptr<DataType>& type, const std::string& values,
+ const std::shared_ptr<DataType>& index_type, const std::string& indices,
+ std::shared_ptr<Array>* out) {
+ return Take(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, indices))
+ .Value(out);
+}
+
+void CheckTake(const std::shared_ptr<DataType>& type, const std::string& values,
+ const std::string& indices, const std::string& expected) {
+ std::shared_ptr<Array> actual;
+
+ for (auto index_type : {int8(), uint32()}) {
+ ASSERT_OK(TakeJSON(type, values, index_type, indices, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertArraysEqual(*ArrayFromJSON(type, expected), *actual, /*verbose=*/true);
+ }
+}
+
+void AssertTakeNull(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(null(), values, indices, expected);
+}
+
+void AssertTakeBoolean(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(boolean(), values, indices, expected);
+}
+
+template <typename ValuesType, typename IndexType>
+void ValidateTakeImpl(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices,
+ const std::shared_ptr<Array>& result) {
+ using ValuesArrayType = typename TypeTraits<ValuesType>::ArrayType;
+ using IndexArrayType = typename TypeTraits<IndexType>::ArrayType;
+ auto typed_values = checked_pointer_cast<ValuesArrayType>(values);
+ auto typed_result = checked_pointer_cast<ValuesArrayType>(result);
+ auto typed_indices = checked_pointer_cast<IndexArrayType>(indices);
+ for (int64_t i = 0; i < indices->length(); ++i) {
+ if (typed_indices->IsNull(i) || typed_values->IsNull(typed_indices->Value(i))) {
+ ASSERT_TRUE(result->IsNull(i)) << i;
+ } else {
+ ASSERT_FALSE(result->IsNull(i)) << i;
+ ASSERT_EQ(typed_result->GetView(i), typed_values->GetView(typed_indices->Value(i)))
+ << i;
+ }
+ }
+}
+
+template <typename ValuesType>
+void ValidateTake(const std::shared_ptr<Array>& values,
+ const std::shared_ptr<Array>& indices) {
+ ASSERT_OK_AND_ASSIGN(Datum out, Take(values, indices));
+ auto taken = out.make_array();
+ ASSERT_OK(taken->ValidateFull());
+ ASSERT_EQ(indices->length(), taken->length());
+ switch (indices->type_id()) {
+ case Type::INT8:
+ ValidateTakeImpl<ValuesType, Int8Type>(values, indices, taken);
+ break;
+ case Type::INT16:
+ ValidateTakeImpl<ValuesType, Int16Type>(values, indices, taken);
+ break;
+ case Type::INT32:
+ ValidateTakeImpl<ValuesType, Int32Type>(values, indices, taken);
+ break;
+ case Type::INT64:
+ ValidateTakeImpl<ValuesType, Int64Type>(values, indices, taken);
+ break;
+ case Type::UINT8:
+ ValidateTakeImpl<ValuesType, UInt8Type>(values, indices, taken);
+ break;
+ case Type::UINT16:
+ ValidateTakeImpl<ValuesType, UInt16Type>(values, indices, taken);
+ break;
+ case Type::UINT32:
+ ValidateTakeImpl<ValuesType, UInt32Type>(values, indices, taken);
+ break;
+ case Type::UINT64:
+ ValidateTakeImpl<ValuesType, UInt64Type>(values, indices, taken);
+ break;
+ default:
+ FAIL() << "Invalid index type";
+ break;
+ }
+}
+
+template <typename T>
+T GetMaxIndex(int64_t values_length) {
+ int64_t max_index = values_length - 1;
+ if (max_index > static_cast<int64_t>(std::numeric_limits<T>::max())) {
+ max_index = std::numeric_limits<T>::max();
+ }
+ return static_cast<T>(max_index);
+}
+
+template <>
+uint64_t GetMaxIndex(int64_t values_length) {
+ return static_cast<uint64_t>(values_length - 1);
+}
+
+template <typename ValuesType, typename IndexType>
+void CheckTakeRandom(const std::shared_ptr<Array>& values, int64_t indices_length,
+ double null_probability, random::RandomArrayGenerator* rand) {
+ using IndexCType = typename IndexType::c_type;
+ IndexCType max_index = GetMaxIndex<IndexCType>(values->length());
+ auto indices = rand->Numeric<IndexType>(indices_length, static_cast<IndexCType>(0),
+ max_index, null_probability);
+ auto indices_no_nulls = rand->Numeric<IndexType>(
+ indices_length, static_cast<IndexCType>(0), max_index, /*null_probability=*/0.0);
+ ValidateTake<ValuesType>(values, indices);
+ ValidateTake<ValuesType>(values, indices_no_nulls);
+ // Sliced indices array
+ if (indices_length >= 2) {
+ indices = indices->Slice(1, indices_length - 2);
+ indices_no_nulls = indices_no_nulls->Slice(1, indices_length - 2);
+ ValidateTake<ValuesType>(values, indices);
+ ValidateTake<ValuesType>(values, indices_no_nulls);
+ }
+}
+
+template <typename ValuesType, typename DataGenerator>
+void DoRandomTakeTests(DataGenerator&& generate_values) {
+ auto rand = random::RandomArrayGenerator(kRandomSeed);
+ for (const int64_t length : {1, 16, 59}) {
+ for (const int64_t indices_length : {0, 5, 30}) {
+ for (const auto null_probability : {0.0, 0.05, 0.25, 0.95, 1.0}) {
+ auto values = generate_values(length, null_probability, &rand);
+ CheckTakeRandom<ValuesType, Int8Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, Int16Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, Int32Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, Int64Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, UInt8Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, UInt16Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, UInt32Type>(values, indices_length, null_probability,
+ &rand);
+ CheckTakeRandom<ValuesType, UInt64Type>(values, indices_length, null_probability,
+ &rand);
+ // Sliced values array
+ if (length > 2) {
+ values = values->Slice(1, length - 2);
+ CheckTakeRandom<ValuesType, UInt64Type>(values, indices_length,
+ null_probability, &rand);
+ }
+ }
+ }
+ }
+}
+
+template <typename ArrowType>
+class TestTakeKernel : public ::testing::Test {};
+
+TEST(TestTakeKernel, TakeNull) {
+ AssertTakeNull("[null, null, null]", "[0, 1, 0]", "[null, null, null]");
+
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(null(), "[null, null, null]", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError,
+ TakeJSON(boolean(), "[null, null, null]", int8(), "[0, -1, 0]", &arr));
+}
+
+TEST(TestTakeKernel, InvalidIndexType) {
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(NotImplemented, TakeJSON(null(), "[null, null, null]", float32(),
+ "[0.0, 1.0, 0.1]", &arr));
+}
+
+TEST(TestTakeKernel, TakeBoolean) {
+ AssertTakeBoolean("[7, 8, 9]", "[]", "[]");
+ AssertTakeBoolean("[true, false, true]", "[0, 1, 0]", "[true, false, true]");
+ AssertTakeBoolean("[null, false, true]", "[0, 1, 0]", "[null, false, null]");
+ AssertTakeBoolean("[true, false, true]", "[null, 1, 0]", "[null, false, true]");
+
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError,
+ TakeJSON(boolean(), "[true, false, true]", int8(), "[0, -1, 0]", &arr));
+}
+
+template <typename ArrowType>
+class TestTakeKernelWithNumeric : public TestTakeKernel<ArrowType> {
+ protected:
+ void AssertTake(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(type_singleton(), values, indices, expected);
+ }
+
+ std::shared_ptr<DataType> type_singleton() {
+ return TypeTraits<ArrowType>::type_singleton();
+ }
+};
+
+TYPED_TEST_SUITE(TestTakeKernelWithNumeric, NumericArrowTypes);
+TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) {
+ this->AssertTake("[7, 8, 9]", "[]", "[]");
+ this->AssertTake("[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]");
+ this->AssertTake("[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]");
+ this->AssertTake("[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]");
+ this->AssertTake("[null, 8, 9]", "[]", "[]");
+ this->AssertTake("[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]");
+
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(this->type_singleton(), "[7, 8, 9]", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError, TakeJSON(this->type_singleton(), "[7, 8, 9]", int8(),
+ "[0, -1, 0]", &arr));
+}
+
+TYPED_TEST(TestTakeKernelWithNumeric, TakeRandomNumeric) {
+ DoRandomTakeTests<TypeParam>(RandomNumeric<TypeParam>);
+}
+
+TEST(TestTakeKernel, TakeBooleanRandom) { DoRandomTakeTests<BooleanType>(RandomBoolean); }
+
+TEST(TestTakeKernelString, Random) {
+ DoRandomTakeTests<StringType>(RandomString);
+ DoRandomTakeTests<LargeStringType>(RandomLargeString);
+}
+
+TEST(TestTakeKernelFixedSizeBinary, Random) {
+ DoRandomTakeTests<FixedSizeBinaryType>(RandomFixedSizeBinary);
+}
+
+template <typename TypeClass>
+class TestTakeKernelWithString : public TestTakeKernel<TypeClass> {
+ public:
+ std::shared_ptr<DataType> value_type() {
+ return TypeTraits<TypeClass>::type_singleton();
+ }
+
+ void AssertTake(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(value_type(), values, indices, expected);
+ }
+
+ void AssertTakeDictionary(const std::string& dictionary_values,
+ const std::string& dictionary_indices,
+ const std::string& indices,
+ const std::string& expected_indices) {
+ auto dict = ArrayFromJSON(value_type(), dictionary_values);
+ auto type = dictionary(int8(), value_type());
+ ASSERT_OK_AND_ASSIGN(auto values,
+ DictionaryArray::FromArrays(
+ type, ArrayFromJSON(int8(), dictionary_indices), dict));
+ ASSERT_OK_AND_ASSIGN(
+ auto expected,
+ DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict));
+ auto take_indices = ArrayFromJSON(int8(), indices);
+ AssertTakeArrays(values, take_indices, expected);
+ }
+};
+
+TYPED_TEST_SUITE(TestTakeKernelWithString, TestingStringTypes);
+
+TYPED_TEST(TestTakeKernelWithString, TakeString) {
+ this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])");
+ this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]");
+ this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])");
+
+ std::shared_ptr<DataType> type = this->value_type();
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(type, R"(["a", "b", "c"])", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError, TakeJSON(type, R"(["a", "b", null, "ddd", "ee"])", int64(),
+ "[2, 5]", &arr));
+}
+
+TYPED_TEST(TestTakeKernelWithString, TakeDictionary) {
+ auto dict = R"(["a", "b", "c", "d", "e"])";
+ this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]");
+ this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]");
+ this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]");
+}
+
+class TestTakeKernelFSB : public TestTakeKernel<FixedSizeBinaryType> {
+ public:
+ std::shared_ptr<DataType> value_type() { return fixed_size_binary(3); }
+
+ void AssertTake(const std::string& values, const std::string& indices,
+ const std::string& expected) {
+ CheckTake(value_type(), values, indices, expected);
+ }
+};
+
+TEST_F(TestTakeKernelFSB, TakeFixedSizeBinary) {
+ this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]", R"(["aaa", "bbb", "aaa"])");
+ this->AssertTake(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\", null]");
+ this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[null, 1, 0]", R"([null, "bbb", "aaa"])");
+
+ std::shared_ptr<DataType> type = this->value_type();
+ std::shared_ptr<Array> arr;
+ ASSERT_RAISES(IndexError,
+ TakeJSON(type, R"(["aaa", "bbb", "ccc"])", int8(), "[0, 9, 0]", &arr));
+ ASSERT_RAISES(IndexError, TakeJSON(type, R"(["aaa", "bbb", null, "ddd", "eee"])",
+ int64(), "[2, 5]", &arr));
+}
+
+class TestTakeKernelWithList : public TestTakeKernel<ListType> {};
+
+TEST_F(TestTakeKernelWithList, TakeListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ CheckTake(list(int32()), list_json, "[]", "[]");
+ CheckTake(list(int32()), list_json, "[3, 2, 1]", "[[3], null, [1,2]]");
+ CheckTake(list(int32()), list_json, "[null, 3, 0]", "[null, [3], []]");
+ CheckTake(list(int32()), list_json, "[null, null]", "[null, null]");
+ CheckTake(list(int32()), list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]");
+ CheckTake(list(int32()), list_json, "[0, 1, 2, 3]", list_json);
+ CheckTake(list(int32()), list_json, "[0, 0, 0, 0, 0, 0, 1]",
+ "[[], [], [], [], [], [], [1, 2]]");
+}
+
+TEST_F(TestTakeKernelWithList, TakeListListInt32) {
+ std::string list_json = R"([
+ [],
+ [[1], [2, null, 2], []],
+ null,
+ [[3, null], null]
+ ])";
+ auto type = list(list(int32()));
+ CheckTake(type, list_json, "[]", "[]");
+ CheckTake(type, list_json, "[3, 2, 1]", R"([
+ [[3, null], null],
+ null,
+ [[1], [2, null, 2], []]
+ ])");
+ CheckTake(type, list_json, "[null, 3, 0]", R"([
+ null,
+ [[3, null], null],
+ []
+ ])");
+ CheckTake(type, list_json, "[null, null]", "[null, null]");
+ CheckTake(type, list_json, "[3, 0, 0, 3]",
+ "[[[3, null], null], [], [], [[3, null], null]]");
+ CheckTake(type, list_json, "[0, 1, 2, 3]", list_json);
+ CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]",
+ "[[], [], [], [], [], [], [[1], [2, null, 2], []]]");
+}
+
+class TestTakeKernelWithLargeList : public TestTakeKernel<LargeListType> {};
+
+TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) {
+ std::string list_json = "[[], [1,2], null, [3]]";
+ CheckTake(large_list(int32()), list_json, "[]", "[]");
+ CheckTake(large_list(int32()), list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]");
+}
+
+class TestTakeKernelWithFixedSizeList : public TestTakeKernel<FixedSizeListType> {};
+
+TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) {
+ std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[]", "[]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 2, 1]",
+ "[[7, 8, null], [4, 5, 6], [1, null, 3]]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[null, 2, 0]",
+ "[null, [4, 5, 6], null]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[null, null]", "[null, null]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]",
+ "[[7, 8, null], null, null, [7, 8, null]]");
+ CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json);
+ CheckTake(
+ fixed_size_list(int32(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
+ "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, null, 3]]");
+}
+
+class TestTakeKernelWithMap : public TestTakeKernel<MapType> {};
+
+TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
+ std::string map_json = R"([
+ [["joe", 0], ["mark", null]],
+ null,
+ [["cap", 8]],
+ []
+ ])";
+ CheckTake(map(utf8(), int32()), map_json, "[]", "[]");
+ CheckTake(map(utf8(), int32()), map_json, "[3, 1, 3, 1, 3]",
+ "[[], null, [], null, []]");
+ CheckTake(map(utf8(), int32()), map_json, "[2, 1, null]", R"([
+ [["cap", 8]],
+ null,
+ null
+ ])");
+ CheckTake(map(utf8(), int32()), map_json, "[2, 1, 0]", R"([
+ [["cap", 8]],
+ null,
+ [["joe", 0], ["mark", null]]
+ ])");
+ CheckTake(map(utf8(), int32()), map_json, "[0, 1, 2, 3]", map_json);
+ CheckTake(map(utf8(), int32()), map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ [["joe", 0], ["mark", null]],
+ []
+ ])");
+}
+
+class TestTakeKernelWithStruct : public TestTakeKernel<StructType> {};
+
+TEST_F(TestTakeKernelWithStruct, TakeStruct) {
+ auto struct_type = struct_({field("a", int32()), field("b", utf8())});
+ auto struct_json = R"([
+ null,
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ CheckTake(struct_type, struct_json, "[]", "[]");
+ CheckTake(struct_type, struct_json, "[3, 1, 3, 1, 3]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"}
+ ])");
+ CheckTake(struct_type, struct_json, "[3, 1, 0]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ null
+ ])");
+ CheckTake(struct_type, struct_json, "[0, 1, 2, 3]", struct_json);
+ CheckTake(struct_type, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ null,
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"}
+ ])");
+}
+
+class TestTakeKernelWithUnion : public TestTakeKernel<UnionType> {};
+
+// TODO: Restore Union take functionality
+TEST_F(TestTakeKernelWithUnion, DISABLED_TakeUnion) {
+ for (auto union_ : UnionTypeFactories()) {
+ auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5});
+ auto union_json = R"([
+ null,
+ [2, 222],
+ [5, "hello"],
+ [5, "eh"],
+ null,
+ [2, 111]
+ ])";
+ CheckTake(union_type, union_json, "[]", "[]");
+ CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
+ [5, "eh"],
+ [2, 222],
+ [5, "eh"],
+ [2, 222],
+ [5, "eh"]
+ ])");
+ CheckTake(union_type, union_json, "[4, 2, 1]", R"([
+ null,
+ [5, "hello"],
+ [2, 222]
+ ])");
+ CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5]", union_json);
+ CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ null,
+ [5, "hello"],
+ [5, "hello"],
+ [5, "hello"],
+ [5, "hello"],
+ [5, "hello"],
+ [5, "hello"]
+ ])");
+ }
+}
+
+class TestPermutationsWithTake : public TestBase {
+ protected:
+ void DoTake(const Int16Array& values, const Int16Array& indices,
+ std::shared_ptr<Int16Array>* out) {
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> boxed_out, Take(values, indices));
+ ASSERT_OK(boxed_out->ValidateFull());
+ *out = checked_pointer_cast<Int16Array>(std::move(boxed_out));
+ }
+
+ std::shared_ptr<Int16Array> DoTake(const Int16Array& values,
+ const Int16Array& indices) {
+ std::shared_ptr<Int16Array> out;
+ DoTake(values, indices, &out);
+ return out;
+ }
+
+ std::shared_ptr<Int16Array> DoTakeN(uint64_t n, std::shared_ptr<Int16Array> array) {
+ auto power_of_2 = array;
+ array = Identity(array->length());
+ while (n != 0) {
+ if (n & 1) {
+ array = DoTake(*array, *power_of_2);
+ }
+ power_of_2 = DoTake(*power_of_2, *power_of_2);
+ n >>= 1;
+ }
+ return array;
+ }
+
+ template <typename Rng>
+ void Shuffle(const Int16Array& array, Rng& gen, std::shared_ptr<Int16Array>* shuffled) {
+ auto byte_length = array.length() * sizeof(int16_t);
+ ASSERT_OK_AND_ASSIGN(auto data, array.values()->CopySlice(0, byte_length));
+ auto mutable_data = reinterpret_cast<int16_t*>(data->mutable_data());
+ std::shuffle(mutable_data, mutable_data + array.length(), gen);
+ shuffled->reset(new Int16Array(array.length(), data));
+ }
+
+ template <typename Rng>
+ std::shared_ptr<Int16Array> Shuffle(const Int16Array& array, Rng& gen) {
+ std::shared_ptr<Int16Array> out;
+ Shuffle(array, gen, &out);
+ return out;
+ }
+
+ void Identity(int64_t length, std::shared_ptr<Int16Array>* identity) {
+ Int16Builder identity_builder;
+ ASSERT_OK(identity_builder.Resize(length));
+ for (int16_t i = 0; i < length; ++i) {
+ identity_builder.UnsafeAppend(i);
+ }
+ ASSERT_OK(identity_builder.Finish(identity));
+ }
+
+ std::shared_ptr<Int16Array> Identity(int64_t length) {
+ std::shared_ptr<Int16Array> out;
+ Identity(length, &out);
+ return out;
+ }
+
+ std::shared_ptr<Int16Array> Inverse(const std::shared_ptr<Int16Array>& permutation) {
+ auto length = static_cast<int16_t>(permutation->length());
+
+ std::vector<bool> cycle_lengths(length + 1, false);
+ auto permutation_to_the_i = permutation;
+ for (int16_t cycle_length = 1; cycle_length <= length; ++cycle_length) {
+ cycle_lengths[cycle_length] = HasTrivialCycle(*permutation_to_the_i);
+ permutation_to_the_i = DoTake(*permutation, *permutation_to_the_i);
+ }
+
+ uint64_t cycle_to_identity_length = 1;
+ for (int16_t cycle_length = length; cycle_length > 1; --cycle_length) {
+ if (!cycle_lengths[cycle_length]) {
+ continue;
+ }
+ if (cycle_to_identity_length % cycle_length == 0) {
+ continue;
+ }
+ if (cycle_to_identity_length >
+ std::numeric_limits<uint64_t>::max() / cycle_length) {
+ // overflow, can't compute Inverse
+ return nullptr;
+ }
+ cycle_to_identity_length *= cycle_length;
+ }
+
+ return DoTakeN(cycle_to_identity_length - 1, permutation);
+ }
+
+ bool HasTrivialCycle(const Int16Array& permutation) {
+ for (int64_t i = 0; i < permutation.length(); ++i) {
+ if (permutation.Value(i) == static_cast<int16_t>(i)) {
+ return true;
+ }
+ }
+ return false;
+ }
+};
+
+TEST_F(TestPermutationsWithTake, InvertPermutation) {
+ for (auto seed : std::vector<random::SeedType>({0, kRandomSeed, kRandomSeed * 2 - 1})) {
+ std::default_random_engine gen(seed);
+ for (int16_t length = 0; length < 1 << 10; ++length) {
+ auto identity = Identity(length);
+ auto permutation = Shuffle(*identity, gen);
+ auto inverse = Inverse(permutation);
+ if (inverse == nullptr) {
+ break;
+ }
+ ASSERT_TRUE(DoTake(*inverse, *permutation)->Equals(identity));
+ }
+ }
+}
+
+class TestTakeKernelWithRecordBatch : public TestTakeKernel<RecordBatch> {
+ public:
+ void AssertTake(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::string& indices, const std::string& expected_batch) {
+ std::shared_ptr<RecordBatch> actual;
+
+ for (auto index_type : {int8(), uint32()}) {
+ ASSERT_OK(TakeJSON(schm, batch_json, index_type, indices, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
+ }
+ }
+
+ Status TakeJSON(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
+ const std::shared_ptr<DataType>& index_type, const std::string& indices,
+ std::shared_ptr<RecordBatch>* out) {
+ auto batch = RecordBatchFromJSON(schm, batch_json);
+ ARROW_ASSIGN_OR_RAISE(Datum result,
+ Take(Datum(batch), Datum(ArrayFromJSON(index_type, indices))));
+ *out = result.record_batch();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ auto struct_json = R"([
+ {"a": null, "b": "yo"},
+ {"a": 1, "b": ""},
+ {"a": 2, "b": "hello"},
+ {"a": 4, "b": "eh"}
+ ])";
+ this->AssertTake(schm, struct_json, "[]", "[]");
+ this->AssertTake(schm, struct_json, "[3, 1, 3, 1, 3]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": 4, "b": "eh"}
+ ])");
+ this->AssertTake(schm, struct_json, "[3, 1, 0]", R"([
+ {"a": 4, "b": "eh"},
+ {"a": 1, "b": ""},
+ {"a": null, "b": "yo"}
+ ])");
+ this->AssertTake(schm, struct_json, "[0, 1, 2, 3]", struct_json);
+ this->AssertTake(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
+ {"a": null, "b": "yo"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"},
+ {"a": 2, "b": "hello"}
+ ])");
+}
+
+class TestTakeKernelWithChunkedArray : public TestTakeKernel<ChunkedArray> {
+ public:
+ void AssertTake(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, const std::string& indices,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->TakeWithArray(type, values, indices, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ void AssertChunkedTake(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices,
+ const std::vector<std::string>& expected) {
+ std::shared_ptr<ChunkedArray> actual;
+ ASSERT_OK(this->TakeWithChunkedArray(type, values, indices, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
+ }
+
+ Status TakeWithArray(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values, const std::string& indices,
+ std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values),
+ ArrayFromJSON(int8(), indices)));
+ *out = result.chunked_array();
+ return Status::OK();
+ }
+
+ Status TakeWithChunkedArray(const std::shared_ptr<DataType>& type,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices,
+ std::shared_ptr<ChunkedArray>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values),
+ ChunkedArrayFromJSON(int8(), indices)));
+ *out = result.chunked_array();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestTakeKernelWithChunkedArray, TakeChunkedArray) {
+ this->AssertTake(int8(), {"[]"}, "[]", {"[]"});
+ this->AssertChunkedTake(int8(), {"[]"}, {"[]"}, {"[]"});
+
+ this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7, 9]"});
+ this->AssertChunkedTake(int8(), {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]", "[2]"},
+ {"[7, 8, 7]", "[]", "[9]"});
+ this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"});
+
+ std::shared_ptr<ChunkedArray> arr;
+ ASSERT_RAISES(IndexError,
+ this->TakeWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 5]", &arr));
+ ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[7]", "[8, 9]"},
+ {"[0, 1, 0]", "[5, 1]"}, &arr));
+}
+
+class TestTakeKernelWithTable : public TestTakeKernel<Table> {
+ public:
+ void AssertTake(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json, const std::string& filter,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+
+ ASSERT_OK(this->TakeWithArray(schm, table_json, filter, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
+ }
+
+ void AssertChunkedTake(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& table_json,
+ const std::vector<std::string>& filter,
+ const std::vector<std::string>& expected_table) {
+ std::shared_ptr<Table> actual;
+
+ ASSERT_OK(this->TakeWithChunkedArray(schm, table_json, filter, &actual));
+ ASSERT_OK(actual->ValidateFull());
+ ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
+ }
+
+ Status TakeWithArray(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values, const std::string& indices,
+ std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)),
+ Datum(ArrayFromJSON(int8(), indices))));
+ *out = result.table();
+ return Status::OK();
+ }
+
+ Status TakeWithChunkedArray(const std::shared_ptr<Schema>& schm,
+ const std::vector<std::string>& values,
+ const std::vector<std::string>& indices,
+ std::shared_ptr<Table>* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum result,
+ Take(Datum(TableFromJSON(schm, values)),
+ Datum(ChunkedArrayFromJSON(int8(), indices))));
+ *out = result.table();
+ return Status::OK();
+ }
+};
+
+TEST_F(TestTakeKernelWithTable, TakeTable) {
+ std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
+ auto schm = schema(fields);
+
+ std::vector<std::string> table_json = {
+ "[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]",
+ "[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"};
+
+ this->AssertTake(schm, table_json, "[]", {"[]"});
+ std::vector<std::string> expected_310 = {
+ "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\": \"yo\"}]"};
+ this->AssertTake(schm, table_json, "[3, 1, 0]", expected_310);
+ this->AssertChunkedTake(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json);
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_take.cc b/cpp/src/arrow/compute/kernels/vector_take.cc
deleted file mode 100644
index 536797c..0000000
--- a/cpp/src/arrow/compute/kernels/vector_take.cc
+++ /dev/null
@@ -1,989 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include <algorithm>
-#include <limits>
-#include <type_traits>
-
-#include "arrow/array/array_base.h"
-#include "arrow/array/builder_primitive.h"
-#include "arrow/array/concatenate.h"
-#include "arrow/array/data.h"
-#include "arrow/buffer_builder.h"
-#include "arrow/compute/api_vector.h"
-#include "arrow/compute/kernels/common.h"
-#include "arrow/record_batch.h"
-#include "arrow/result.h"
-#include "arrow/util/bit_block_counter.h"
-#include "arrow/util/bitmap_reader.h"
-#include "arrow/util/int_util.h"
-
-namespace arrow {
-
-using internal::BitBlockCount;
-using internal::BitmapReader;
-using internal::GetArrayView;
-using internal::IndexBoundsCheck;
-using internal::OptionalBitBlockCounter;
-using internal::OptionalBitIndexer;
-
-namespace compute {
-namespace internal {
-
-using TakeState = OptionsWrapper<TakeOptions>;
-
-// ----------------------------------------------------------------------
-// Implement optimized take for primitive types from boolean to 1/2/4/8-byte
-// C-type based types. Use common implementation for every byte width and only
-// generate code for unsigned integer indices, since after boundschecking to
-// check for negative numbers in the indices we can safely reinterpret_cast
-// signed integers as unsigned.
-
-struct PrimitiveTakeArgs {
- const uint8_t* values;
- const uint8_t* values_bitmap = nullptr;
- int values_bit_width;
- int64_t values_length;
- int64_t values_offset;
- int64_t values_null_count;
- const uint8_t* indices;
- const uint8_t* indices_bitmap = nullptr;
- int indices_bit_width;
- int64_t indices_length;
- int64_t indices_offset;
- int64_t indices_null_count;
-};
-
-// Reduce code size by dealing with the unboxing of the kernel inputs once
-// rather than duplicating compiled code to do all these in each kernel.
-PrimitiveTakeArgs GetPrimitiveTakeArgs(const ExecBatch& batch) {
- PrimitiveTakeArgs args;
-
- const ArrayData& arg0 = *batch[0].array();
- const ArrayData& arg1 = *batch[1].array();
-
- // Values
- args.values_bit_width = checked_cast<const FixedWidthType&>(*arg0.type).bit_width();
- args.values = arg0.buffers[1]->data();
- if (args.values_bit_width > 1) {
- args.values += arg0.offset * args.values_bit_width / 8;
- }
- args.values_length = arg0.length;
- args.values_offset = arg0.offset;
- args.values_null_count = arg0.GetNullCount();
- if (arg0.buffers[0]) {
- args.values_bitmap = arg0.buffers[0]->data();
- }
-
- // Indices
- args.indices_bit_width = checked_cast<const FixedWidthType&>(*arg1.type).bit_width();
- args.indices = arg1.buffers[1]->data() + arg1.offset * args.indices_bit_width / 8;
- args.indices_length = arg1.length;
- args.indices_offset = arg1.offset;
- args.indices_null_count = arg1.GetNullCount();
- if (arg1.buffers[0]) {
- args.indices_bitmap = arg1.buffers[0]->data();
- }
-
- return args;
-}
-
-/// \brief The Take implementation for primitive (fixed-width) types does not
-/// use the logical Arrow type but rather the physical C type. This way we
-/// only generate one take function for each byte width.
-///
-/// This function assumes that the indices have been boundschecked.
-template <typename IndexCType, typename ValueCType>
-struct PrimitiveTakeImpl {
- static void Exec(const PrimitiveTakeArgs& args, Datum* out_datum) {
- auto values = reinterpret_cast<const ValueCType*>(args.values);
- auto values_bitmap = args.values_bitmap;
- auto values_offset = args.values_offset;
-
- auto indices = reinterpret_cast<const IndexCType*>(args.indices);
- auto indices_bitmap = args.indices_bitmap;
- auto indices_offset = args.indices_offset;
-
- ArrayData* out_arr = out_datum->mutable_array();
- auto out = out_arr->GetMutableValues<ValueCType>(1);
- auto out_bitmap = out_arr->buffers[0]->mutable_data();
- auto out_offset = out_arr->offset;
-
- // If either the values or indices have nulls, we preemptively zero out the
- // out validity bitmap so that we don't have to use ClearBit in each
- // iteration for nulls.
- if (args.values_null_count > 0 || args.indices_null_count > 0) {
- BitUtil::SetBitsTo(out_bitmap, out_offset, args.indices_length, false);
- }
-
- OptionalBitBlockCounter indices_bit_counter(indices_bitmap, indices_offset,
- args.indices_length);
- int64_t position = 0;
- int64_t valid_count = 0;
- while (position < args.indices_length) {
- BitBlockCount block = indices_bit_counter.NextBlock();
- if (args.values_null_count == 0) {
- // Values are never null, so things are easier
- valid_count += block.popcount;
- if (block.popcount == block.length) {
- // Fastest path: neither values nor index nulls
- BitUtil::SetBitsTo(out_bitmap, out_offset + position, block.length, true);
- for (int64_t i = 0; i < block.length; ++i) {
- out[position] = values[indices[position]];
- ++position;
- }
- } else if (block.popcount > 0) {
- // Slow path: some indices but not all are null
- for (int64_t i = 0; i < block.length; ++i) {
- if (BitUtil::GetBit(indices_bitmap, indices_offset + position)) {
- // index is not null
- BitUtil::SetBit(out_bitmap, out_offset + position);
- out[position] = values[indices[position]];
- } else {
- out[position] = ValueCType{};
- }
- ++position;
- }
- } else {
- position += block.length;
- }
- } else {
- // Values have nulls, so we must do random access into the values bitmap
- if (block.popcount == block.length) {
- // Faster path: indices are not null but values may be
- for (int64_t i = 0; i < block.length; ++i) {
- if (BitUtil::GetBit(values_bitmap, values_offset + indices[position])) {
- // value is not null
- out[position] = values[indices[position]];
- BitUtil::SetBit(out_bitmap, out_offset + position);
- ++valid_count;
- } else {
- out[position] = ValueCType{};
- }
- ++position;
- }
- } else if (block.popcount > 0) {
- // Slow path: some but not all indices are null. Since we are doing
- // random access in general we have to check the value nullness one by
- // one.
- for (int64_t i = 0; i < block.length; ++i) {
- if (BitUtil::GetBit(indices_bitmap, indices_offset + position) &&
- BitUtil::GetBit(values_bitmap, values_offset + indices[position])) {
- // index is not null && value is not null
- out[position] = values[indices[position]];
- BitUtil::SetBit(out_bitmap, out_offset + position);
- ++valid_count;
- } else {
- out[position] = ValueCType{};
- }
- ++position;
- }
- } else {
- memset(out + position, 0, sizeof(ValueCType) * block.length);
- position += block.length;
- }
- }
- }
- out_arr->null_count = out_arr->length - valid_count;
- }
-};
-
-template <typename IndexCType>
-struct BooleanTakeImpl {
- static void Exec(const PrimitiveTakeArgs& args, Datum* out_datum) {
- auto values = args.values;
- auto values_bitmap = args.values_bitmap;
- auto values_offset = args.values_offset;
-
- auto indices = reinterpret_cast<const IndexCType*>(args.indices);
- auto indices_bitmap = args.indices_bitmap;
- auto indices_offset = args.indices_offset;
-
- ArrayData* out_arr = out_datum->mutable_array();
- auto out = out_arr->buffers[1]->mutable_data();
- auto out_bitmap = out_arr->buffers[0]->mutable_data();
- auto out_offset = out_arr->offset;
-
- // If either the values or indices have nulls, we preemptively zero out the
- // out validity bitmap so that we don't have to use ClearBit in each
- // iteration for nulls.
- if (args.values_null_count > 0 || args.indices_null_count > 0) {
- BitUtil::SetBitsTo(out_bitmap, out_offset, args.indices_length, false);
- }
- // Avoid uninitialized data in values array
- BitUtil::SetBitsTo(out, out_offset, args.indices_length, false);
-
- auto PlaceDataBit = [&](int64_t loc, IndexCType index) {
- BitUtil::SetBitTo(out, out_offset + loc,
- BitUtil::GetBit(values, values_offset + index));
- };
-
- OptionalBitBlockCounter indices_bit_counter(indices_bitmap, indices_offset,
- args.indices_length);
- int64_t position = 0;
- int64_t valid_count = 0;
- while (position < args.indices_length) {
- BitBlockCount block = indices_bit_counter.NextBlock();
- if (args.values_null_count == 0) {
- // Values are never null, so things are easier
- valid_count += block.popcount;
- if (block.popcount == block.length) {
- // Fastest path: neither values nor index nulls
- BitUtil::SetBitsTo(out_bitmap, out_offset + position, block.length, true);
- for (int64_t i = 0; i < block.length; ++i) {
- PlaceDataBit(position, indices[position]);
- ++position;
- }
- } else if (block.popcount > 0) {
- // Slow path: some but not all indices are null
- for (int64_t i = 0; i < block.length; ++i) {
- if (BitUtil::GetBit(indices_bitmap, indices_offset + position)) {
- // index is not null
- BitUtil::SetBit(out_bitmap, out_offset + position);
- PlaceDataBit(position, indices[position]);
- }
- ++position;
- }
- } else {
- position += block.length;
- }
- } else {
- // Values have nulls, so we must do random access into the values bitmap
- if (block.popcount == block.length) {
- // Faster path: indices are not null but values may be
- for (int64_t i = 0; i < block.length; ++i) {
- if (BitUtil::GetBit(values_bitmap, values_offset + indices[position])) {
- // value is not null
- BitUtil::SetBit(out_bitmap, out_offset + position);
- PlaceDataBit(position, indices[position]);
- ++valid_count;
- }
- ++position;
- }
- } else if (block.popcount > 0) {
- // Slow path: some but not all indices are null. Since we are doing
- // random access in general we have to check the value nullness one by
- // one.
- for (int64_t i = 0; i < block.length; ++i) {
- if (BitUtil::GetBit(indices_bitmap, indices_offset + position)) {
- // index is not null
- if (BitUtil::GetBit(values_bitmap, values_offset + indices[position])) {
- // value is not null
- PlaceDataBit(position, indices[position]);
- BitUtil::SetBit(out_bitmap, out_offset + position);
- ++valid_count;
- }
- }
- ++position;
- }
- } else {
- position += block.length;
- }
- }
- }
- out_arr->null_count = out_arr->length - valid_count;
- }
-};
-
-template <template <typename...> class TakeImpl, typename... Args>
-void TakeIndexDispatch(const PrimitiveTakeArgs& args, Datum* out) {
- // With the simplifying assumption that boundschecking has taken place
- // already at a higher level, we can now assume that the index values are all
- // non-negative. Thus, we can interpret signed integers as unsigned and avoid
- // having to generate double the amount of binary code to handle each integer
- // width.
- switch (args.indices_bit_width) {
- case 8:
- return TakeImpl<uint8_t, Args...>::Exec(args, out);
- case 16:
- return TakeImpl<uint16_t, Args...>::Exec(args, out);
- case 32:
- return TakeImpl<uint32_t, Args...>::Exec(args, out);
- case 64:
- return TakeImpl<uint64_t, Args...>::Exec(args, out);
- default:
- DCHECK(false) << "Invalid indices byte width";
- break;
- }
-}
-
-Status PreallocateData(KernelContext* ctx, int64_t length, int bit_width, Datum* out) {
- // Preallocate memory
- ArrayData* out_arr = out->mutable_array();
- out_arr->length = length;
- out_arr->buffers.resize(2);
-
- ARROW_ASSIGN_OR_RAISE(out_arr->buffers[0], ctx->AllocateBitmap(length));
- if (bit_width == 1) {
- ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->AllocateBitmap(length));
- } else {
- ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1], ctx->Allocate(length * bit_width / 8));
- }
- return Status::OK();
-}
-
-static void PrimitiveTakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& state = checked_cast<const TakeState&>(*ctx->state());
- if (state.options.boundscheck) {
- KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length()));
- }
- PrimitiveTakeArgs args = GetPrimitiveTakeArgs(batch);
- KERNEL_RETURN_IF_ERROR(
- ctx, PreallocateData(ctx, args.indices_length, args.values_bit_width, out));
- switch (args.values_bit_width) {
- case 1:
- return TakeIndexDispatch<BooleanTakeImpl>(args, out);
- case 8:
- return TakeIndexDispatch<PrimitiveTakeImpl, int8_t>(args, out);
- case 16:
- return TakeIndexDispatch<PrimitiveTakeImpl, int16_t>(args, out);
- case 32:
- return TakeIndexDispatch<PrimitiveTakeImpl, int32_t>(args, out);
- case 64:
- return TakeIndexDispatch<PrimitiveTakeImpl, int64_t>(args, out);
- default:
- DCHECK(false) << "Invalid values byte width";
- break;
- }
-}
-
-static void NullTakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& state = checked_cast<const TakeState&>(*ctx->state());
- if (state.options.boundscheck) {
- KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length()));
- }
- out->value = std::make_shared<NullArray>(batch.length)->data();
-}
-
-static void DictionaryTakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& state = checked_cast<const TakeState&>(*ctx->state());
- DictionaryArray values(batch[0].array());
- Result<Datum> result =
- Take(Datum(values.indices()), batch[1], state.options, ctx->exec_context());
- if (!result.ok()) {
- ctx->SetStatus(result.status());
- return;
- }
- DictionaryArray taken_values(values.type(), (*result).make_array(),
- values.dictionary());
- out->value = taken_values.data();
-}
-
-static void ExtensionTakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& state = checked_cast<const TakeState&>(*ctx->state());
- ExtensionArray values(batch[0].array());
- Result<Datum> result =
- Take(Datum(values.storage()), batch[1], state.options, ctx->exec_context());
- if (!result.ok()) {
- ctx->SetStatus(result.status());
- return;
- }
-
- ExtensionArray taken_values(values.type(), (*result).make_array());
- out->value = taken_values.data();
-}
-
-// ----------------------------------------------------------------------
-
-// Use CRTP to dispatch to type-specific processing of indices for each
-// unsigned integer type.
-template <typename Impl, typename Type>
-struct GenericTakeImpl {
- using ValuesArrayType = typename TypeTraits<Type>::ArrayType;
-
- KernelContext* ctx;
- std::shared_ptr<ArrayData> values;
- std::shared_ptr<ArrayData> indices;
- ArrayData* out;
- TypedBufferBuilder<bool> validity_builder;
-
- GenericTakeImpl(KernelContext* ctx, const ExecBatch& batch, Datum* out)
- : ctx(ctx),
- values(batch[0].array()),
- indices(batch[1].array()),
- out(out->mutable_array()),
- validity_builder(ctx->memory_pool()) {}
-
- virtual ~GenericTakeImpl() = default;
-
- Status FinishCommon() {
- out->buffers.resize(values->buffers.size());
- out->length = validity_builder.length();
- out->null_count = validity_builder.false_count();
- return validity_builder.Finish(&out->buffers[0]);
- }
-
- template <typename IndexCType, typename ValidVisitor, typename NullVisitor>
- Status VisitIndices(ValidVisitor&& visit_valid, NullVisitor&& visit_null) {
- const auto indices_values = indices->GetValues<IndexCType>(1);
- const uint8_t* bitmap = nullptr;
- if (indices->buffers[0]) {
- bitmap = indices->buffers[0]->data();
- }
- OptionalBitIndexer indices_is_valid(indices->buffers[0], indices->offset);
- OptionalBitIndexer values_is_valid(values->buffers[0], values->offset);
- const bool values_have_nulls = (values->GetNullCount() > 0);
-
- OptionalBitBlockCounter bit_counter(bitmap, indices->offset, indices->length);
- int64_t position = 0;
- while (position < indices->length) {
- BitBlockCount block = bit_counter.NextBlock();
- const bool indices_have_nulls = block.popcount < block.length;
- if (!indices_have_nulls && !values_have_nulls) {
- // Fastest path, neither indices nor values have nulls
- validity_builder.UnsafeAppend(block.length, true);
- for (int64_t i = 0; i < block.length; ++i) {
- RETURN_NOT_OK(visit_valid(indices_values[position++]));
- }
- } else if (block.popcount > 0) {
- // Since we have to branch on whether the indices are null or not, we
- // combine the "non-null indices block but some values null" and
- // "some-null indices block but values non-null" into a single loop.
- for (int64_t i = 0; i < block.length; ++i) {
- if ((!indices_have_nulls || indices_is_valid[position]) &&
- values_is_valid[indices_values[position]]) {
- validity_builder.UnsafeAppend(true);
- RETURN_NOT_OK(visit_valid(indices_values[position]));
- } else {
- validity_builder.UnsafeAppend(false);
- RETURN_NOT_OK(visit_null());
- }
- ++position;
- }
- } else {
- // The whole block is null
- validity_builder.UnsafeAppend(block.length, false);
- for (int64_t i = 0; i < block.length; ++i) {
- RETURN_NOT_OK(visit_null());
- }
- position += block.length;
- }
- }
- return Status::OK();
- }
-
- virtual Status Init() { return Status::OK(); }
-
- // Implementation specific finish logic
- virtual Status Finish() = 0;
-
- Status Exec() {
- RETURN_NOT_OK(this->validity_builder.Reserve(indices->length));
- RETURN_NOT_OK(Init());
- int index_width =
- checked_cast<const FixedWidthType&>(*this->indices->type).bit_width() / 8;
-
- // CTRP dispatch here
- switch (index_width) {
- case 1:
- RETURN_NOT_OK(static_cast<Impl*>(this)->template ProcessIndices<uint8_t>());
- break;
- case 2:
- RETURN_NOT_OK(static_cast<Impl*>(this)->template ProcessIndices<uint16_t>());
- break;
- case 4:
- RETURN_NOT_OK(static_cast<Impl*>(this)->template ProcessIndices<uint32_t>());
- break;
- case 8:
- RETURN_NOT_OK(static_cast<Impl*>(this)->template ProcessIndices<uint64_t>());
- break;
- default:
- DCHECK(false) << "Invalid index width";
- break;
- }
- RETURN_NOT_OK(this->FinishCommon());
- return Finish();
- }
-};
-
-#define LIFT_BASE_MEMBERS() \
- using ValuesArrayType = typename Base::ValuesArrayType; \
- using Base::ctx; \
- using Base::values; \
- using Base::indices; \
- using Base::out; \
- using Base::validity_builder
-
-static inline Status VisitNoop() { return Status::OK(); }
-
-// A take implementation for 32-bit and 64-bit variable binary types. Common
-// generated kernels are shared between Binary/String and
-// LargeBinary/LargeString
-template <typename Type>
-struct VarBinaryTakeImpl : public GenericTakeImpl<VarBinaryTakeImpl<Type>, Type> {
- using offset_type = typename Type::offset_type;
-
- using Base = GenericTakeImpl<VarBinaryTakeImpl<Type>, Type>;
- LIFT_BASE_MEMBERS();
-
- std::shared_ptr<ArrayData> values_as_binary;
- TypedBufferBuilder<offset_type> offset_builder;
- TypedBufferBuilder<uint8_t> data_builder;
-
- static constexpr int64_t kOffsetLimit = std::numeric_limits<offset_type>::max() - 1;
-
- VarBinaryTakeImpl(KernelContext* ctx, const ExecBatch& batch, Datum* out)
- : Base(ctx, batch, out),
- offset_builder(ctx->memory_pool()),
- data_builder(ctx->memory_pool()) {}
-
- template <typename IndexCType>
- Status ProcessIndices() {
- ValuesArrayType typed_values(this->values_as_binary);
-
- // Presize the data builder with a rough estimate of the required data size
- const auto values_length = values->length;
- const auto mean_value_length =
- (values_length > 0) ? ((typed_values.raw_value_offsets()[values_length] -
- typed_values.raw_value_offsets()[0]) /
- static_cast<double>(values_length))
- : 0.0;
- RETURN_NOT_OK(data_builder.Reserve(static_cast<int64_t>(
- mean_value_length * (indices->length - indices->GetNullCount()))));
-
- int64_t space_available = data_builder.capacity();
-
- offset_type offset = 0;
- RETURN_NOT_OK(this->template VisitIndices<IndexCType>(
- [&](IndexCType index) {
- offset_builder.UnsafeAppend(offset);
- auto val = typed_values.GetView(index);
- offset_type value_size = static_cast<offset_type>(val.size());
- if (ARROW_PREDICT_FALSE(static_cast<int64_t>(offset) +
- static_cast<int64_t>(value_size)) > kOffsetLimit) {
- return Status::Invalid("Take operation overflowed binary array capacity");
- }
- offset += value_size;
- if (ARROW_PREDICT_FALSE(value_size > space_available)) {
- RETURN_NOT_OK(data_builder.Reserve(value_size));
- space_available = data_builder.capacity() - data_builder.length();
- }
- data_builder.UnsafeAppend(reinterpret_cast<const uint8_t*>(val.data()),
- value_size);
- space_available -= value_size;
- return Status::OK();
- },
- [&]() {
- offset_builder.UnsafeAppend(offset);
- return Status::OK();
- }));
- offset_builder.UnsafeAppend(offset);
- return Status::OK();
- }
-
- Status Init() override {
- ARROW_ASSIGN_OR_RAISE(this->values_as_binary,
- GetArrayView(this->values, TypeTraits<Type>::type_singleton()));
- return offset_builder.Reserve(indices->length + 1);
- }
-
- Status Finish() override {
- RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1]));
- return data_builder.Finish(&out->buffers[2]);
- }
-};
-
-struct FSBTakeImpl : public GenericTakeImpl<FSBTakeImpl, FixedSizeBinaryType> {
- using Base = GenericTakeImpl<FSBTakeImpl, FixedSizeBinaryType>;
- LIFT_BASE_MEMBERS();
-
- TypedBufferBuilder<uint8_t> data_builder;
-
- FSBTakeImpl(KernelContext* ctx, const ExecBatch& batch, Datum* out)
- : Base(ctx, batch, out), data_builder(ctx->memory_pool()) {}
-
- template <typename IndexCType>
- Status ProcessIndices() {
- FixedSizeBinaryArray typed_values(this->values);
- int32_t value_size = typed_values.byte_width();
-
- RETURN_NOT_OK(data_builder.Reserve(value_size * indices->length));
- RETURN_NOT_OK(this->template VisitIndices<IndexCType>(
- [&](IndexCType index) {
- auto val = typed_values.GetView(index);
- data_builder.UnsafeAppend(reinterpret_cast<const uint8_t*>(val.data()),
- value_size);
- return Status::OK();
- },
- [&]() {
- data_builder.UnsafeAppend(value_size, static_cast<uint8_t>(0x00));
- return Status::OK();
- }));
- return Status::OK();
- }
-
- Status Finish() override { return data_builder.Finish(&out->buffers[1]); }
-};
-
-template <typename Type>
-struct ListTakeImpl : public GenericTakeImpl<ListTakeImpl<Type>, Type> {
- using offset_type = typename Type::offset_type;
-
- using Base = GenericTakeImpl<ListTakeImpl<Type>, Type>;
- LIFT_BASE_MEMBERS();
-
- TypedBufferBuilder<offset_type> offset_builder;
- typename TypeTraits<Type>::OffsetBuilderType child_index_builder;
-
- ListTakeImpl(KernelContext* ctx, const ExecBatch& batch, Datum* out)
- : Base(ctx, batch, out),
- offset_builder(ctx->memory_pool()),
- child_index_builder(ctx->memory_pool()) {}
-
- template <typename IndexCType>
- Status ProcessIndices() {
- ValuesArrayType typed_values(this->values);
-
- // TODO presize child_index_builder with a similar heuristic as VarBinaryTakeImpl
-
- offset_type offset = 0;
- auto PushValidIndex = [&](IndexCType index) {
- offset_builder.UnsafeAppend(offset);
- offset_type value_offset = typed_values.value_offset(index);
- offset_type value_length = typed_values.value_length(index);
- offset += value_length;
- RETURN_NOT_OK(child_index_builder.Reserve(value_length));
- for (offset_type j = value_offset; j < value_offset + value_length; ++j) {
- child_index_builder.UnsafeAppend(j);
- }
- return Status::OK();
- };
-
- auto PushNullIndex = [&]() {
- offset_builder.UnsafeAppend(offset);
- return Status::OK();
- };
-
- RETURN_NOT_OK(this->template VisitIndices<IndexCType>(std::move(PushValidIndex),
- std::move(PushNullIndex)));
- offset_builder.UnsafeAppend(offset);
- return Status::OK();
- }
-
- Status Init() override {
- RETURN_NOT_OK(offset_builder.Reserve(indices->length + 1));
- return Status::OK();
- }
-
- Status Finish() override {
- std::shared_ptr<Array> child_indices;
- RETURN_NOT_OK(child_index_builder.Finish(&child_indices));
-
- ValuesArrayType typed_values(this->values);
-
- // No need to boundscheck the child values indices
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child,
- Take(*typed_values.values(), *child_indices,
- TakeOptions::NoBoundsCheck(), ctx->exec_context()));
- RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1]));
- out->child_data = {taken_child->data()};
- return Status::OK();
- }
-};
-
-struct FSLTakeImpl : public GenericTakeImpl<FSLTakeImpl, FixedSizeListType> {
- Int64Builder child_index_builder;
-
- using Base = GenericTakeImpl<FSLTakeImpl, FixedSizeListType>;
- LIFT_BASE_MEMBERS();
-
- FSLTakeImpl(KernelContext* ctx, const ExecBatch& batch, Datum* out)
- : Base(ctx, batch, out), child_index_builder(ctx->memory_pool()) {}
-
- template <typename IndexCType>
- Status ProcessIndices() {
- ValuesArrayType typed_values(this->values);
- int32_t list_size = typed_values.list_type()->list_size();
-
- /// We must take list_size elements even for null elements of
- /// indices.
- RETURN_NOT_OK(child_index_builder.Reserve(indices->length * list_size));
- return this->template VisitIndices<IndexCType>(
- [&](IndexCType index) {
- int64_t offset = index * list_size;
- for (int64_t j = offset; j < offset + list_size; ++j) {
- child_index_builder.UnsafeAppend(j);
- }
- return Status::OK();
- },
- [&]() { return child_index_builder.AppendNulls(list_size); });
- }
-
- Status Finish() override {
- std::shared_ptr<Array> child_indices;
- RETURN_NOT_OK(child_index_builder.Finish(&child_indices));
-
- ValuesArrayType typed_values(this->values);
-
- // No need to boundscheck the child values indices
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> taken_child,
- Take(*typed_values.values(), *child_indices,
- TakeOptions::NoBoundsCheck(), ctx->exec_context()));
- out->child_data = {taken_child->data()};
- return Status::OK();
- }
-};
-
-struct StructTakeImpl : public GenericTakeImpl<StructTakeImpl, StructType> {
- using Base = GenericTakeImpl<StructTakeImpl, StructType>;
- LIFT_BASE_MEMBERS();
-
- using Base::Base;
-
- template <typename IndexCType>
- Status ProcessIndices() {
- StructArray typed_values(values);
- return this->template VisitIndices<IndexCType>(
- [&](IndexCType index) { return Status::OK(); },
- /*visit_null=*/VisitNoop);
- }
-
- Status Finish() override {
- StructArray typed_values(values);
-
- // Select from children without boundschecking
- out->child_data.resize(values->type->num_fields());
- for (int field_index = 0; field_index < values->type->num_fields(); ++field_index) {
- ARROW_ASSIGN_OR_RAISE(Datum taken_field,
- Take(Datum(typed_values.field(field_index)), Datum(indices),
- TakeOptions::NoBoundsCheck(), ctx->exec_context()));
- out->child_data[field_index] = taken_field.array();
- }
- return Status::OK();
- }
-};
-
-#undef LIFT_BASE_MEMBERS
-
-template <typename Impl>
-static void GenericTakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& state = checked_cast<const TakeState&>(*ctx->state());
- if (state.options.boundscheck) {
- KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length()));
- }
- Impl kernel(ctx, batch, out);
- KERNEL_RETURN_IF_ERROR(ctx, kernel.Exec());
-}
-
-// Shorthand naming of these functions
-// A -> Array
-// C -> ChunkedArray
-// R -> RecordBatch
-// T -> Table
-
-Result<std::shared_ptr<Array>> TakeAA(const Array& values, const Array& indices,
- const TakeOptions& options, ExecContext* ctx) {
- ARROW_ASSIGN_OR_RAISE(Datum result,
- CallFunction("array_take", {values, indices}, &options, ctx));
- return result.make_array();
-}
-
-Result<std::shared_ptr<ChunkedArray>> TakeCA(const ChunkedArray& values,
- const Array& indices,
- const TakeOptions& options,
- ExecContext* ctx) {
- auto num_chunks = values.num_chunks();
- std::vector<std::shared_ptr<Array>> new_chunks(1); // Hard-coded 1 for now
- std::shared_ptr<Array> current_chunk;
-
- // Case 1: `values` has a single chunk, so just use it
- if (num_chunks == 1) {
- current_chunk = values.chunk(0);
- } else {
- // TODO Case 2: See if all `indices` fall in the same chunk and call Array Take on it
- // See
- // https://github.com/apache/arrow/blob/6f2c9041137001f7a9212f244b51bc004efc29af/r/src/compute.cpp#L123-L151
- // TODO Case 3: If indices are sorted, can slice them and call Array Take
-
- // Case 4: Else, concatenate chunks and call Array Take
- RETURN_NOT_OK(Concatenate(values.chunks(), default_memory_pool(), ¤t_chunk));
- }
- // Call Array Take on our single chunk
- ARROW_ASSIGN_OR_RAISE(new_chunks[0], TakeAA(*current_chunk, indices, options, ctx));
- return std::make_shared<ChunkedArray>(std::move(new_chunks));
-}
-
-Result<std::shared_ptr<ChunkedArray>> TakeCC(const ChunkedArray& values,
- const ChunkedArray& indices,
- const TakeOptions& options,
- ExecContext* ctx) {
- auto num_chunks = indices.num_chunks();
- std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
- for (int i = 0; i < num_chunks; i++) {
- // Take with that indices chunk
- // Note that as currently implemented, this is inefficient because `values`
- // will get concatenated on every iteration of this loop
- ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ChunkedArray> current_chunk,
- TakeCA(values, *indices.chunk(i), options, ctx));
- // Concatenate the result to make a single array for this chunk
- RETURN_NOT_OK(
- Concatenate(current_chunk->chunks(), default_memory_pool(), &new_chunks[i]));
- }
- return std::make_shared<ChunkedArray>(std::move(new_chunks));
-}
-
-Result<std::shared_ptr<ChunkedArray>> TakeAC(const Array& values,
- const ChunkedArray& indices,
- const TakeOptions& options,
- ExecContext* ctx) {
- auto num_chunks = indices.num_chunks();
- std::vector<std::shared_ptr<Array>> new_chunks(num_chunks);
- for (int i = 0; i < num_chunks; i++) {
- // Take with that indices chunk
- ARROW_ASSIGN_OR_RAISE(new_chunks[i], TakeAA(values, *indices.chunk(i), options, ctx));
- }
- return std::make_shared<ChunkedArray>(std::move(new_chunks));
-}
-
-Result<std::shared_ptr<RecordBatch>> TakeRA(const RecordBatch& batch,
- const Array& indices,
- const TakeOptions& options,
- ExecContext* ctx) {
- auto ncols = batch.num_columns();
- auto nrows = indices.length();
- std::vector<std::shared_ptr<Array>> columns(ncols);
- for (int j = 0; j < ncols; j++) {
- ARROW_ASSIGN_OR_RAISE(columns[j], TakeAA(*batch.column(j), indices, options, ctx));
- }
- return RecordBatch::Make(batch.schema(), nrows, columns);
-}
-
-Result<std::shared_ptr<Table>> TakeTA(const Table& table, const Array& indices,
- const TakeOptions& options, ExecContext* ctx) {
- auto ncols = table.num_columns();
- std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
-
- for (int j = 0; j < ncols; j++) {
- ARROW_ASSIGN_OR_RAISE(columns[j], TakeCA(*table.column(j), indices, options, ctx));
- }
- return Table::Make(table.schema(), columns);
-}
-
-Result<std::shared_ptr<Table>> TakeTC(const Table& table, const ChunkedArray& indices,
- const TakeOptions& options, ExecContext* ctx) {
- auto ncols = table.num_columns();
- std::vector<std::shared_ptr<ChunkedArray>> columns(ncols);
- for (int j = 0; j < ncols; j++) {
- ARROW_ASSIGN_OR_RAISE(columns[j], TakeCC(*table.column(j), indices, options, ctx));
- }
- return Table::Make(table.schema(), columns);
-}
-
-// Metafunction for dispatching to different Take implementations other than
-// Array-Array.
-//
-// TODO: Revamp approach to executing Take operations. In addition to being
-// overly complex dispatching, there is no parallelization.
-class TakeMetaFunction : public MetaFunction {
- public:
- TakeMetaFunction() : MetaFunction("take", Arity::Binary()) {}
-
- Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
- const FunctionOptions* options,
- ExecContext* ctx) const override {
- Datum::Kind index_kind = args[1].kind();
- const TakeOptions& take_opts = static_cast<const TakeOptions&>(*options);
- switch (args[0].kind()) {
- case Datum::ARRAY:
- if (index_kind == Datum::ARRAY) {
- return TakeAA(*args[0].make_array(), *args[1].make_array(), take_opts, ctx);
- } else if (index_kind == Datum::CHUNKED_ARRAY) {
- return TakeAC(*args[0].make_array(), *args[1].chunked_array(), take_opts, ctx);
- }
- break;
- case Datum::CHUNKED_ARRAY:
- if (index_kind == Datum::ARRAY) {
- return TakeCA(*args[0].chunked_array(), *args[1].make_array(), take_opts, ctx);
- } else if (index_kind == Datum::CHUNKED_ARRAY) {
- return TakeCC(*args[0].chunked_array(), *args[1].chunked_array(), take_opts,
- ctx);
- }
- break;
- case Datum::RECORD_BATCH:
- if (index_kind == Datum::ARRAY) {
- return TakeRA(*args[0].record_batch(), *args[1].make_array(), take_opts, ctx);
- }
- break;
- case Datum::TABLE:
- if (index_kind == Datum::ARRAY) {
- return TakeTA(*args[0].table(), *args[1].make_array(), take_opts, ctx);
- } else if (index_kind == Datum::CHUNKED_ARRAY) {
- return TakeTC(*args[0].table(), *args[1].chunked_array(), take_opts, ctx);
- }
- break;
- default:
- break;
- }
- return Status::NotImplemented(
- "Unsupported types for take operation: "
- "values=",
- args[0].ToString(), "indices=", args[1].ToString());
- }
-};
-
-static InputType kTakeIndexType(match::Integer(), ValueDescr::ARRAY);
-
-void RegisterVectorTake(FunctionRegistry* registry) {
- VectorKernel base;
- base.init = InitWrapOptions<TakeOptions>;
- base.can_execute_chunkwise = false;
-
- auto array_take = std::make_shared<VectorFunction>("array_take", Arity::Binary());
-
- auto AddTakeKernel = [&](InputType value_ty, ArrayKernelExec exec) {
- base.signature =
- KernelSignature::Make({value_ty, kTakeIndexType}, OutputType(FirstType));
- base.exec = exec;
- DCHECK_OK(array_take->AddKernel(base));
- };
-
- // Single kernel entry point for all primitive types. We dispatch to take
- // implementations inside the kernel for now. The primitive take
- // implementation writes into preallocated memory while the other
- // implementations handle their own memory allocation.
- AddTakeKernel(InputType(match::Primitive(), ValueDescr::ARRAY), PrimitiveTakeExec);
-
- // Take implementations for Binary, String, LargeBinary, LargeString, and
- // FixedSizeBinary
- AddTakeKernel(InputType(match::BinaryLike(), ValueDescr::ARRAY),
- GenericTakeExec<VarBinaryTakeImpl<BinaryType>>);
- AddTakeKernel(InputType(match::LargeBinaryLike(), ValueDescr::ARRAY),
- GenericTakeExec<VarBinaryTakeImpl<LargeBinaryType>>);
- AddTakeKernel(InputType::Array(Type::FIXED_SIZE_BINARY), GenericTakeExec<FSBTakeImpl>);
-
- AddTakeKernel(InputType::Array(null()), NullTakeExec);
- AddTakeKernel(InputType::Array(Type::DECIMAL), GenericTakeExec<FSBTakeImpl>);
- AddTakeKernel(InputType::Array(Type::DICTIONARY), DictionaryTakeExec);
- AddTakeKernel(InputType::Array(Type::EXTENSION), ExtensionTakeExec);
- AddTakeKernel(InputType::Array(Type::LIST), GenericTakeExec<ListTakeImpl<ListType>>);
- AddTakeKernel(InputType::Array(Type::LARGE_LIST),
- GenericTakeExec<ListTakeImpl<LargeListType>>);
- AddTakeKernel(InputType::Array(Type::FIXED_SIZE_LIST), GenericTakeExec<FSLTakeImpl>);
- AddTakeKernel(InputType::Array(Type::STRUCT), GenericTakeExec<StructTakeImpl>);
-
- // TODO: Reuse ListType kernel for MAP
- AddTakeKernel(InputType::Array(Type::MAP), GenericTakeExec<ListTakeImpl<MapType>>);
-
- DCHECK_OK(registry->AddFunction(std::move(array_take)));
-
- // Add take metafunction
- DCHECK_OK(registry->AddFunction(std::make_shared<TakeMetaFunction>()));
-}
-
-} // namespace internal
-} // namespace compute
-} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_take_test.cc b/cpp/src/arrow/compute/kernels/vector_take_test.cc
deleted file mode 100644
index a7568df..0000000
--- a/cpp/src/arrow/compute/kernels/vector_take_test.cc
+++ /dev/null
@@ -1,844 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include <algorithm>
-#include <limits>
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "arrow/compute/api.h"
-#include "arrow/compute/kernels/test_util.h"
-#include "arrow/table.h"
-#include "arrow/testing/gtest_common.h"
-#include "arrow/testing/gtest_util.h"
-#include "arrow/testing/random.h"
-#include "arrow/testing/util.h"
-
-namespace arrow {
-namespace compute {
-
-using internal::checked_cast;
-using internal::checked_pointer_cast;
-using util::string_view;
-
-void AssertTakeArrays(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& indices,
- const std::shared_ptr<Array>& expected) {
- ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, Take(*values, *indices));
- ASSERT_OK(actual->ValidateFull());
- AssertArraysEqual(*expected, *actual, /*verbose=*/true);
-}
-
-Status TakeJSON(const std::shared_ptr<DataType>& type, const std::string& values,
- const std::shared_ptr<DataType>& index_type, const std::string& indices,
- std::shared_ptr<Array>* out) {
- return Take(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, indices))
- .Value(out);
-}
-
-void CheckTake(const std::shared_ptr<DataType>& type, const std::string& values,
- const std::string& indices, const std::string& expected) {
- std::shared_ptr<Array> actual;
-
- for (auto index_type : {int8(), uint32()}) {
- ASSERT_OK(TakeJSON(type, values, index_type, indices, &actual));
- ASSERT_OK(actual->ValidateFull());
- AssertArraysEqual(*ArrayFromJSON(type, expected), *actual, /*verbose=*/true);
- }
-}
-
-void AssertTakeNull(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(null(), values, indices, expected);
-}
-
-void AssertTakeBoolean(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(boolean(), values, indices, expected);
-}
-
-template <typename ValuesType, typename IndexType>
-void ValidateTakeImpl(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& indices,
- const std::shared_ptr<Array>& result) {
- using ValuesArrayType = typename TypeTraits<ValuesType>::ArrayType;
- using IndexArrayType = typename TypeTraits<IndexType>::ArrayType;
- auto typed_values = checked_pointer_cast<ValuesArrayType>(values);
- auto typed_result = checked_pointer_cast<ValuesArrayType>(result);
- auto typed_indices = checked_pointer_cast<IndexArrayType>(indices);
- for (int64_t i = 0; i < indices->length(); ++i) {
- if (typed_indices->IsNull(i) || typed_values->IsNull(typed_indices->Value(i))) {
- ASSERT_TRUE(result->IsNull(i)) << i;
- } else {
- ASSERT_FALSE(result->IsNull(i)) << i;
- ASSERT_EQ(typed_result->GetView(i), typed_values->GetView(typed_indices->Value(i)))
- << i;
- }
- }
-}
-
-template <typename ValuesType>
-void ValidateTake(const std::shared_ptr<Array>& values,
- const std::shared_ptr<Array>& indices) {
- ASSERT_OK_AND_ASSIGN(Datum out, Take(values, indices));
- auto taken = out.make_array();
- ASSERT_OK(taken->ValidateFull());
- ASSERT_EQ(indices->length(), taken->length());
- switch (indices->type_id()) {
- case Type::INT8:
- ValidateTakeImpl<ValuesType, Int8Type>(values, indices, taken);
- break;
- case Type::INT16:
- ValidateTakeImpl<ValuesType, Int16Type>(values, indices, taken);
- break;
- case Type::INT32:
- ValidateTakeImpl<ValuesType, Int32Type>(values, indices, taken);
- break;
- case Type::INT64:
- ValidateTakeImpl<ValuesType, Int64Type>(values, indices, taken);
- break;
- case Type::UINT8:
- ValidateTakeImpl<ValuesType, UInt8Type>(values, indices, taken);
- break;
- case Type::UINT16:
- ValidateTakeImpl<ValuesType, UInt16Type>(values, indices, taken);
- break;
- case Type::UINT32:
- ValidateTakeImpl<ValuesType, UInt32Type>(values, indices, taken);
- break;
- case Type::UINT64:
- ValidateTakeImpl<ValuesType, UInt64Type>(values, indices, taken);
- break;
- default:
- FAIL() << "Invalid index type";
- break;
- }
-}
-
-template <typename T>
-T GetMaxIndex(int64_t values_length) {
- int64_t max_index = values_length - 1;
- if (max_index > static_cast<int64_t>(std::numeric_limits<T>::max())) {
- max_index = std::numeric_limits<T>::max();
- }
- return static_cast<T>(max_index);
-}
-
-template <>
-uint64_t GetMaxIndex(int64_t values_length) {
- return static_cast<uint64_t>(values_length - 1);
-}
-
-template <typename ValuesType, typename IndexType>
-void CheckTakeRandom(const std::shared_ptr<Array>& values, int64_t indices_length,
- double null_probability, random::RandomArrayGenerator* rand) {
- using IndexCType = typename IndexType::c_type;
- IndexCType max_index = GetMaxIndex<IndexCType>(values->length());
- auto indices = rand->Numeric<IndexType>(indices_length, static_cast<IndexCType>(0),
- max_index, null_probability);
- auto indices_no_nulls = rand->Numeric<IndexType>(
- indices_length, static_cast<IndexCType>(0), max_index, /*null_probability=*/0.0);
- ValidateTake<ValuesType>(values, indices);
- ValidateTake<ValuesType>(values, indices_no_nulls);
- // Sliced indices array
- if (indices_length >= 2) {
- indices = indices->Slice(1, indices_length - 2);
- indices_no_nulls = indices_no_nulls->Slice(1, indices_length - 2);
- ValidateTake<ValuesType>(values, indices);
- ValidateTake<ValuesType>(values, indices_no_nulls);
- }
-}
-
-template <typename ValuesType, typename DataGenerator>
-void DoRandomTakeTests(DataGenerator&& generate_values) {
- auto rand = random::RandomArrayGenerator(kRandomSeed);
- for (const int64_t length : {1, 16, 59}) {
- for (const int64_t indices_length : {0, 5, 30}) {
- for (const auto null_probability : {0.0, 0.05, 0.25, 0.95, 1.0}) {
- auto values = generate_values(length, null_probability, &rand);
- CheckTakeRandom<ValuesType, Int8Type>(values, indices_length, null_probability,
- &rand);
- CheckTakeRandom<ValuesType, Int16Type>(values, indices_length, null_probability,
- &rand);
- CheckTakeRandom<ValuesType, Int32Type>(values, indices_length, null_probability,
- &rand);
- CheckTakeRandom<ValuesType, Int64Type>(values, indices_length, null_probability,
- &rand);
- CheckTakeRandom<ValuesType, UInt8Type>(values, indices_length, null_probability,
- &rand);
- CheckTakeRandom<ValuesType, UInt16Type>(values, indices_length, null_probability,
- &rand);
- CheckTakeRandom<ValuesType, UInt32Type>(values, indices_length, null_probability,
- &rand);
- CheckTakeRandom<ValuesType, UInt64Type>(values, indices_length, null_probability,
- &rand);
- // Sliced values array
- if (length > 2) {
- values = values->Slice(1, length - 2);
- CheckTakeRandom<ValuesType, UInt64Type>(values, indices_length,
- null_probability, &rand);
- }
- }
- }
- }
-}
-
-template <typename ArrowType>
-class TestTakeKernel : public ::testing::Test {};
-
-TEST(TestTakeKernel, TakeNull) {
- AssertTakeNull("[null, null, null]", "[0, 1, 0]", "[null, null, null]");
-
- std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError,
- TakeJSON(null(), "[null, null, null]", int8(), "[0, 9, 0]", &arr));
- ASSERT_RAISES(IndexError,
- TakeJSON(boolean(), "[null, null, null]", int8(), "[0, -1, 0]", &arr));
-}
-
-TEST(TestTakeKernel, InvalidIndexType) {
- std::shared_ptr<Array> arr;
- ASSERT_RAISES(NotImplemented, TakeJSON(null(), "[null, null, null]", float32(),
- "[0.0, 1.0, 0.1]", &arr));
-}
-
-TEST(TestTakeKernel, TakeBoolean) {
- AssertTakeBoolean("[7, 8, 9]", "[]", "[]");
- AssertTakeBoolean("[true, false, true]", "[0, 1, 0]", "[true, false, true]");
- AssertTakeBoolean("[null, false, true]", "[0, 1, 0]", "[null, false, null]");
- AssertTakeBoolean("[true, false, true]", "[null, 1, 0]", "[null, false, true]");
-
- std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError,
- TakeJSON(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr));
- ASSERT_RAISES(IndexError,
- TakeJSON(boolean(), "[true, false, true]", int8(), "[0, -1, 0]", &arr));
-}
-
-TEST(TestTakeKernel, TakeBooleanRandom) {
- DoRandomTakeTests<BooleanType>(
- [](int64_t length, double null_probability, random::RandomArrayGenerator* rng) {
- return rng->Boolean(length, 0.5, null_probability);
- });
-}
-
-template <typename ArrowType>
-class TestTakeKernelWithNumeric : public TestTakeKernel<ArrowType> {
- protected:
- void AssertTake(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(type_singleton(), values, indices, expected);
- }
-
- std::shared_ptr<DataType> type_singleton() {
- return TypeTraits<ArrowType>::type_singleton();
- }
-};
-
-TYPED_TEST_SUITE(TestTakeKernelWithNumeric, NumericArrowTypes);
-TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) {
- this->AssertTake("[7, 8, 9]", "[]", "[]");
- this->AssertTake("[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]");
- this->AssertTake("[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]");
- this->AssertTake("[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]");
- this->AssertTake("[null, 8, 9]", "[]", "[]");
- this->AssertTake("[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]");
-
- std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError,
- TakeJSON(this->type_singleton(), "[7, 8, 9]", int8(), "[0, 9, 0]", &arr));
- ASSERT_RAISES(IndexError, TakeJSON(this->type_singleton(), "[7, 8, 9]", int8(),
- "[0, -1, 0]", &arr));
-}
-
-TYPED_TEST(TestTakeKernelWithNumeric, TakeRandomNumeric) {
- DoRandomTakeTests<TypeParam>(
- [](int64_t length, double null_probability, random::RandomArrayGenerator* rng) {
- return rng->Numeric<TypeParam>(length, 0, 127, null_probability);
- });
-}
-
-template <typename TypeClass>
-class TestTakeKernelWithString : public TestTakeKernel<TypeClass> {
- public:
- std::shared_ptr<DataType> value_type() {
- return TypeTraits<TypeClass>::type_singleton();
- }
-
- void AssertTake(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(value_type(), values, indices, expected);
- }
-
- void AssertTakeDictionary(const std::string& dictionary_values,
- const std::string& dictionary_indices,
- const std::string& indices,
- const std::string& expected_indices) {
- auto dict = ArrayFromJSON(value_type(), dictionary_values);
- auto type = dictionary(int8(), value_type());
- ASSERT_OK_AND_ASSIGN(auto values,
- DictionaryArray::FromArrays(
- type, ArrayFromJSON(int8(), dictionary_indices), dict));
- ASSERT_OK_AND_ASSIGN(
- auto expected,
- DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict));
- auto take_indices = ArrayFromJSON(int8(), indices);
- AssertTakeArrays(values, take_indices, expected);
- }
-};
-
-TYPED_TEST_SUITE(TestTakeKernelWithString, TestingStringTypes);
-
-TYPED_TEST(TestTakeKernelWithString, TakeString) {
- this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])");
- this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]");
- this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])");
-
- std::shared_ptr<DataType> type = this->value_type();
- std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError,
- TakeJSON(type, R"(["a", "b", "c"])", int8(), "[0, 9, 0]", &arr));
- ASSERT_RAISES(IndexError, TakeJSON(type, R"(["a", "b", null, "ddd", "ee"])", int64(),
- "[2, 5]", &arr));
-}
-
-TEST(TestTakeKernelString, Random) {
- DoRandomTakeTests<StringType>(
- [](int64_t length, double null_probability, random::RandomArrayGenerator* rng) {
- return rng->String(length, 0, 32, null_probability);
- });
- DoRandomTakeTests<LargeStringType>(
- [](int64_t length, double null_probability, random::RandomArrayGenerator* rng) {
- return rng->LargeString(length, 0, 32, null_probability);
- });
-}
-
-TEST(TestTakeKernelFixedSizeBinary, Random) {
- DoRandomTakeTests<FixedSizeBinaryType>([](int64_t length, double null_probability,
- random::RandomArrayGenerator* rng) {
- const int32_t value_size = 16;
- int64_t data_nbytes = length * value_size;
- std::shared_ptr<Buffer> data = *AllocateBuffer(data_nbytes);
- random_bytes(data_nbytes, /*seed=*/0, data->mutable_data());
- auto validity = rng->Boolean(length, 1 - null_probability);
-
- // Assemble the data for a FixedSizeBinaryArray
- auto values_data = std::make_shared<ArrayData>(fixed_size_binary(value_size), length);
- values_data->buffers = {validity->data()->buffers[1], data};
- return MakeArray(values_data);
- });
-}
-
-TYPED_TEST(TestTakeKernelWithString, TakeDictionary) {
- auto dict = R"(["a", "b", "c", "d", "e"])";
- this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]");
- this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]");
- this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]");
-}
-
-class TestTakeKernelFSB : public TestTakeKernel<FixedSizeBinaryType> {
- public:
- std::shared_ptr<DataType> value_type() { return fixed_size_binary(3); }
-
- void AssertTake(const std::string& values, const std::string& indices,
- const std::string& expected) {
- CheckTake(value_type(), values, indices, expected);
- }
-};
-
-TEST_F(TestTakeKernelFSB, TakeFixedSizeBinary) {
- this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]", R"(["aaa", "bbb", "aaa"])");
- this->AssertTake(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\", null]");
- this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[null, 1, 0]", R"([null, "bbb", "aaa"])");
-
- std::shared_ptr<DataType> type = this->value_type();
- std::shared_ptr<Array> arr;
- ASSERT_RAISES(IndexError,
- TakeJSON(type, R"(["aaa", "bbb", "ccc"])", int8(), "[0, 9, 0]", &arr));
- ASSERT_RAISES(IndexError, TakeJSON(type, R"(["aaa", "bbb", null, "ddd", "eee"])",
- int64(), "[2, 5]", &arr));
-}
-
-class TestTakeKernelWithList : public TestTakeKernel<ListType> {};
-
-TEST_F(TestTakeKernelWithList, TakeListInt32) {
- std::string list_json = "[[], [1,2], null, [3]]";
- CheckTake(list(int32()), list_json, "[]", "[]");
- CheckTake(list(int32()), list_json, "[3, 2, 1]", "[[3], null, [1,2]]");
- CheckTake(list(int32()), list_json, "[null, 3, 0]", "[null, [3], []]");
- CheckTake(list(int32()), list_json, "[null, null]", "[null, null]");
- CheckTake(list(int32()), list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]");
- CheckTake(list(int32()), list_json, "[0, 1, 2, 3]", list_json);
- CheckTake(list(int32()), list_json, "[0, 0, 0, 0, 0, 0, 1]",
- "[[], [], [], [], [], [], [1, 2]]");
-}
-
-TEST_F(TestTakeKernelWithList, TakeListListInt32) {
- std::string list_json = R"([
- [],
- [[1], [2, null, 2], []],
- null,
- [[3, null], null]
- ])";
- auto type = list(list(int32()));
- CheckTake(type, list_json, "[]", "[]");
- CheckTake(type, list_json, "[3, 2, 1]", R"([
- [[3, null], null],
- null,
- [[1], [2, null, 2], []]
- ])");
- CheckTake(type, list_json, "[null, 3, 0]", R"([
- null,
- [[3, null], null],
- []
- ])");
- CheckTake(type, list_json, "[null, null]", "[null, null]");
- CheckTake(type, list_json, "[3, 0, 0, 3]",
- "[[[3, null], null], [], [], [[3, null], null]]");
- CheckTake(type, list_json, "[0, 1, 2, 3]", list_json);
- CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]",
- "[[], [], [], [], [], [], [[1], [2, null, 2], []]]");
-}
-
-class TestTakeKernelWithLargeList : public TestTakeKernel<LargeListType> {};
-
-TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) {
- std::string list_json = "[[], [1,2], null, [3]]";
- CheckTake(large_list(int32()), list_json, "[]", "[]");
- CheckTake(large_list(int32()), list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]");
-}
-
-class TestTakeKernelWithFixedSizeList : public TestTakeKernel<FixedSizeListType> {};
-
-TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) {
- std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]";
- CheckTake(fixed_size_list(int32(), 3), list_json, "[]", "[]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 2, 1]",
- "[[7, 8, null], [4, 5, 6], [1, null, 3]]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[null, 2, 0]",
- "[null, [4, 5, 6], null]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[null, null]", "[null, null]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]",
- "[[7, 8, null], null, null, [7, 8, null]]");
- CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json);
- CheckTake(
- fixed_size_list(int32(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]",
- "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, null, 3]]");
-}
-
-class TestTakeKernelWithMap : public TestTakeKernel<MapType> {};
-
-TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) {
- std::string map_json = R"([
- [["joe", 0], ["mark", null]],
- null,
- [["cap", 8]],
- []
- ])";
- CheckTake(map(utf8(), int32()), map_json, "[]", "[]");
- CheckTake(map(utf8(), int32()), map_json, "[3, 1, 3, 1, 3]",
- "[[], null, [], null, []]");
- CheckTake(map(utf8(), int32()), map_json, "[2, 1, null]", R"([
- [["cap", 8]],
- null,
- null
- ])");
- CheckTake(map(utf8(), int32()), map_json, "[2, 1, 0]", R"([
- [["cap", 8]],
- null,
- [["joe", 0], ["mark", null]]
- ])");
- CheckTake(map(utf8(), int32()), map_json, "[0, 1, 2, 3]", map_json);
- CheckTake(map(utf8(), int32()), map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([
- [["joe", 0], ["mark", null]],
- [["joe", 0], ["mark", null]],
- [["joe", 0], ["mark", null]],
- [["joe", 0], ["mark", null]],
- [["joe", 0], ["mark", null]],
- [["joe", 0], ["mark", null]],
- []
- ])");
-}
-
-class TestTakeKernelWithStruct : public TestTakeKernel<StructType> {};
-
-TEST_F(TestTakeKernelWithStruct, TakeStruct) {
- auto struct_type = struct_({field("a", int32()), field("b", utf8())});
- auto struct_json = R"([
- null,
- {"a": 1, "b": ""},
- {"a": 2, "b": "hello"},
- {"a": 4, "b": "eh"}
- ])";
- CheckTake(struct_type, struct_json, "[]", "[]");
- CheckTake(struct_type, struct_json, "[3, 1, 3, 1, 3]", R"([
- {"a": 4, "b": "eh"},
- {"a": 1, "b": ""},
- {"a": 4, "b": "eh"},
- {"a": 1, "b": ""},
- {"a": 4, "b": "eh"}
- ])");
- CheckTake(struct_type, struct_json, "[3, 1, 0]", R"([
- {"a": 4, "b": "eh"},
- {"a": 1, "b": ""},
- null
- ])");
- CheckTake(struct_type, struct_json, "[0, 1, 2, 3]", struct_json);
- CheckTake(struct_type, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
- null,
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"}
- ])");
-}
-
-class TestTakeKernelWithUnion : public TestTakeKernel<UnionType> {};
-
-// TODO: Restore Union take functionality
-TEST_F(TestTakeKernelWithUnion, DISABLED_TakeUnion) {
- for (auto union_ : UnionTypeFactories()) {
- auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5});
- auto union_json = R"([
- null,
- [2, 222],
- [5, "hello"],
- [5, "eh"],
- null,
- [2, 111]
- ])";
- CheckTake(union_type, union_json, "[]", "[]");
- CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
- [5, "eh"],
- [2, 222],
- [5, "eh"],
- [2, 222],
- [5, "eh"]
- ])");
- CheckTake(union_type, union_json, "[4, 2, 1]", R"([
- null,
- [5, "hello"],
- [2, 222]
- ])");
- CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5]", union_json);
- CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
- null,
- [5, "hello"],
- [5, "hello"],
- [5, "hello"],
- [5, "hello"],
- [5, "hello"],
- [5, "hello"]
- ])");
- }
-}
-
-class TestPermutationsWithTake : public TestBase {
- protected:
- void DoTake(const Int16Array& values, const Int16Array& indices,
- std::shared_ptr<Int16Array>* out) {
- ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> boxed_out, Take(values, indices));
- ASSERT_OK(boxed_out->ValidateFull());
- *out = checked_pointer_cast<Int16Array>(std::move(boxed_out));
- }
-
- std::shared_ptr<Int16Array> DoTake(const Int16Array& values,
- const Int16Array& indices) {
- std::shared_ptr<Int16Array> out;
- DoTake(values, indices, &out);
- return out;
- }
-
- std::shared_ptr<Int16Array> DoTakeN(uint64_t n, std::shared_ptr<Int16Array> array) {
- auto power_of_2 = array;
- array = Identity(array->length());
- while (n != 0) {
- if (n & 1) {
- array = DoTake(*array, *power_of_2);
- }
- power_of_2 = DoTake(*power_of_2, *power_of_2);
- n >>= 1;
- }
- return array;
- }
-
- template <typename Rng>
- void Shuffle(const Int16Array& array, Rng& gen, std::shared_ptr<Int16Array>* shuffled) {
- auto byte_length = array.length() * sizeof(int16_t);
- ASSERT_OK_AND_ASSIGN(auto data, array.values()->CopySlice(0, byte_length));
- auto mutable_data = reinterpret_cast<int16_t*>(data->mutable_data());
- std::shuffle(mutable_data, mutable_data + array.length(), gen);
- shuffled->reset(new Int16Array(array.length(), data));
- }
-
- template <typename Rng>
- std::shared_ptr<Int16Array> Shuffle(const Int16Array& array, Rng& gen) {
- std::shared_ptr<Int16Array> out;
- Shuffle(array, gen, &out);
- return out;
- }
-
- void Identity(int64_t length, std::shared_ptr<Int16Array>* identity) {
- Int16Builder identity_builder;
- ASSERT_OK(identity_builder.Resize(length));
- for (int16_t i = 0; i < length; ++i) {
- identity_builder.UnsafeAppend(i);
- }
- ASSERT_OK(identity_builder.Finish(identity));
- }
-
- std::shared_ptr<Int16Array> Identity(int64_t length) {
- std::shared_ptr<Int16Array> out;
- Identity(length, &out);
- return out;
- }
-
- std::shared_ptr<Int16Array> Inverse(const std::shared_ptr<Int16Array>& permutation) {
- auto length = static_cast<int16_t>(permutation->length());
-
- std::vector<bool> cycle_lengths(length + 1, false);
- auto permutation_to_the_i = permutation;
- for (int16_t cycle_length = 1; cycle_length <= length; ++cycle_length) {
- cycle_lengths[cycle_length] = HasTrivialCycle(*permutation_to_the_i);
- permutation_to_the_i = DoTake(*permutation, *permutation_to_the_i);
- }
-
- uint64_t cycle_to_identity_length = 1;
- for (int16_t cycle_length = length; cycle_length > 1; --cycle_length) {
- if (!cycle_lengths[cycle_length]) {
- continue;
- }
- if (cycle_to_identity_length % cycle_length == 0) {
- continue;
- }
- if (cycle_to_identity_length >
- std::numeric_limits<uint64_t>::max() / cycle_length) {
- // overflow, can't compute Inverse
- return nullptr;
- }
- cycle_to_identity_length *= cycle_length;
- }
-
- return DoTakeN(cycle_to_identity_length - 1, permutation);
- }
-
- bool HasTrivialCycle(const Int16Array& permutation) {
- for (int64_t i = 0; i < permutation.length(); ++i) {
- if (permutation.Value(i) == static_cast<int16_t>(i)) {
- return true;
- }
- }
- return false;
- }
-};
-
-TEST_F(TestPermutationsWithTake, InvertPermutation) {
- for (auto seed : std::vector<random::SeedType>({0, kRandomSeed, kRandomSeed * 2 - 1})) {
- std::default_random_engine gen(seed);
- for (int16_t length = 0; length < 1 << 10; ++length) {
- auto identity = Identity(length);
- auto permutation = Shuffle(*identity, gen);
- auto inverse = Inverse(permutation);
- if (inverse == nullptr) {
- break;
- }
- ASSERT_TRUE(DoTake(*inverse, *permutation)->Equals(identity));
- }
- }
-}
-
-class TestTakeKernelWithRecordBatch : public TestTakeKernel<RecordBatch> {
- public:
- void AssertTake(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
- const std::string& indices, const std::string& expected_batch) {
- std::shared_ptr<RecordBatch> actual;
-
- for (auto index_type : {int8(), uint32()}) {
- ASSERT_OK(TakeJSON(schm, batch_json, index_type, indices, &actual));
- ASSERT_OK(actual->ValidateFull());
- ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual);
- }
- }
-
- Status TakeJSON(const std::shared_ptr<Schema>& schm, const std::string& batch_json,
- const std::shared_ptr<DataType>& index_type, const std::string& indices,
- std::shared_ptr<RecordBatch>* out) {
- auto batch = RecordBatchFromJSON(schm, batch_json);
- ARROW_ASSIGN_OR_RAISE(Datum result,
- Take(Datum(batch), Datum(ArrayFromJSON(index_type, indices))));
- *out = result.record_batch();
- return Status::OK();
- }
-};
-
-TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) {
- std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
- auto schm = schema(fields);
-
- auto struct_json = R"([
- {"a": null, "b": "yo"},
- {"a": 1, "b": ""},
- {"a": 2, "b": "hello"},
- {"a": 4, "b": "eh"}
- ])";
- this->AssertTake(schm, struct_json, "[]", "[]");
- this->AssertTake(schm, struct_json, "[3, 1, 3, 1, 3]", R"([
- {"a": 4, "b": "eh"},
- {"a": 1, "b": ""},
- {"a": 4, "b": "eh"},
- {"a": 1, "b": ""},
- {"a": 4, "b": "eh"}
- ])");
- this->AssertTake(schm, struct_json, "[3, 1, 0]", R"([
- {"a": 4, "b": "eh"},
- {"a": 1, "b": ""},
- {"a": null, "b": "yo"}
- ])");
- this->AssertTake(schm, struct_json, "[0, 1, 2, 3]", struct_json);
- this->AssertTake(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
- {"a": null, "b": "yo"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"},
- {"a": 2, "b": "hello"}
- ])");
-}
-
-class TestTakeKernelWithChunkedArray : public TestTakeKernel<ChunkedArray> {
- public:
- void AssertTake(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values, const std::string& indices,
- const std::vector<std::string>& expected) {
- std::shared_ptr<ChunkedArray> actual;
- ASSERT_OK(this->TakeWithArray(type, values, indices, &actual));
- ASSERT_OK(actual->ValidateFull());
- AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
- }
-
- void AssertChunkedTake(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values,
- const std::vector<std::string>& indices,
- const std::vector<std::string>& expected) {
- std::shared_ptr<ChunkedArray> actual;
- ASSERT_OK(this->TakeWithChunkedArray(type, values, indices, &actual));
- ASSERT_OK(actual->ValidateFull());
- AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual);
- }
-
- Status TakeWithArray(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values, const std::string& indices,
- std::shared_ptr<ChunkedArray>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values),
- ArrayFromJSON(int8(), indices)));
- *out = result.chunked_array();
- return Status::OK();
- }
-
- Status TakeWithChunkedArray(const std::shared_ptr<DataType>& type,
- const std::vector<std::string>& values,
- const std::vector<std::string>& indices,
- std::shared_ptr<ChunkedArray>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values),
- ChunkedArrayFromJSON(int8(), indices)));
- *out = result.chunked_array();
- return Status::OK();
- }
-};
-
-TEST_F(TestTakeKernelWithChunkedArray, TakeChunkedArray) {
- this->AssertTake(int8(), {"[]"}, "[]", {"[]"});
- this->AssertChunkedTake(int8(), {"[]"}, {"[]"}, {"[]"});
-
- this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7, 9]"});
- this->AssertChunkedTake(int8(), {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]", "[2]"},
- {"[7, 8, 7]", "[]", "[9]"});
- this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"});
-
- std::shared_ptr<ChunkedArray> arr;
- ASSERT_RAISES(IndexError,
- this->TakeWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 5]", &arr));
- ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[7]", "[8, 9]"},
- {"[0, 1, 0]", "[5, 1]"}, &arr));
-}
-
-class TestTakeKernelWithTable : public TestTakeKernel<Table> {
- public:
- void AssertTake(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& table_json, const std::string& filter,
- const std::vector<std::string>& expected_table) {
- std::shared_ptr<Table> actual;
-
- ASSERT_OK(this->TakeWithArray(schm, table_json, filter, &actual));
- ASSERT_OK(actual->ValidateFull());
- ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
- }
-
- void AssertChunkedTake(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& table_json,
- const std::vector<std::string>& filter,
- const std::vector<std::string>& expected_table) {
- std::shared_ptr<Table> actual;
-
- ASSERT_OK(this->TakeWithChunkedArray(schm, table_json, filter, &actual));
- ASSERT_OK(actual->ValidateFull());
- ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual);
- }
-
- Status TakeWithArray(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& values, const std::string& indices,
- std::shared_ptr<Table>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)),
- Datum(ArrayFromJSON(int8(), indices))));
- *out = result.table();
- return Status::OK();
- }
-
- Status TakeWithChunkedArray(const std::shared_ptr<Schema>& schm,
- const std::vector<std::string>& values,
- const std::vector<std::string>& indices,
- std::shared_ptr<Table>* out) {
- ARROW_ASSIGN_OR_RAISE(Datum result,
- Take(Datum(TableFromJSON(schm, values)),
- Datum(ChunkedArrayFromJSON(int8(), indices))));
- *out = result.table();
- return Status::OK();
- }
-};
-
-TEST_F(TestTakeKernelWithTable, TakeTable) {
- std::vector<std::shared_ptr<Field>> fields = {field("a", int32()), field("b", utf8())};
- auto schm = schema(fields);
-
- std::vector<std::string> table_json = {
- "[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]",
- "[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"};
-
- this->AssertTake(schm, table_json, "[]", {"[]"});
- std::vector<std::string> expected_310 = {
- "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\": \"yo\"}]"};
- this->AssertTake(schm, table_json, "[3, 1, 0]", expected_310);
- this->AssertChunkedTake(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json);
-}
-
-} // namespace compute
-} // namespace arrow
diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc
index ebae60a..416a391 100644
--- a/cpp/src/arrow/compute/registry.cc
+++ b/cpp/src/arrow/compute/registry.cc
@@ -110,10 +110,9 @@ static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {
RegisterScalarAggregateBasic(registry.get());
// Vector functions
- RegisterVectorFilter(registry.get());
RegisterVectorHash(registry.get());
+ RegisterVectorSelection(registry.get());
RegisterVectorSort(registry.get());
- RegisterVectorTake(registry.get());
return registry;
}
diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h
index 515b17b..93598f7 100644
--- a/cpp/src/arrow/compute/registry_internal.h
+++ b/cpp/src/arrow/compute/registry_internal.h
@@ -34,10 +34,9 @@ void RegisterScalarStringAscii(FunctionRegistry* registry);
void RegisterScalarValidity(FunctionRegistry* registry);
// Vector functions
-void RegisterVectorFilter(FunctionRegistry* registry);
void RegisterVectorHash(FunctionRegistry* registry);
+void RegisterVectorSelection(FunctionRegistry* registry);
void RegisterVectorSort(FunctionRegistry* registry);
-void RegisterVectorTake(FunctionRegistry* registry);
// Aggregate functions
void RegisterScalarAggregateBasic(FunctionRegistry* registry);
diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc
index ed645d7..9a46fc2 100644
--- a/cpp/src/arrow/dataset/filter.cc
+++ b/cpp/src/arrow/dataset/filter.cc
@@ -1228,7 +1228,8 @@ Result<std::shared_ptr<RecordBatch>> TreeEvaluator::Filter(
auto selection_array = selection.make_array();
compute::ExecContext ctx(pool);
ARROW_ASSIGN_OR_RAISE(Datum filtered,
- compute::Filter(batch, selection_array, {}, &ctx));
+ compute::Filter(batch, selection_array,
+ compute::FilterOptions::Defaults(), &ctx));
return filtered.record_batch();
}
diff --git a/cpp/src/arrow/testing/random.cc b/cpp/src/arrow/testing/random.cc
index 6146cb8..6a1249d 100644
--- a/cpp/src/arrow/testing/random.cc
+++ b/cpp/src/arrow/testing/random.cc
@@ -76,7 +76,8 @@ struct GenerateOptions {
double probability_;
};
-std::shared_ptr<Array> RandomArrayGenerator::Boolean(int64_t size, double probability,
+std::shared_ptr<Array> RandomArrayGenerator::Boolean(int64_t size,
+ double true_probability,
double null_probability) {
// The boolean generator does not care about the value distribution since it
// only calls the GenerateBitmap method.
@@ -84,7 +85,13 @@ std::shared_ptr<Array> RandomArrayGenerator::Boolean(int64_t size, double probab
BufferVector buffers{2};
// Need 2 distinct generators such that probabilities are not shared.
- GenOpt value_gen(seed(), 0, 1, probability);
+
+ // The "GenerateBitmap" function is written to generate validity bitmaps
+ // parameterized by the null probability, which is the probability of 0. For
+ // boolean data, the true probability is the probability of 1, so to use
+ // GenerateBitmap we must provide the probability of false instead.
+ GenOpt value_gen(seed(), 0, 1, 1 - true_probability);
+
GenOpt null_gen(seed(), 0, 1, null_probability);
int64_t null_count = 0;
diff --git a/cpp/src/arrow/testing/random.h b/cpp/src/arrow/testing/random.h
index 0b4e7e3..7f2bf8e 100644
--- a/cpp/src/arrow/testing/random.h
+++ b/cpp/src/arrow/testing/random.h
@@ -45,11 +45,11 @@ class ARROW_EXPORT RandomArrayGenerator {
/// \brief Generates a random BooleanArray
///
/// \param[in] size the size of the array to generate
- /// \param[in] probability the estimated number of active bits
+ /// \param[in] true_probability the probability of a value being 1 / bit-set
/// \param[in] null_probability the probability of a row being null
///
/// \return a generated Array
- std::shared_ptr<Array> Boolean(int64_t size, double probability,
+ std::shared_ptr<Array> Boolean(int64_t size, double true_probability,
double null_probability = 0);
/// \brief Generates a random UInt8Array
diff --git a/cpp/src/arrow/util/bit_block_counter.cc b/cpp/src/arrow/util/bit_block_counter.cc
index 648b46b..a550ecc 100644
--- a/cpp/src/arrow/util/bit_block_counter.cc
+++ b/cpp/src/arrow/util/bit_block_counter.cc
@@ -125,11 +125,11 @@ OptionalBitBlockCounter::OptionalBitBlockCounter(
: OptionalBitBlockCounter(validity_bitmap ? validity_bitmap->data() : nullptr, offset,
length) {}
-BitBlockCount BinaryBitBlockCounter::NextAndWord() {
+template <template <typename T> class Op>
+BitBlockCount BinaryBitBlockCounter::NextWord() {
if (!bits_remaining_) {
return {0, 0};
}
-
// When the offset is > 0, we need there to be a word beyond the last aligned
// word in the bitmap for the bit shifting logic.
const int64_t bits_required_to_use_words =
@@ -139,8 +139,8 @@ BitBlockCount BinaryBitBlockCounter::NextAndWord() {
const int16_t run_length = static_cast<int16_t>(std::min(bits_remaining_, kWordBits));
int16_t popcount = 0;
for (int64_t i = 0; i < run_length; ++i) {
- if (BitUtil::GetBit(left_bitmap_, left_offset_ + i) &&
- BitUtil::GetBit(right_bitmap_, right_offset_ + i)) {
+ if (Op<bool>::Call(BitUtil::GetBit(left_bitmap_, left_offset_ + i),
+ BitUtil::GetBit(right_bitmap_, right_offset_ + i))) {
++popcount;
}
}
@@ -154,13 +154,14 @@ BitBlockCount BinaryBitBlockCounter::NextAndWord() {
int64_t popcount = 0;
if (left_offset_ == 0 && right_offset_ == 0) {
- popcount = BitUtil::PopCount(LoadWord(left_bitmap_) & LoadWord(right_bitmap_));
+ popcount = BitUtil::PopCount(
+ Op<uint64_t>::Call(LoadWord(left_bitmap_), LoadWord(right_bitmap_)));
} else {
auto left_word =
ShiftWord(LoadWord(left_bitmap_), LoadWord(left_bitmap_ + 8), left_offset_);
auto right_word =
ShiftWord(LoadWord(right_bitmap_), LoadWord(right_bitmap_ + 8), right_offset_);
- popcount = BitUtil::PopCount(left_word & right_word);
+ popcount = BitUtil::PopCount(Op<uint64_t>::Call(left_word, right_word));
}
left_bitmap_ += kWordBits / 8;
right_bitmap_ += kWordBits / 8;
@@ -168,5 +169,17 @@ BitBlockCount BinaryBitBlockCounter::NextAndWord() {
return {64, static_cast<int16_t>(popcount)};
}
+BitBlockCount BinaryBitBlockCounter::NextAndWord() {
+ return NextWord<detail::BitBlockAnd>();
+}
+
+BitBlockCount BinaryBitBlockCounter::NextOrWord() {
+ return NextWord<detail::BitBlockOr>();
+}
+
+BitBlockCount BinaryBitBlockCounter::NextOrNotWord() {
+ return NextWord<detail::BitBlockOrNot>();
+}
+
} // namespace internal
} // namespace arrow
diff --git a/cpp/src/arrow/util/bit_block_counter.h b/cpp/src/arrow/util/bit_block_counter.h
index d6318e7..3ee777f 100644
--- a/cpp/src/arrow/util/bit_block_counter.h
+++ b/cpp/src/arrow/util/bit_block_counter.h
@@ -33,11 +33,50 @@ class Buffer;
namespace internal {
+namespace detail {
+
+// These templates are here to help with unit tests
+
+template <typename T>
+struct BitBlockAnd {
+ static T Call(T left, T right) { return left & right; }
+};
+
+template <>
+struct BitBlockAnd<bool> {
+ static bool Call(bool left, bool right) { return left && right; }
+};
+
+template <typename T>
+struct BitBlockOr {
+ static T Call(T left, T right) { return left | right; }
+};
+
+template <>
+struct BitBlockOr<bool> {
+ static bool Call(bool left, bool right) { return left || right; }
+};
+
+template <typename T>
+struct BitBlockOrNot {
+ static T Call(T left, T right) { return left | ~right; }
+};
+
+template <>
+struct BitBlockOrNot<bool> {
+ static bool Call(bool left, bool right) { return left || !right; }
+};
+
+} // namespace detail
+
/// \brief Return value from bit block counters: the total number of bits and
/// the number of set bits.
struct BitBlockCount {
int16_t length;
int16_t popcount;
+
+ bool NoneSet() const { return this->popcount == 0; }
+ bool AllSet() const { return this->length == this->popcount; }
};
/// \brief A class that scans through a true/false bitmap to compute popcounts
@@ -104,6 +143,22 @@ class ARROW_EXPORT OptionalBitBlockCounter {
}
}
+ // Like NextBlock, but returns a word-sized block even when there is no
+ // validity bitmap
+ BitBlockCount NextWord() {
+ static constexpr int64_t kWordSize = 64;
+ if (has_bitmap_) {
+ BitBlockCount block = counter_.NextWord();
+ position_ += block.length;
+ return block;
+ } else {
+ int16_t block_size = static_cast<int16_t>(std::min(kWordSize, length_ - position_));
+ position_ += block_size;
+ // All values are non-null
+ return {block_size, block_size};
+ }
+ }
+
private:
BitBlockCounter counter_;
int64_t position_;
@@ -132,7 +187,16 @@ class ARROW_EXPORT BinaryBitBlockCounter {
/// blocks in subsequent invocations.
BitBlockCount NextAndWord();
+ /// \brief Computes "x | y" block for each available run of bits.
+ BitBlockCount NextOrWord();
+
+ /// \brief Computes "x | ~y" block for each available run of bits.
+ BitBlockCount NextOrNotWord();
+
private:
+ template <template <typename T> class Op>
+ BitBlockCount NextWord();
+
const uint8_t* left_bitmap_;
int64_t left_offset_;
const uint8_t* right_bitmap_;
diff --git a/cpp/src/arrow/util/bit_block_counter_test.cc b/cpp/src/arrow/util/bit_block_counter_test.cc
index a634bfe..f1a0f35 100644
--- a/cpp/src/arrow/util/bit_block_counter_test.cc
+++ b/cpp/src/arrow/util/bit_block_counter_test.cc
@@ -205,7 +205,8 @@ TEST_F(TestBitBlockCounter, FourWordsRandomData) {
}
}
-TEST(TestBinaryBitBlockCounter, NextAndWord) {
+template <template <typename T> class Op, typename NextWordFunc>
+void CheckBinaryBitBlockOp(NextWordFunc&& get_next_word) {
const int64_t nbytes = 1024;
auto left = *AllocateBuffer(nbytes);
auto right = *AllocateBuffer(nbytes);
@@ -218,12 +219,12 @@ TEST(TestBinaryBitBlockCounter, NextAndWord) {
overlap_length);
int64_t position = 0;
do {
- BitBlockCount block = counter.NextAndWord();
+ BitBlockCount block = get_next_word(&counter);
int expected_popcount = 0;
for (int j = 0; j < block.length; ++j) {
- expected_popcount +=
- static_cast<int>(BitUtil::GetBit(left->data(), position + left_offset + j) &&
- BitUtil::GetBit(right->data(), position + right_offset + j));
+ expected_popcount += static_cast<int>(
+ Op<bool>::Call(BitUtil::GetBit(left->data(), position + left_offset + j),
+ BitUtil::GetBit(right->data(), position + right_offset + j)));
}
ASSERT_EQ(block.popcount, expected_popcount);
position += block.length;
@@ -231,7 +232,7 @@ TEST(TestBinaryBitBlockCounter, NextAndWord) {
// We made it through all the data
ASSERT_EQ(position, overlap_length);
- BitBlockCount block = counter.NextAndWord();
+ BitBlockCount block = get_next_word(&counter);
ASSERT_EQ(block.length, 0);
ASSERT_EQ(block.popcount, 0);
};
@@ -243,8 +244,23 @@ TEST(TestBinaryBitBlockCounter, NextAndWord) {
}
}
-TEST(TestOptionalBitBlockCounter, Basics) {
- const int64_t nbytes = 1024;
+TEST(TestBinaryBitBlockCounter, NextAndWord) {
+ CheckBinaryBitBlockOp<detail::BitBlockAnd>(
+ [](BinaryBitBlockCounter* counter) { return counter->NextAndWord(); });
+}
+
+TEST(TestBinaryBitBlockCounter, NextOrWord) {
+ CheckBinaryBitBlockOp<detail::BitBlockOr>(
+ [](BinaryBitBlockCounter* counter) { return counter->NextOrWord(); });
+}
+
+TEST(TestBinaryBitBlockCounter, NextOrNotWord) {
+ CheckBinaryBitBlockOp<detail::BitBlockOrNot>(
+ [](BinaryBitBlockCounter* counter) { return counter->NextOrNotWord(); });
+}
+
+TEST(TestOptionalBitBlockCounter, NextBlock) {
+ const int64_t nbytes = 5000;
auto bitmap = *AllocateBitmap(nbytes * 8);
random_bytes(nbytes, 0, bitmap->mutable_data());
@@ -264,6 +280,44 @@ TEST(TestOptionalBitBlockCounter, Basics) {
BitBlockCount optional_block = optional_counter.NextBlock();
ASSERT_EQ(optional_block.length, 0);
ASSERT_EQ(optional_block.popcount, 0);
+
+ OptionalBitBlockCounter optional_counter_no_bitmap(nullptr, 0, nbytes * 8);
+ BitBlockCount no_bitmap_block = optional_counter_no_bitmap.NextBlock();
+
+ int16_t max_length = std::numeric_limits<int16_t>::max();
+ ASSERT_EQ(no_bitmap_block.length, max_length);
+ ASSERT_EQ(no_bitmap_block.popcount, max_length);
+ no_bitmap_block = optional_counter_no_bitmap.NextBlock();
+ ASSERT_EQ(no_bitmap_block.length, nbytes * 8 - max_length);
+ ASSERT_EQ(no_bitmap_block.popcount, no_bitmap_block.length);
+}
+
+TEST(TestOptionalBitBlockCounter, NextWord) {
+ const int64_t nbytes = 5000;
+ auto bitmap = *AllocateBitmap(nbytes * 8);
+ random_bytes(nbytes, 0, bitmap->mutable_data());
+
+ OptionalBitBlockCounter optional_counter(bitmap, 0, nbytes * 8);
+ OptionalBitBlockCounter optional_counter_no_bitmap(nullptr, 0, nbytes * 8);
+ BitBlockCounter bit_counter(bitmap->data(), 0, nbytes * 8);
+
+ while (true) {
+ BitBlockCount block = bit_counter.NextWord();
+ BitBlockCount no_bitmap_block = optional_counter_no_bitmap.NextWord();
+ BitBlockCount optional_block = optional_counter.NextWord();
+ ASSERT_EQ(optional_block.length, block.length);
+ ASSERT_EQ(optional_block.popcount, block.popcount);
+
+ ASSERT_EQ(no_bitmap_block.length, block.length);
+ ASSERT_EQ(no_bitmap_block.popcount, block.length);
+ if (block.length == 0) {
+ break;
+ }
+ }
+
+ BitBlockCount optional_block = optional_counter.NextWord();
+ ASSERT_EQ(optional_block.length, 0);
+ ASSERT_EQ(optional_block.popcount, 0);
}
} // namespace internal
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 8a8abfd..d8084f0 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1568,6 +1568,8 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
cdef cppclass CFilterOptions \
" arrow::compute::FilterOptions"(CFunctionOptions):
+ CFilterOptions()
+ CFilterOptions(CFilterNullSelectionBehavior null_selection)
CFilterNullSelectionBehavior null_selection_behavior
cdef cppclass CTakeOptions \
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index fd262e1..a1607e8 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -158,7 +158,7 @@ def test_filter(ty, values):
# non-boolean dtype
mask = pa.array([0, 1, 0, 1, 0])
- with pytest.raises(NotImplementedError, match="no kernel matching"):
+ with pytest.raises(NotImplementedError):
arr.filter(mask)
# wrong length
@@ -229,8 +229,7 @@ def test_filter_errors():
for obj in [arr, batch, table]:
# non-boolean dtype
mask = pa.array([0, 1, 0, 1, 0])
- with pytest.raises(NotImplementedError,
- match="no kernel matching input types"):
+ with pytest.raises(NotImplementedError):
obj.filter(mask)
# wrong length