You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/28 14:01:32 UTC

[arrow] branch main updated: GH-36203: [C++] Support casting in both ways for is_in and index_in (#36204)

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f3bd2524c GH-36203: [C++] Support casting in both ways for is_in and index_in (#36204)
6f3bd2524c is described below

commit 6f3bd2524c2abe3a4a278fc1c62fc5c49b56cab3
Author: Jin Shang <sh...@gmail.com>
AuthorDate: Wed Jun 28 22:01:26 2023 +0800

    GH-36203: [C++] Support casting in both ways for is_in and index_in (#36204)
    
    
    
    ### Rationale for this change
    
    This is a follow up of https://github.com/apache/arrow/pull/36058#pullrequestreview-1488682384. Currently it only try to cast the value set to input type, not the other way around. This causes some valid input types to be rejected.
    
    ### What changes are included in this PR?
    
    The kernels will first try to case value_set to input type during preparation. If it doesn't work, it would try to cast input to value_set type before the lookup happens.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    Some previously rejected input types will now be valid.
    
    * Closes: #36203
    
    Lead-authored-by: Jin Shang <sh...@gmail.com>
    Co-authored-by: Antoine Pitrou <pi...@free.fr>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/arrow/compute/kernels/scalar_set_lookup.cc | 90 +++++++++++++++++-----
 .../compute/kernels/scalar_set_lookup_test.cc      | 87 +++++++++++++++++----
 2 files changed, 143 insertions(+), 34 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
index 2d72e75619..803dfbde9c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
@@ -31,12 +31,16 @@ namespace arrow {
 using internal::checked_cast;
 using internal::HashTraits;
 
-namespace compute {
-namespace internal {
+namespace compute::internal {
 namespace {
 
+// This base class enables non-templated access to the value set type
+struct SetLookupStateBase : public KernelState {
+  std::shared_ptr<DataType> value_set_type;
+};
+
 template <typename Type>
-struct SetLookupState : public KernelState {
+struct SetLookupState : public SetLookupStateBase {
   explicit SetLookupState(MemoryPool* pool) : memory_pool(pool) {}
 
   Status Init(const SetLookupOptions& options) {
@@ -65,6 +69,7 @@ struct SetLookupState : public KernelState {
     if (!options.skip_nulls && lookup_table->GetNull() >= 0) {
       null_index = memo_index_to_value_index[lookup_table->GetNull()];
     }
+    value_set_type = options.value_set.type();
     return Status::OK();
   }
 
@@ -115,11 +120,12 @@ struct SetLookupState : public KernelState {
 };
 
 template <>
-struct SetLookupState<NullType> : public KernelState {
+struct SetLookupState<NullType> : public SetLookupStateBase {
   explicit SetLookupState(MemoryPool*) {}
 
   Status Init(const SetLookupOptions& options) {
     value_set_has_null = (options.value_set.length() > 0) && !options.skip_nulls;
+    value_set_type = null();
     return Status::OK();
   }
 
@@ -215,16 +221,31 @@ struct InitStateVisitor {
       return Status::Invalid("Array type didn't match type of values set: ", *arg_type,
                              " vs ", *options.value_set.type());
     }
+
     if (!options.value_set.is_arraylike()) {
       return Status::Invalid("Set lookup value set must be Array or ChunkedArray");
     } else if (!options.value_set.type()->Equals(*arg_type)) {
-      ARROW_ASSIGN_OR_RAISE(
-          options.value_set,
+      auto cast_result =
           Cast(options.value_set, CastOptions::Safe(arg_type.GetSharedPtr()),
-               ctx->exec_context()));
+               ctx->exec_context());
+      if (cast_result.ok()) {
+        options.value_set = *cast_result;
+      } else if (CanCast(*arg_type.type, *options.value_set.type())) {
+        // Avoid casting from non binary types to string like above
+        // Otherwise, will try to cast input array to value set type during kernel exec
+        if ((options.value_set.type()->id() == Type::STRING ||
+             options.value_set.type()->id() == Type::LARGE_STRING) &&
+            !is_base_binary_like(arg_type.id())) {
+          return Status::Invalid("Array type didn't match type of values set: ",
+                                 *arg_type, " vs ", *options.value_set.type());
+        }
+      } else {
+        return Status::Invalid("Array type doesn't match type of values set: ", *arg_type,
+                               " vs ", *options.value_set.type());
+      }
     }
 
-    RETURN_NOT_OK(VisitTypeInline(*arg_type, this));
+    RETURN_NOT_OK(VisitTypeInline(*options.value_set.type(), this));
     return std::move(result);
   }
 };
@@ -263,15 +284,12 @@ struct IndexInVisitor {
   }
 
   template <typename Type>
-  Status ProcessIndexIn() {
+  Status ProcessIndexIn(const SetLookupState<Type>& state, const ArraySpan& input) {
     using T = typename GetViewType<Type>::T;
-
-    const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
-
     FirstTimeBitmapWriter bitmap_writer(out_bitmap, out->offset, out->length);
     int32_t* out_data = out->GetValues<int32_t>(1);
     VisitArraySpanInline<Type>(
-        data,
+        input,
         [&](T v) {
           int32_t index = state.lookup_table->Get(v);
           if (index != -1) {
@@ -303,6 +321,19 @@ struct IndexInVisitor {
     return Status::OK();
   }
 
+  template <typename Type>
+  Status ProcessIndexIn() {
+    const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
+    if (!data.type->Equals(state.value_set_type)) {
+      auto materialized_input = data.ToArrayData();
+      ARROW_ASSIGN_OR_RAISE(auto casted_input,
+                            Cast(*materialized_input, state.value_set_type,
+                                 CastOptions::Safe(), ctx->exec_context()));
+      return ProcessIndexIn(state, *casted_input.array());
+    }
+    return ProcessIndexIn(state, data);
+  }
+
   template <typename Type>
   enable_if_boolean<Type, Status> Visit(const Type&) {
     return ProcessIndexIn<BooleanType>();
@@ -331,7 +362,10 @@ struct IndexInVisitor {
     return ProcessIndexIn<MonthDayNanoIntervalType>();
   }
 
-  Status Execute() { return VisitTypeInline(*data.type, this); }
+  Status Execute() {
+    const auto& state = checked_cast<const SetLookupStateBase&>(*ctx->state());
+    return VisitTypeInline(*state.value_set_type, this);
+  }
 };
 
 Status ExecIndexIn(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
@@ -359,13 +393,11 @@ struct IsInVisitor {
   }
 
   template <typename Type>
-  Status ProcessIsIn() {
+  Status ProcessIsIn(const SetLookupState<Type>& state, const ArraySpan& input) {
     using T = typename GetViewType<Type>::T;
-    const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
-
     FirstTimeBitmapWriter writer(out->buffers[1].data, out->offset, out->length);
     VisitArraySpanInline<Type>(
-        this->data,
+        input,
         [&](T v) {
           if (state.lookup_table->Get(v) != -1) {
             writer.Set();
@@ -386,6 +418,20 @@ struct IsInVisitor {
     return Status::OK();
   }
 
+  template <typename Type>
+  Status ProcessIsIn() {
+    const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
+
+    if (!data.type->Equals(state.value_set_type)) {
+      auto materialized_data = data.ToArrayData();
+      ARROW_ASSIGN_OR_RAISE(auto casted_data,
+                            Cast(*materialized_data, state.value_set_type,
+                                 CastOptions::Safe(), ctx->exec_context()));
+      return ProcessIsIn(state, *casted_data.array());
+    }
+    return ProcessIsIn(state, data);
+  }
+
   template <typename Type>
   enable_if_boolean<Type, Status> Visit(const Type&) {
     return ProcessIsIn<BooleanType>();
@@ -413,7 +459,10 @@ struct IsInVisitor {
     return ProcessIsIn<MonthDayNanoIntervalType>();
   }
 
-  Status Execute() { return VisitTypeInline(*data.type, this); }
+  Status Execute() {
+    const auto& state = checked_cast<const SetLookupStateBase&>(*ctx->state());
+    return VisitTypeInline(*state.value_set_type, this);
+  }
 };
 
 Status ExecIsIn(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
@@ -566,6 +615,5 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) {
   }
 }
 
-}  // namespace internal
-}  // namespace compute
+}  // namespace compute::internal
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
index ba724d3549..762a4bfe5c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
@@ -126,25 +126,40 @@ TEST_F(TestIsInKernel, ImplicitlyCastValueSet) {
                                             "true, false, true, false]"));
   AssertArraysEqual(*expected, *out.make_array());
 
-  // fails; value_set cannot be cast to int8
-  opts = SetLookupOptions{ArrayFromJSON(float32(), "[2.5, 3.1, 5.0]")};
-  ASSERT_RAISES(Invalid, CallFunction("is_in", {input}, &opts));
+  // value_set cannot be casted to int8, but int8 is castable to float
+  CheckIsIn(input, ArrayFromJSON(float32(), "[1.0, 2.5, 3.1, 5.0]"),
+            "[false, true, false, false, false, true, false, false, false]");
 
   // Allow implicit casts between binary types...
-  CheckIsIn(ArrayFromJSON(binary(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+  CheckIsIn(ArrayFromJSON(binary(), R"(["aaa", "bb", "ccc", null, "bbb"])"),
             ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb"])"),
-            "[true, true, false, false, true]");
+            "[true, false, false, false, true]");
+  CheckIsIn(ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+            ArrayFromJSON(binary(), R"(["aa", "bbb"])"),
+            "[false, true, false, false, true]");
   CheckIsIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
             ArrayFromJSON(large_utf8(), R"(["aaa", "bbb"])"),
             "[true, true, false, false, true]");
+  CheckIsIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+            ArrayFromJSON(utf8(), R"(["aaa", "bbb"])"),
+            "[true, true, false, false, true]");
+
   // But explicitly deny implicit casts from non-binary to utf8 to
   // avoid surprises
   ASSERT_RAISES(Invalid,
                 IsIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
                      SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
+  ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+                              SetLookupOptions(ArrayFromJSON(
+                                  utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
+
   ASSERT_RAISES(Invalid,
                 IsIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
                      SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
+  ASSERT_RAISES(Invalid,
+                IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+                     SetLookupOptions(ArrayFromJSON(
+                         large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
 }
 
 template <typename Type>
@@ -253,11 +268,12 @@ TEST_F(TestIsInKernel, TimeDuration) {
               "[true, false, false, true, true]", /*skip_nulls=*/true);
   }
 
-  // Different units, invalid cast
-  ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
-                              ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 2]")));
+  // Different units, cast value_set to values will fail, then cast values to value_set
+  CheckIsIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
+            ArrayFromJSON(duration(TimeUnit::MILLI), "[1, 2, 2000]"),
+            "[false, false, true]");
 
-  // Different units, valid cast
+  // Different units, cast value_set to values
   CheckIsIn(ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 1, 2000]"),
             ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 2]"), "[true, false, true]");
 }
@@ -779,11 +795,12 @@ TEST_F(TestIndexInKernel, TimeDuration) {
   CheckIndexIn(duration(TimeUnit::SECOND), "[null, null, null, null]", "[null]",
                "[0, 0, 0, 0]");
 
-  // Different units, invalid cast
-  ASSERT_RAISES(Invalid, IndexIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
-                                 ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 2]")));
+  // Different units, cast value_set to values will fail, then cast values to value_set
+  CheckIndexIn(ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 2]"),
+               ArrayFromJSON(duration(TimeUnit::MILLI), "[1, 2, 2000]"),
+               "[null, null, 2]");
 
-  // Different units, valid cast
+  // Different units, cast value_set to values
   CheckIndexIn(ArrayFromJSON(duration(TimeUnit::MILLI), "[0, 1, 2000]"),
                ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 2]"), "[0, null, 1]");
 }
@@ -822,6 +839,50 @@ TEST_F(TestIndexInKernel, Boolean) {
   CheckIndexIn(boolean(), "[null, null, null, null]", "[null]", "[0, 0, 0, 0]");
 }
 
+TEST_F(TestIndexInKernel, ImplicitlyCastValueSet) {
+  auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]");
+
+  SetLookupOptions opts{ArrayFromJSON(int32(), "[2, 3, 5, 7]")};
+  ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("index_in", {input}, &opts));
+
+  auto expected = ArrayFromJSON(int32(), ("[null, null, 0, 1, null,"
+                                          "2, null, 3, null]"));
+  AssertArraysEqual(*expected, *out.make_array());
+
+  // Although value_set cannot be cast to int8, but int8 is castable to float
+  CheckIndexIn(input, ArrayFromJSON(float32(), "[1.0, 2.5, 3.1, 5.0]"),
+               "[null, 0, null, null, null, 3, null, null, null]");
+
+  // Allow implicit casts between binary types...
+  CheckIndexIn(ArrayFromJSON(binary(), R"(["aaa", "bb", "ccc", null, "bbb"])"),
+               ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb"])"),
+               "[0, null, null, null, 1]");
+  CheckIndexIn(
+      ArrayFromJSON(fixed_size_binary(3), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+      ArrayFromJSON(binary(), R"(["aa", "bbb"])"), "[null, 1, null, null, 1]");
+  CheckIndexIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+               ArrayFromJSON(large_utf8(), R"(["aaa", "bbb"])"), "[0, 1, null, null, 1]");
+  CheckIndexIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+               ArrayFromJSON(utf8(), R"(["aaa", "bbb"])"), "[0, 1, null, null, 1]");
+  // But explicitly deny implicit casts from non-binary to utf8 to
+  // avoid surprises
+  ASSERT_RAISES(Invalid,
+                IndexIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+                        SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
+  ASSERT_RAISES(Invalid, IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+                                 SetLookupOptions(ArrayFromJSON(
+                                     utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
+
+  ASSERT_RAISES(
+      Invalid,
+      IndexIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
+              SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
+  ASSERT_RAISES(Invalid,
+                IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+                        SetLookupOptions(ArrayFromJSON(
+                            large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
+}
+
 template <typename Type>
 class TestIndexInKernelBinary : public TestIndexInKernel {};