You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by em...@apache.org on 2019/06/06 00:56:51 UTC
[arrow] branch master updated: ARROW-4990: [C++] Support
Array-Array comparison
This is an automated email from the ASF dual-hosted git repository.
emkornfield pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new a2ef7d9 ARROW-4990: [C++] Support Array-Array comparison
a2ef7d9 is described below
commit a2ef7d9ab4cda9980a2602625c593a246e282847
Author: François Saint-Jacques <fs...@gmail.com>
AuthorDate: Wed Jun 5 17:55:31 2019 -0700
ARROW-4990: [C++] Support Array-Array comparison
Comparison only supported for the left argument to be an array and the right argument a scalar. This extends support for comparing two arrays, but also supporting the case where the left argument is a scalar and the right an array.
Author: François Saint-Jacques <fs...@gmail.com>
Closes #4398 from fsaintjacques/ARROW-4990-compare-array-array and squashes the following commits:
864c67966 <François Saint-Jacques> Conform style guide
1402a98ee <François Saint-Jacques> Address review
b498eddc9 <François Saint-Jacques> mvcc again
c427313e8 <François Saint-Jacques> Make mvcc happy
56e467725 <François Saint-Jacques> autoformat
f6f5274fb <François Saint-Jacques> Supports comparison of Arrays
8643411d5 <François Saint-Jacques> Add binary operation support to PropagateNulls
53703acdd <François Saint-Jacques> Add length() to Datum.
---
cpp/src/arrow/compute/kernel.h | 18 +++
cpp/src/arrow/compute/kernels/compare.cc | 106 +++++++++---
cpp/src/arrow/compute/kernels/filter-benchmark.cc | 25 +++
cpp/src/arrow/compute/kernels/filter-test.cc | 187 ++++++++++++++++++++--
cpp/src/arrow/compute/kernels/filter.cc | 28 +++-
cpp/src/arrow/compute/kernels/filter.h | 31 +++-
cpp/src/arrow/compute/kernels/util-internal.cc | 20 +--
cpp/src/arrow/compute/kernels/util-internal.h | 12 ++
8 files changed, 374 insertions(+), 53 deletions(-)
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index aba659e..c8f2e94 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -64,6 +64,10 @@ struct Datum;
static inline bool CollectionEquals(const std::vector<Datum>& left,
const std::vector<Datum>& right);
+// Datums variants may have a length. This special value indicate that the
+// current variant does not have a length.
+constexpr int64_t kUnknownLength = -1;
+
/// \class Datum
/// \brief Variant type for various Arrow C++ data structures
struct ARROW_EXPORT Datum {
@@ -201,6 +205,20 @@ struct ARROW_EXPORT Datum {
return NULLPTR;
}
+ /// \brief The value length of the variant, if any
+ ///
+ /// \return kUnknownLength if no type
+ int64_t length() const {
+ if (this->kind() == Datum::ARRAY) {
+ return util::get<std::shared_ptr<ArrayData>>(this->value)->length;
+ } else if (this->kind() == Datum::CHUNKED_ARRAY) {
+ return util::get<std::shared_ptr<ChunkedArray>>(this->value)->length();
+ } else if (this->kind() == Datum::SCALAR) {
+ return 1;
+ }
+ return kUnknownLength;
+ }
+
bool Equals(const Datum& other) const {
if (this->kind() != other.kind()) return false;
diff --git a/cpp/src/arrow/compute/kernels/compare.cc b/cpp/src/arrow/compute/kernels/compare.cc
index f27a449..040793f 100644
--- a/cpp/src/arrow/compute/kernels/compare.cc
+++ b/cpp/src/arrow/compute/kernels/compare.cc
@@ -32,17 +32,44 @@ class FunctionContext;
struct Datum;
template <typename ArrowType, CompareOperator Op,
- typename ArrayType = typename TypeTraits<ArrowType>::ArrayType,
typename ScalarType = typename TypeTraits<ArrowType>::ScalarType,
typename T = typename TypeTraits<ArrowType>::CType>
-static Status CompareArrayScalar(const ArrayData& input, const ScalarType& scalar,
- uint8_t* bitmap) {
+static Status CompareArrayScalar(const ArrayData& array, const ScalarType& scalar,
+ uint8_t* output_bitmap) {
+ const T* left = array.GetValues<T>(1);
const T right = scalar.value;
- const T* values = input.GetValues<T>(1);
- size_t i = 0;
- internal::GenerateBitsUnrolled(bitmap, 0, input.length, [values, right, &i]() -> bool {
- return Comparator<T, Op>::Compare(values[i++], right);
+ internal::GenerateBitsUnrolled(
+ output_bitmap, 0, array.length,
+ [&left, right]() -> bool { return Comparator<T, Op>::Compare(*left++, right); });
+
+ return Status::OK();
+}
+
+template <typename ArrowType, CompareOperator Op,
+ typename ScalarType = typename TypeTraits<ArrowType>::ScalarType,
+ typename T = typename TypeTraits<ArrowType>::CType>
+static Status CompareScalarArray(const ScalarType& scalar, const ArrayData& array,
+ uint8_t* output_bitmap) {
+ const T left = scalar.value;
+ const T* right = array.GetValues<T>(1);
+
+ internal::GenerateBitsUnrolled(
+ output_bitmap, 0, array.length,
+ [left, &right]() -> bool { return Comparator<T, Op>::Compare(left, *right++); });
+
+ return Status::OK();
+}
+
+template <typename ArrowType, CompareOperator Op,
+ typename T = typename TypeTraits<ArrowType>::CType>
+static Status CompareArrayArray(const ArrayData& lhs, const ArrayData& rhs,
+ uint8_t* output_bitmap) {
+ const T* left = lhs.GetValues<T>(1);
+ const T* right = rhs.GetValues<T>(1);
+
+ internal::GenerateBitsUnrolled(output_bitmap, 0, lhs.length, [&left, &right]() -> bool {
+ return Comparator<T, Op>::Compare(*left++, *right++);
});
return Status::OK();
@@ -50,31 +77,68 @@ static Status CompareArrayScalar(const ArrayData& input, const ScalarType& scala
template <typename ArrowType, CompareOperator Op>
class CompareFunction final : public FilterFunction {
- using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
public:
explicit CompareFunction(FunctionContext* ctx) : ctx_(ctx) {}
- Status Filter(const ArrayData& input, const Scalar& scalar, ArrayData* output) const {
+ Status Filter(const ArrayData& array, const Scalar& scalar, ArrayData* output) const {
// Caller must cast
- DCHECK(input.type->Equals(scalar.type));
+ DCHECK(array.type->Equals(scalar.type));
// Output must be a boolean array
DCHECK(output->type->Equals(boolean()));
// Output must be of same length
- DCHECK_EQ(output->length, input.length);
+ DCHECK_EQ(output->length, array.length);
// Scalar is null, all comparisons are null.
if (!scalar.is_valid) {
- return detail::SetAllNulls(ctx_, input, output);
+ return detail::SetAllNulls(ctx_, array, output);
}
// Copy null_bitmap
- RETURN_NOT_OK(detail::PropagateNulls(ctx_, input, output));
+ RETURN_NOT_OK(detail::PropagateNulls(ctx_, array, output));
uint8_t* bitmap_result = output->buffers[1]->mutable_data();
return CompareArrayScalar<ArrowType, Op>(
- input, static_cast<const ScalarType&>(scalar), bitmap_result);
+ array, static_cast<const ScalarType&>(scalar), bitmap_result);
+ }
+
+ Status Filter(const Scalar& scalar, const ArrayData& array, ArrayData* output) const {
+ // Caller must cast
+ DCHECK(array.type->Equals(scalar.type));
+ // Output must be a boolean array
+ DCHECK(output->type->Equals(boolean()));
+ // Output must be of same length
+ DCHECK_EQ(output->length, array.length);
+
+ // Scalar is null, all comparisons are null.
+ if (!scalar.is_valid) {
+ return detail::SetAllNulls(ctx_, array, output);
+ }
+
+ // Copy null_bitmap
+ RETURN_NOT_OK(detail::PropagateNulls(ctx_, array, output));
+
+ uint8_t* bitmap_result = output->buffers[1]->mutable_data();
+ return CompareScalarArray<ArrowType, Op>(static_cast<const ScalarType&>(scalar),
+ array, bitmap_result);
+ }
+
+ Status Filter(const ArrayData& lhs, const ArrayData& rhs, ArrayData* output) const {
+ // Caller must cast
+ DCHECK(lhs.type->Equals(rhs.type));
+ // Output must be a boolean array
+ DCHECK(output->type->Equals(boolean()));
+ // Inputs must be of same length
+ DCHECK_EQ(lhs.length, rhs.length);
+ // Output must be of same length as inputs
+ DCHECK_EQ(output->length, lhs.length);
+
+ // Copy null_bitmap
+ RETURN_NOT_OK(detail::AssignNullIntersection(ctx_, lhs, rhs, output));
+
+ uint8_t* bitmap_result = output->buffers[1]->mutable_data();
+ return CompareArrayArray<ArrowType, Op>(lhs, rhs, bitmap_result);
}
private:
@@ -152,13 +216,9 @@ Status Compare(FunctionContext* context, const Datum& left, const Datum& right,
struct CompareOptions options, Datum* out) {
DCHECK(out);
- DCHECK_EQ(left.kind(), Datum::ARRAY);
- DCHECK_EQ(right.kind(), Datum::SCALAR);
- DCHECK(left.type()->Equals(right.type()));
-
- auto array = left.make_array();
- auto type = array->type();
-
+ auto type = left.type();
+ DCHECK(type->Equals(right.type()));
+ // Requires that both types are equal.
auto fn = MakeCompareFilterFunction(context, *type, options);
if (fn == nullptr) {
return Status::NotImplemented("Compare not implemented for type ", type->ToString());
@@ -166,7 +226,9 @@ Status Compare(FunctionContext* context, const Datum& left, const Datum& right,
FilterBinaryKernel filter_kernel(fn);
detail::PrimitiveAllocatingBinaryKernel kernel(&filter_kernel);
- out->value = ArrayData::Make(filter_kernel.out_type(), array->length());
+
+ const int64_t length = FilterBinaryKernel::out_length(left, right);
+ out->value = ArrayData::Make(filter_kernel.out_type(), length);
return kernel.Call(context, left, right, out);
}
diff --git a/cpp/src/arrow/compute/kernels/filter-benchmark.cc b/cpp/src/arrow/compute/kernels/filter-benchmark.cc
index 24e1841..00de199 100644
--- a/cpp/src/arrow/compute/kernels/filter-benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/filter-benchmark.cc
@@ -51,7 +51,32 @@ static void CompareArrayScalarKernel(benchmark::State& state) {
state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t));
}
+static void CompareArrayArrayKernel(benchmark::State& state) {
+ const int64_t memory_size = state.range(0) / 4;
+ const int64_t array_size = memory_size / sizeof(int64_t);
+ const double null_percent = static_cast<double>(state.range(1)) / 100.0;
+ auto rand = random::RandomArrayGenerator(0x94378165);
+ auto lhs = std::static_pointer_cast<NumericArray<Int64Type>>(
+ rand.Int64(array_size, -100, 100, null_percent));
+ auto rhs = std::static_pointer_cast<NumericArray<Int64Type>>(
+ rand.Int64(array_size, -100, 100, null_percent));
+
+ CompareOptions ge(GREATER_EQUAL);
+
+ FunctionContext ctx;
+ for (auto _ : state) {
+ Datum out;
+ ABORT_NOT_OK(Compare(&ctx, Datum(lhs), Datum(rhs), ge, &out));
+ benchmark::DoNotOptimize(out);
+ }
+
+ state.counters["size"] = static_cast<double>(memory_size);
+ state.counters["null_percent"] = static_cast<double>(state.range(1));
+ state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t) * 2);
+}
+
BENCHMARK(CompareArrayScalarKernel)->Apply(RegressionSetArgs);
+BENCHMARK(CompareArrayArrayKernel)->Apply(RegressionSetArgs);
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/filter-test.cc b/cpp/src/arrow/compute/kernels/filter-test.cc
index 9c76b9d..1c8967c 100644
--- a/cpp/src/arrow/compute/kernels/filter-test.cc
+++ b/cpp/src/arrow/compute/kernels/filter-test.cc
@@ -73,6 +73,25 @@ static void ValidateCompare(FunctionContext* ctx, CompareOptions options,
ValidateCompare<ArrowType>(ctx, options, lhs, rhs, expected);
}
+template <typename ArrowType>
+static void ValidateCompare(FunctionContext* ctx, CompareOptions options,
+ const Datum& lhs, const char* rhs_str,
+ const char* expected_str) {
+ auto rhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), rhs_str);
+ auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str);
+ ValidateCompare<ArrowType>(ctx, options, lhs, rhs, expected);
+}
+
+template <typename ArrowType>
+static void ValidateCompare(FunctionContext* ctx, CompareOptions options,
+ const char* lhs_str, const char* rhs_str,
+ const char* expected_str) {
+ auto lhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), lhs_str);
+ auto rhs = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), rhs_str);
+ auto expected = ArrayFromJSON(TypeTraits<BooleanType>::type_singleton(), expected_str);
+ ValidateCompare<ArrowType>(ctx, options, lhs, rhs, expected);
+}
+
template <typename T>
static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) {
switch (op) {
@@ -94,17 +113,20 @@ static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) {
}
template <typename ArrowType>
-static Datum SimpleCompare(CompareOptions options, const Datum& lhs, const Datum& rhs) {
+static Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs,
+ const Datum& rhs) {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
using T = typename TypeTraits<ArrowType>::CType;
- auto array = std::static_pointer_cast<ArrayType>(lhs.make_array());
- T value = std::static_pointer_cast<ScalarType>(rhs.scalar())->value;
+ bool swap = lhs.is_array();
+ auto array = std::static_pointer_cast<ArrayType>((swap ? lhs : rhs).make_array());
+ T value = std::static_pointer_cast<ScalarType>((swap ? rhs : lhs).scalar())->value;
std::vector<bool> bitmap(array->length());
for (int64_t i = 0; i < array->length(); i++) {
- bitmap[i] = SlowCompare<T>(options.op, array->Value(i), value);
+ bitmap[i] = swap ? SlowCompare<T>(options.op, array->Value(i), value)
+ : SlowCompare<T>(options.op, value, array->Value(i));
}
std::shared_ptr<Array> result;
@@ -115,8 +137,58 @@ static Datum SimpleCompare(CompareOptions options, const Datum& lhs, const Datum
std::vector<bool> null_bitmap(array->length());
auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(),
array->length());
- for (int64_t i = 0; i < array->length(); i++, reader.Next())
+ for (int64_t i = 0; i < array->length(); i++, reader.Next()) {
null_bitmap[i] = reader.IsSet();
+ }
+ ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result);
+ }
+
+ return Datum(result);
+}
+
+template <typename ArrowType,
+ typename ArrayType = typename TypeTraits<ArrowType>::ArrayType>
+static std::vector<bool> NullBitmapFromArrays(const ArrayType& lhs,
+ const ArrayType& rhs) {
+ auto left_lambda = [&lhs](int64_t i) {
+ return lhs.null_count() == 0 ? true : lhs.IsValid(i);
+ };
+
+ auto right_lambda = [&rhs](int64_t i) {
+ return rhs.null_count() == 0 ? true : rhs.IsValid(i);
+ };
+
+ const int64_t length = lhs.length();
+ std::vector<bool> null_bitmap(length);
+
+ for (int64_t i = 0; i < length; i++) {
+ null_bitmap[i] = left_lambda(i) && right_lambda(i);
+ }
+
+ return null_bitmap;
+}
+
+template <typename ArrowType>
+static Datum SimpleArrayArrayCompare(CompareOptions options, const Datum& lhs,
+ const Datum& rhs) {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using T = typename TypeTraits<ArrowType>::CType;
+
+ auto l_array = std::static_pointer_cast<ArrayType>(lhs.make_array());
+ auto r_array = std::static_pointer_cast<ArrayType>(rhs.make_array());
+ const int64_t length = l_array->length();
+
+ std::vector<bool> bitmap(length);
+ for (int64_t i = 0; i < length; i++) {
+ bitmap[i] = SlowCompare<T>(options.op, l_array->Value(i), r_array->Value(i));
+ }
+
+ std::shared_ptr<Array> result;
+
+ if (l_array->null_count() == 0 && r_array->null_count() == 0) {
+ ArrayFromVector<BooleanType>(bitmap, &result);
+ } else {
+ std::vector<bool> null_bitmap = NullBitmapFromArrays<ArrowType>(*l_array, *r_array);
ArrayFromVector<BooleanType>(null_bitmap, bitmap, &result);
}
@@ -127,7 +199,10 @@ template <typename ArrowType>
static void ValidateCompare(FunctionContext* ctx, CompareOptions options,
const Datum& lhs, const Datum& rhs) {
Datum result;
- Datum expected = SimpleCompare<ArrowType>(options, lhs, rhs);
+
+ bool has_scalar = lhs.is_scalar() || rhs.is_scalar();
+ Datum expected = has_scalar ? SimpleScalarArrayCompare<ArrowType>(options, lhs, rhs)
+ : SimpleArrayArrayCompare<ArrowType>(options, lhs, rhs);
ValidateCompare<ArrowType>(ctx, options, lhs, rhs, expected);
}
@@ -191,6 +266,61 @@ TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayScalar) {
ValidateCompare<TypeParam>(&this->ctx_, lte, "[null,0,1,1]", one, "[null,1,1,1]");
}
+TYPED_TEST(TestNumericCompareKernel, SimpleCompareScalarArray) {
+ using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
+ using CType = typename TypeTraits<TypeParam>::CType;
+
+ Datum one(std::make_shared<ScalarType>(CType(1)));
+
+ CompareOptions eq(CompareOperator::EQUAL);
+ ValidateCompare<TypeParam>(&this->ctx_, eq, one, "[]", "[]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, one, "[0,0,1,1,2,2]", "[0,0,1,1,0,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, one, "[0,1,2,3,4,5]", "[0,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, one, "[5,4,3,2,1,0]", "[0,0,0,0,1,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, one, "[null,0,1,1]", "[null,0,1,1]");
+
+ CompareOptions neq(CompareOperator::NOT_EQUAL);
+ ValidateCompare<TypeParam>(&this->ctx_, neq, one, "[]", "[]");
+ ValidateCompare<TypeParam>(&this->ctx_, neq, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, neq, one, "[0,0,1,1,2,2]", "[1,1,0,0,1,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, neq, one, "[0,1,2,3,4,5]", "[1,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, neq, one, "[5,4,3,2,1,0]", "[1,1,1,1,0,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, neq, one, "[null,0,1,1]", "[null,1,0,0]");
+
+ CompareOptions gt(CompareOperator::GREATER);
+ ValidateCompare<TypeParam>(&this->ctx_, gt, one, "[]", "[]");
+ ValidateCompare<TypeParam>(&this->ctx_, gt, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, gt, one, "[0,0,1,1,2,2]", "[1,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, gt, one, "[0,1,2,3,4,5]", "[1,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, gt, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, gt, one, "[null,0,1,1]", "[null,1,0,0]");
+
+ CompareOptions gte(CompareOperator::GREATER_EQUAL);
+ ValidateCompare<TypeParam>(&this->ctx_, gte, one, "[]", "[]");
+ ValidateCompare<TypeParam>(&this->ctx_, gte, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, gte, one, "[0,0,1,1,2,2]", "[1,1,1,1,0,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, gte, one, "[0,1,2,3,4,5]", "[1,1,0,0,0,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, gte, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]");
+ ValidateCompare<TypeParam>(&this->ctx_, gte, one, "[null,0,1,1]", "[null,1,1,1]");
+
+ CompareOptions lt(CompareOperator::LESS);
+ ValidateCompare<TypeParam>(&this->ctx_, lt, one, "[]", "[]");
+ ValidateCompare<TypeParam>(&this->ctx_, lt, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, lt, one, "[0,0,1,1,2,2]", "[0,0,0,0,1,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, lt, one, "[0,1,2,3,4,5]", "[0,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, lt, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, lt, one, "[null,0,1,1]", "[null,0,0,0]");
+
+ CompareOptions lte(CompareOperator::LESS_EQUAL);
+ ValidateCompare<TypeParam>(&this->ctx_, lte, one, "[]", "[]");
+ ValidateCompare<TypeParam>(&this->ctx_, lte, one, "[null]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, lte, one, "[0,0,1,1,2,2]", "[0,0,1,1,1,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, lte, one, "[0,1,2,3,4,5]", "[0,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, lte, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]");
+ ValidateCompare<TypeParam>(&this->ctx_, lte, one, "[null,0,1,1]", "[null,0,1,1]");
+}
+
TYPED_TEST(TestNumericCompareKernel, TestNullScalar) {
/* Ensure that null scalar broadcast to all null results. */
using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
@@ -201,8 +331,10 @@ TYPED_TEST(TestNumericCompareKernel, TestNullScalar) {
CompareOptions eq(CompareOperator::EQUAL);
ValidateCompare<TypeParam>(&this->ctx_, eq, "[]", null, "[]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, null, "[]", "[]");
ValidateCompare<TypeParam>(&this->ctx_, eq, "[null]", null, "[null]");
- ValidateCompare<TypeParam>(&this->ctx_, eq, "[1,2,3]", null, "[null, null, null]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, null, "[null]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, null, "[1,2,3]", "[null, null, null]");
}
TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes);
@@ -213,12 +345,43 @@ TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayScalar) {
auto rand = random::RandomArrayGenerator(0x5416447);
for (size_t i = 3; i < 13; i++) {
for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
- for (auto length_adjust : {-2, -1, 0, 1, 2}) {
- int64_t length = (1UL << i) + length_adjust;
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
auto array = Datum(rand.Numeric<TypeParam>(length, 0, 100, null_probability));
- auto zero = Datum(std::make_shared<ScalarType>(CType(50)));
- auto options = CompareOptions(GREATER);
- ValidateCompare<TypeParam>(&this->ctx_, options, array, zero);
+ auto fifty = Datum(std::make_shared<ScalarType>(CType(50)));
+ auto options = CompareOptions(op);
+ ValidateCompare<TypeParam>(&this->ctx_, options, array, fifty);
+ ValidateCompare<TypeParam>(&this->ctx_, options, fifty, array);
+ }
+ }
+ }
+}
+
+TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayArray) {
+ /* Ensure that null scalar broadcast to all null results. */
+ CompareOptions eq(CompareOperator::EQUAL);
+ ValidateCompare<TypeParam>(&this->ctx_, eq, "[]", "[]", "[]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, "[null]", "[null]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, "[1]", "[1]", "[1]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, "[1]", "[2]", "[0]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, "[null]", "[1]", "[null]");
+ ValidateCompare<TypeParam>(&this->ctx_, eq, "[1]", "[null]", "[null]");
+
+ CompareOptions lte(CompareOperator::LESS_EQUAL);
+ ValidateCompare<TypeParam>(&this->ctx_, lte, "[1,2,3,4,5]", "[2,3,4,5,6]",
+ "[1,1,1,1,1]");
+}
+
+TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayArray) {
+ auto rand = random::RandomArrayGenerator(0x5416447);
+ for (size_t i = 3; i < 5; i++) {
+ for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
+ for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) {
+ const int64_t length = static_cast<int64_t>(1ULL << i);
+ auto lhs = Datum(rand.Numeric<TypeParam>(length << i, 0, 100, null_probability));
+ auto rhs = Datum(rand.Numeric<TypeParam>(length << i, 0, 100, null_probability));
+ auto options = CompareOptions(op);
+ ValidateCompare<TypeParam>(&this->ctx_, options, lhs, rhs);
}
}
}
diff --git a/cpp/src/arrow/compute/kernels/filter.cc b/cpp/src/arrow/compute/kernels/filter.cc
index d7fbf54..1cbf0dc 100644
--- a/cpp/src/arrow/compute/kernels/filter.cc
+++ b/cpp/src/arrow/compute/kernels/filter.cc
@@ -19,6 +19,7 @@
#include "arrow/array.h"
#include "arrow/compute/kernel.h"
+#include "arrow/util/logging.h"
namespace arrow {
@@ -30,11 +31,28 @@ std::shared_ptr<DataType> FilterBinaryKernel::out_type() const {
Status FilterBinaryKernel::Call(FunctionContext* ctx, const Datum& left,
const Datum& right, Datum* out) {
- auto array = left.array();
- auto scalar = right.scalar();
- auto result = out->array();
-
- return filter_function_->Filter(*array, *scalar, result.get());
+ DCHECK(left.type()->Equals(right.type()));
+
+ auto lk = left.kind();
+ auto rk = right.kind();
+ auto out_array = out->array();
+
+ if (lk == Datum::ARRAY && rk == Datum::SCALAR) {
+ auto array = left.array();
+ auto scalar = right.scalar();
+ return filter_function_->Filter(*array, *scalar, &out_array);
+ } else if (lk == Datum::SCALAR && rk == Datum::ARRAY) {
+ auto scalar = left.scalar();
+ auto array = right.array();
+ auto out_array = out->array();
+ return filter_function_->Filter(*scalar, *array, &out_array);
+ } else if (lk == Datum::ARRAY && rk == Datum::ARRAY) {
+ auto lhs = left.array();
+ auto rhs = right.array();
+ return filter_function_->Filter(*lhs, *rhs, &out_array);
+ }
+
+ return Status::Invalid("Invalid datum signature for FilterBinaryKernel");
}
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/filter.h b/cpp/src/arrow/compute/kernels/filter.h
index becd2d5..3b28bc9 100644
--- a/cpp/src/arrow/compute/kernels/filter.h
+++ b/cpp/src/arrow/compute/kernels/filter.h
@@ -39,9 +39,31 @@ struct Datum;
class ARROW_EXPORT FilterFunction {
public:
/// Filter an array with a scalar argument.
- virtual Status Filter(const ArrayData& input, const Scalar& scalar,
+ virtual Status Filter(const ArrayData& array, const Scalar& scalar,
ArrayData* output) const = 0;
+ Status Filter(const ArrayData& array, const Scalar& scalar,
+ std::shared_ptr<ArrayData>* output) {
+ return Filter(array, scalar, output->get());
+ }
+
+ virtual Status Filter(const Scalar& scalar, const ArrayData& array,
+ ArrayData* output) const = 0;
+
+ Status Filter(const Scalar& scalar, const ArrayData& array,
+ std::shared_ptr<ArrayData>* output) {
+ return Filter(scalar, array, output->get());
+ }
+
+ /// Filter an array with an array argument.
+ virtual Status Filter(const ArrayData& lhs, const ArrayData& rhs,
+ ArrayData* output) const = 0;
+
+ Status Filter(const ArrayData& lhs, const ArrayData& rhs,
+ std::shared_ptr<ArrayData>* output) {
+ return Filter(lhs, rhs, output->get());
+ }
+
/// By default, FilterFunction emits a result bitmap.
virtual std::shared_ptr<DataType> out_type() const { return boolean(); }
@@ -57,6 +79,13 @@ class ARROW_EXPORT FilterBinaryKernel : public BinaryKernel {
Status Call(FunctionContext* ctx, const Datum& left, const Datum& right,
Datum* out) override;
+ static int64_t out_length(const Datum& left, const Datum& right) {
+ if (left.kind() == Datum::ARRAY) return left.length();
+ if (right.kind() == Datum::ARRAY) return right.length();
+
+ return 0;
+ }
+
std::shared_ptr<DataType> out_type() const override;
private:
diff --git a/cpp/src/arrow/compute/kernels/util-internal.cc b/cpp/src/arrow/compute/kernels/util-internal.cc
index 2f94407..f29badf 100644
--- a/cpp/src/arrow/compute/kernels/util-internal.cc
+++ b/cpp/src/arrow/compute/kernels/util-internal.cc
@@ -236,6 +236,11 @@ Status PropagateNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* o
return Status::OK();
}
+Status PropagateNulls(FunctionContext* ctx, const ArrayData& lhs, const ArrayData& rhs,
+ ArrayData* output) {
+ return AssignNullIntersection(ctx, lhs, rhs, output);
+}
+
Status SetAllNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* output) {
const int64_t length = input.length;
if (output->buffers.size() == 0) {
@@ -282,16 +287,11 @@ Status AssignNullIntersection(FunctionContext* ctx, const ArrayData& left,
Status PrimitiveAllocatingUnaryKernel::Call(FunctionContext* ctx, const Datum& input,
Datum* out) {
- std::vector<std::shared_ptr<Buffer>> data_buffers;
- const ArrayData& in_data = *input.array();
-
DCHECK_EQ(out->kind(), Datum::ARRAY);
-
ArrayData* result = out->array().get();
-
result->buffers.resize(2);
- const int64_t length = in_data.length;
+ const int64_t length = input.length();
// Allocate the value buffer
RETURN_NOT_OK(AllocateValueBuffer(ctx, *out_type(), length, &(result->buffers[1])));
return delegate_->Call(ctx, input, out);
@@ -306,17 +306,11 @@ PrimitiveAllocatingBinaryKernel::PrimitiveAllocatingBinaryKernel(BinaryKernel* d
Status PrimitiveAllocatingBinaryKernel::Call(FunctionContext* ctx, const Datum& left,
const Datum& right, Datum* out) {
- std::vector<std::shared_ptr<Buffer>> data_buffers;
- DCHECK_EQ(left.kind(), Datum::ARRAY);
- const ArrayData& left_data = *left.array();
-
DCHECK_EQ(out->kind(), Datum::ARRAY);
-
ArrayData* result = out->array().get();
-
result->buffers.resize(2);
- const int64_t length = left_data.length;
+ const int64_t length = result->length;
RETURN_NOT_OK(AllocateValueBuffer(ctx, *out_type(), length, &(result->buffers[1])));
// Allocate the value buffer
diff --git a/cpp/src/arrow/compute/kernels/util-internal.h b/cpp/src/arrow/compute/kernels/util-internal.h
index 25a670c..efd990f 100644
--- a/cpp/src/arrow/compute/kernels/util-internal.h
+++ b/cpp/src/arrow/compute/kernels/util-internal.h
@@ -72,6 +72,18 @@ Status InvokeBinaryArrayKernel(FunctionContext* ctx, BinaryKernel* kernel,
ARROW_EXPORT
Status PropagateNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* output);
+/// \brief Assign validity bitmap to output, copying and computing the
+/// intersection bitmap if necessary, but zero-copy if possible, so that the
+/// same value slots are valid/not-null in the output (sliced arrays).
+///
+/// \param[in] ctx the kernel FunctionContext
+/// \param[in] left the left input array
+/// \param[in] right the right input array
+/// \param[out] output the output array. Must have length set correctly.
+ARROW_EXPORT
+Status PropagateNulls(FunctionContext* ctx, const ArrayData& left, const ArrayData& right,
+ ArrayData* output);
+
/// \brief Set validity bitmap in output with all null values.
///
/// \param[in] ctx the kernel FunctionContext