You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/28 14:01:32 UTC
[arrow] branch main updated: GH-36203: [C++] Support casting in both ways for is_in and index_in (#36204)
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 6f3bd2524c GH-36203: [C++] Support casting in both ways for is_in and index_in (#36204)
6f3bd2524c is described below
commit 6f3bd2524c2abe3a4a278fc1c62fc5c49b56cab3
Author: Jin Shang <sh...@gmail.com>
AuthorDate: Wed Jun 28 22:01:26 2023 +0800
GH-36203: [C++] Support casting in both ways for is_in and index_in (#36204)
### Rationale for this change
This is a follow up of https://github.com/apache/arrow/pull/36058#pullrequestreview-1488682384. Currently it only try to cast the value set to input type, not the other way around. This causes some valid input types to be rejected.
### What changes are included in this PR?
The kernels will first try to case value_set to input type during preparation. If it doesn't work, it would try to cast input to value_set type before the lookup happens.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
Some previously rejected input types will now be valid.
* Closes: #36203
Lead-authored-by: Jin Shang <sh...@gmail.com>
Co-authored-by: Antoine Pitrou <pi...@free.fr>
Signed-off-by: Antoine Pitrou <an...@python.org>
---
cpp/src/arrow/compute/kernels/scalar_set_lookup.cc | 90 +++++++++++++++++-----
.../compute/kernels/scalar_set_lookup_test.cc | 87 +++++++++++++++++----
2 files changed, 143 insertions(+), 34 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
index 2d72e75619..803dfbde9c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
@@ -31,12 +31,16 @@ namespace arrow {
using internal::checked_cast;
using internal::HashTraits;
-namespace compute {
-namespace internal {
+namespace compute::internal {
namespace {
+// This base class enables non-templated access to the value set type
+struct SetLookupStateBase : public KernelState {
+ std::shared_ptr<DataType> value_set_type;
+};
+
template <typename Type>
-struct SetLookupState : public KernelState {
+struct SetLookupState : public SetLookupStateBase {
explicit SetLookupState(MemoryPool* pool) : memory_pool(pool) {}
Status Init(const SetLookupOptions& options) {
@@ -65,6 +69,7 @@ struct SetLookupState : public KernelState {
if (!options.skip_nulls && lookup_table->GetNull() >= 0) {
null_index = memo_index_to_value_index[lookup_table->GetNull()];
}
+ value_set_type = options.value_set.type();
return Status::OK();
}
@@ -115,11 +120,12 @@ struct SetLookupState : public KernelState {
};
template <>
-struct SetLookupState<NullType> : public KernelState {
+struct SetLookupState<NullType> : public SetLookupStateBase {
explicit SetLookupState(MemoryPool*) {}
Status Init(const SetLookupOptions& options) {
value_set_has_null = (options.value_set.length() > 0) && !options.skip_nulls;
+ value_set_type = null();
return Status::OK();
}
@@ -215,16 +221,31 @@ struct InitStateVisitor {
return Status::Invalid("Array type didn't match type of values set: ", *arg_type,
" vs ", *options.value_set.type());
}
+
if (!options.value_set.is_arraylike()) {
return Status::Invalid("Set lookup value set must be Array or ChunkedArray");
} else if (!options.value_set.type()->Equals(*arg_type)) {
- ARROW_ASSIGN_OR_RAISE(
- options.value_set,
+ auto cast_result =
Cast(options.value_set, CastOptions::Safe(arg_type.GetSharedPtr()),
- ctx->exec_context()));
+ ctx->exec_context());
+ if (cast_result.ok()) {
+ options.value_set = *cast_result;
+ } else if (CanCast(*arg_type.type, *options.value_set.type())) {
+ // Avoid casting from non binary types to string like above
+ // Otherwise, will try to cast input array to value set type during kernel exec
+ if ((options.value_set.type()->id() == Type::STRING ||
+ options.value_set.type()->id() == Type::LARGE_STRING) &&
+ !is_base_binary_like(arg_type.id())) {
+ return Status::Invalid("Array type didn't match type of values set: ",
+ *arg_type, " vs ", *options.value_set.type());
+ }
+ } else {
+ return Status::Invalid("Array type doesn't match type of values set: ", *arg_type,
+ " vs ", *options.value_set.type());
+ }
}
- RETURN_NOT_OK(VisitTypeInline(*arg_type, this));
+ RETURN_NOT_OK(VisitTypeInline(*options.value_set.type(), this));
return std::move(result);
}
};
@@ -263,15 +284,12 @@ struct IndexInVisitor {
}
template <typename Type>
- Status ProcessIndexIn() {
+ Status ProcessIndexIn(const SetLookupState<Type>& state, const ArraySpan& input) {
using T = typename GetViewType<Type>::T;
-
- const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
-
FirstTimeBitmapWriter bitmap_writer(out_bitmap, out->offset, out->length);
int32_t* out_data = out->GetValues<int32_t>(1);
VisitArraySpanInline<Type>(
- data,
+ input,
[&](T v) {
int32_t index = state.lookup_table->Get(v);
if (index != -1) {
@@ -303,6 +321,19 @@ struct IndexInVisitor {
return Status::OK();
}
+ template <typename Type>
+ Status ProcessIndexIn() {
+ const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
+ if (!data.type->Equals(state.value_set_type)) {
+ auto materialized_input = data.ToArrayData();
+ ARROW_ASSIGN_OR_RAISE(auto casted_input,
+ Cast(*materialized_input, state.value_set_type,
+ CastOptions::Safe(), ctx->exec_context()));
+ return ProcessIndexIn(state, *casted_input.array());
+ }
+ return ProcessIndexIn(state, data);
+ }
+
template <typename Type>
enable_if_boolean<Type, Status> Visit(const Type&) {
return ProcessIndexIn<BooleanType>();
@@ -331,7 +362,10 @@ struct IndexInVisitor {
return ProcessIndexIn<MonthDayNanoIntervalType>();
}
- Status Execute() { return VisitTypeInline(*data.type, this); }
+ Status Execute() {
+ const auto& state = checked_cast<const SetLookupStateBase&>(*ctx->state());
+ return VisitTypeInline(*state.value_set_type, this);
+ }
};
Status ExecIndexIn(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
@@ -359,13 +393,11 @@ struct IsInVisitor {
}
template <typename Type>
- Status ProcessIsIn() {
+ Status ProcessIsIn(const SetLookupState<Type>& state, const ArraySpan& input) {
using T = typename GetViewType<Type>::T;
- const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
-
FirstTimeBitmapWriter writer(out->buffers[1].data, out->offset, out->length);
VisitArraySpanInline<Type>(
- this->data,
+ input,
[&](T v) {
if (state.lookup_table->Get(v) != -1) {
writer.Set();
@@ -386,6 +418,20 @@ struct IsInVisitor {
return Status::OK();
}
+ template <typename Type>
+ Status ProcessIsIn() {
+ const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
+
+ if (!data.type->Equals(state.value_set_type)) {
+ auto materialized_data = data.ToArrayData();
+ ARROW_ASSIGN_OR_RAISE(auto casted_data,
+ Cast(*materialized_data, state.value_set_type,
+ CastOptions::Safe(), ctx->exec_context()));
+ return ProcessIsIn(state, *casted_data.array());
+ }
+ return ProcessIsIn(state, data);
+ }
+
template <typename Type>
enable_if_boolean<Type, Status> Visit(const Type&) {
return ProcessIsIn<BooleanType>();
@@ -413,7 +459,10 @@ struct IsInVisitor {
return ProcessIsIn<MonthDayNanoIntervalType>();
}
- Status Execute() { return VisitTypeInline(*data.type, this); }
+ Status Execute() {
+ const auto& state = checked_cast<const SetLookupStateBase&>(*ctx->state());
+ return VisitTypeInline(*state.value_set_type, this);
+ }
};
Status ExecIsIn(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
@@ -566,6 +615,5 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) {
}
}
-} // namespace internal
-} // namespace compute
+} // namespace compute::internal
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
index ba724d3549..762a4bfe5c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
@@ -126,25 +126,40 @@ TEST_F(TestIsInKernel, ImplicitlyCastValueSet) {
"true, false, true, false]"));
AssertArraysEqual(*expected, *out.make_array());
- // fails; value_set cannot be cast to int8
- opts = SetLookupOptions{ArrayFromJSON(float32(), "[2.5, 3.1, 5.0]")};
- ASSERT_RAISES(Invalid, CallFunction("is_in", {input}, &opts));
+ // value_set cannot be casted to int8, but int8 is castable to float
+ CheckIsIn(input, ArrayFromJSON(float32(), "[1.0, 2.5, 3.1, 5.0]"),
+ "[false, true, false, false, false, true, false, false, false]");
// Allow implicit casts between binary types...
- CheckIsIn(ArrayFromJSON(binary(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+ CheckIsIn(ArrayFromJSON(binary(), R"(["aaa", "bb", "ccc", null, "bbb"])"),
ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb"])"),
- "[true, true, false, false, true]");
+ "[true, false, false, false, true]");
+ CheckIsIn(ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+ ArrayFromJSON(binary(), R"(["aa", "bbb"])"),
+ "[false, true, false, false, true]");
CheckIsIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
ArrayFromJSON(large_utf8(), R"(["aaa", "bbb"])"),
"[true, true, false, false, true]");
+ CheckIsIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+ ArrayFromJSON(utf8(), R"(["aaa", "bbb"])"),
+ "[true, true, false, false, true]");
+
// But explicitly deny implicit casts from non-binary to utf8 to
// avoid surprises
ASSERT_RAISES(Invalid,
IsIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
+ ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+ SetLookupOptions(ArrayFromJSON(
+ utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
+
ASSERT_RAISES(Invalid,
IsIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
+ ASSERT_RAISES(Invalid,
+ IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+ SetLookupOptions(ArrayFromJSON(
+ large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
}
template <typename Type>
@@ -253,11 +268,12 @@ TEST_F(TestIsInKernel, TimeDuration) {
"[true, false, false, true, true]", /*skip_nulls=*/true);
}
- // Different units, invalid cast
- ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
- ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 2]")));
+ // Different units, cast value_set to values will fail, then cast values to value_set
+ CheckIsIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
+ ArrayFromJSON(duration(TimeUnit::MILLI), "[1, 2, 2000]"),
+ "[false, false, true]");
- // Different units, valid cast
+ // Different units, cast value_set to values
CheckIsIn(ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 1, 2000]"),
ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 2]"), "[true, false, true]");
}
@@ -779,11 +795,12 @@ TEST_F(TestIndexInKernel, TimeDuration) {
CheckIndexIn(duration(TimeUnit::SECOND), "[null, null, null, null]", "[null]",
"[0, 0, 0, 0]");
- // Different units, invalid cast
- ASSERT_RAISES(Invalid, IndexIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
- ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 2]")));
+ // Different units, cast value_set to values will fail, then cast values to value_set
+ CheckIndexIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
+ ArrayFromJSON(duration(TimeUnit::MILLI), "[1, 2, 2000]"),
+ "[null, null, 2]");
- // Different units, valid cast
+ // Different units, cast value_set to values
CheckIndexIn(ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 1, 2000]"),
ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 2]"), "[0, null, 1]");
}
@@ -822,6 +839,50 @@ TEST_F(TestIndexInKernel, Boolean) {
CheckIndexIn(boolean(), "[null, null, null, null]", "[null]", "[0, 0, 0, 0]");
}
+TEST_F(TestIndexInKernel, ImplicitlyCastValueSet) {
+ auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]");
+
+ SetLookupOptions opts{ArrayFromJSON(int32(), "[2, 3, 5, 7]")};
+ ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("index_in", {input}, &opts));
+
+ auto expected = ArrayFromJSON(int32(), ("[null, null, 0, 1, null,"
+ "2, null, 3, null]"));
+ AssertArraysEqual(*expected, *out.make_array());
+
+ // Although value_set cannot be cast to int8, but int8 is castable to float
+ CheckIndexIn(input, ArrayFromJSON(float32(), "[1.0, 2.5, 3.1, 5.0]"),
+ "[null, 0, null, null, null, 3, null, null, null]");
+
+ // Allow implicit casts between binary types...
+ CheckIndexIn(ArrayFromJSON(binary(), R"(["aaa", "bb", "ccc", null, "bbb"])"),
+ ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb"])"),
+ "[0, null, null, null, 1]");
+ CheckIndexIn(
+ ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+ ArrayFromJSON(binary(), R"(["aa", "bbb"])"), "[null, 1, null, null, 1]");
+ CheckIndexIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+ ArrayFromJSON(large_utf8(), R"(["aaa", "bbb"])"), "[0, 1, null, null, 1]");
+ CheckIndexIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+ ArrayFromJSON(utf8(), R"(["aaa", "bbb"])"), "[0, 1, null, null, 1]");
+ // But explicitly deny implicit casts from non-binary to utf8 to
+ // avoid surprises
+ ASSERT_RAISES(Invalid,
+ IndexIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+ SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
+ ASSERT_RAISES(Invalid, IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+ SetLookupOptions(ArrayFromJSON(
+ utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
+
+ ASSERT_RAISES(
+ Invalid,
+ IndexIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+ SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
+ ASSERT_RAISES(Invalid,
+ IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+ SetLookupOptions(ArrayFromJSON(
+ large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
+}
+
template <typename Type>
class TestIndexInKernelBinary : public TestIndexInKernel {};