You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/29 15:00:59 UTC

[arrow] branch main updated: GH-36345: [C++] Prefer TypeError over Invalid in IsIn and IndexIn kernels (#36358)

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2455bc07e0 GH-36345: [C++] Prefer TypeError over Invalid in IsIn and IndexIn kernels (#36358)
2455bc07e0 is described below

commit 2455bc07e09cd5341d1fabdb293afbd07682f0b2
Author: Jin Shang <sh...@gmail.com>
AuthorDate: Thu Jun 29 23:00:52 2023 +0800

    GH-36345: [C++] Prefer TypeError over Invalid in IsIn and IndexIn kernels (#36358)
    
    ### Rationale for this change
    
    `TypeError` should be returned when `values` and `value_set` have incompatible types in IsIn and IndeIx.
    `Invalid` is still returned if the types are compatible but casting the values fails (for example because of overflow or truncation).
    
    ### What changes are included in this PR?
    
    When casting between types is not supported, return TypeError instead of Invalid.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    No.
    
    * Closes: #36345
    
    Authored-by: Jin Shang <sh...@gmail.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/arrow/compute/kernels/scalar_set_lookup.cc | 44 ++++++++++++++--------
 .../compute/kernels/scalar_set_lookup_test.cc      | 42 +++++++++++----------
 2 files changed, 51 insertions(+), 35 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
index 803dfbde9c..00d391653d 100644
--- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
@@ -209,7 +209,7 @@ struct InitStateVisitor {
       const auto& ty1 = checked_cast<const TimestampType&>(*arg_type);
       const auto& ty2 = checked_cast<const TimestampType&>(*options.value_set.type());
       if (ty1.timezone().empty() ^ ty2.timezone().empty()) {
-        return Status::Invalid(
+        return Status::TypeError(
             "Cannot compare timestamp with timezone to timestamp without timezone, got: ",
             ty1, " and ", ty2);
       }
@@ -218,8 +218,8 @@ struct InitStateVisitor {
       // This is a bit of a hack, but don't implicitly cast from a non-binary
       // type to string, since most types support casting to string and that
       // may lead to surprises. However, we do want most other implicit casts.
-      return Status::Invalid("Array type didn't match type of values set: ", *arg_type,
-                             " vs ", *options.value_set.type());
+      return Status::TypeError("Array type doesn't match type of values set: ", *arg_type,
+                               " vs ", *options.value_set.type());
     }
 
     if (!options.value_set.is_arraylike()) {
@@ -236,12 +236,12 @@ struct InitStateVisitor {
         if ((options.value_set.type()->id() == Type::STRING ||
              options.value_set.type()->id() == Type::LARGE_STRING) &&
             !is_base_binary_like(arg_type.id())) {
-          return Status::Invalid("Array type didn't match type of values set: ",
-                                 *arg_type, " vs ", *options.value_set.type());
+          return Status::TypeError("Array type doesn't match type of values set: ",
+                                   *arg_type, " vs ", *options.value_set.type());
         }
       } else {
-        return Status::Invalid("Array type doesn't match type of values set: ", *arg_type,
-                               " vs ", *options.value_set.type());
+        return Status::TypeError("Array type doesn't match type of values set: ",
+                                 *arg_type, " vs ", *options.value_set.type());
       }
     }
 
@@ -326,9 +326,16 @@ struct IndexInVisitor {
     const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
     if (!data.type->Equals(state.value_set_type)) {
       auto materialized_input = data.ToArrayData();
-      ARROW_ASSIGN_OR_RAISE(auto casted_input,
-                            Cast(*materialized_input, state.value_set_type,
-                                 CastOptions::Safe(), ctx->exec_context()));
+      auto cast_result = Cast(*materialized_input, state.value_set_type,
+                              CastOptions::Safe(), ctx->exec_context());
+      if (ARROW_PREDICT_FALSE(!cast_result.ok())) {
+        if (cast_result.status().IsNotImplemented()) {
+          return Status::TypeError("Array type doesn't match type of values set: ",
+                                   *data.type, " vs ", *state.value_set_type);
+        }
+        return cast_result.status();
+      }
+      auto casted_input = *cast_result;
       return ProcessIndexIn(state, *casted_input.array());
     }
     return ProcessIndexIn(state, data);
@@ -423,11 +430,18 @@ struct IsInVisitor {
     const auto& state = checked_cast<const SetLookupState<Type>&>(*ctx->state());
 
     if (!data.type->Equals(state.value_set_type)) {
-      auto materialized_data = data.ToArrayData();
-      ARROW_ASSIGN_OR_RAISE(auto casted_data,
-                            Cast(*materialized_data, state.value_set_type,
-                                 CastOptions::Safe(), ctx->exec_context()));
-      return ProcessIsIn(state, *casted_data.array());
+      auto materialized_input = data.ToArrayData();
+      auto cast_result = Cast(*materialized_input, state.value_set_type,
+                              CastOptions::Safe(), ctx->exec_context());
+      if (ARROW_PREDICT_FALSE(!cast_result.ok())) {
+        if (cast_result.status().IsNotImplemented()) {
+          return Status::TypeError("Array type doesn't match type of values set: ",
+                                   *data.type, " vs ", *state.value_set_type);
+        }
+        return cast_result.status();
+      }
+      auto casted_input = *cast_result;
+      return ProcessIsIn(state, *casted_input.array());
     }
     return ProcessIsIn(state, data);
   }
diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
index 762a4bfe5c..d1645eb8d9 100644
--- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc
@@ -146,17 +146,17 @@ TEST_F(TestIsInKernel, ImplicitlyCastValueSet) {
 
   // But explicitly deny implicit casts from non-binary to utf8 to
   // avoid surprises
-  ASSERT_RAISES(Invalid,
+  ASSERT_RAISES(TypeError,
                 IsIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
                      SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
-  ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
-                              SetLookupOptions(ArrayFromJSON(
-                                  utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
+  ASSERT_RAISES(TypeError, IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+                                SetLookupOptions(ArrayFromJSON(
+                                    utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
 
-  ASSERT_RAISES(Invalid,
+  ASSERT_RAISES(TypeError,
                 IsIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
                      SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
-  ASSERT_RAISES(Invalid,
+  ASSERT_RAISES(TypeError,
                 IsIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
                      SetLookupOptions(ArrayFromJSON(
                          large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
@@ -241,11 +241,11 @@ TEST_F(TestIsInKernel, TimeTimestamp) {
   }
 
   // Disallow mixing timezone-aware and timezone-naive values
-  ASSERT_RAISES(Invalid, IsIn(ArrayFromJSON(timestamp(TimeUnit::SECOND), "[0, 1, 2]"),
-                              SetLookupOptions(ArrayFromJSON(
-                                  timestamp(TimeUnit::SECOND, "UTC"), "[0, 2]"))));
+  ASSERT_RAISES(TypeError, IsIn(ArrayFromJSON(timestamp(TimeUnit::SECOND), "[0, 1, 2]"),
+                                SetLookupOptions(ArrayFromJSON(
+                                    timestamp(TimeUnit::SECOND, "UTC"), "[0, 2]"))));
   ASSERT_RAISES(
-      Invalid,
+      TypeError,
       IsIn(ArrayFromJSON(timestamp(TimeUnit::SECOND, "UTC"), "[0, 1, 2]"),
            SetLookupOptions(ArrayFromJSON(timestamp(TimeUnit::SECOND), "[0, 2]"))));
   // However, mixed timezones are allowed (underlying value is UTC)
@@ -741,11 +741,12 @@ TEST_F(TestIndexInKernel, TimeTimestamp) {
                "[0, 0, 0, 0]");
 
   // Disallow mixing timezone-aware and timezone-naive values
-  ASSERT_RAISES(Invalid, IndexIn(ArrayFromJSON(timestamp(TimeUnit::SECOND), "[0, 1, 2]"),
-                                 SetLookupOptions(ArrayFromJSON(
-                                     timestamp(TimeUnit::SECOND, "UTC"), "[0, 2]"))));
+  ASSERT_RAISES(TypeError,
+                IndexIn(ArrayFromJSON(timestamp(TimeUnit::SECOND), "[0, 1, 2]"),
+                        SetLookupOptions(ArrayFromJSON(timestamp(TimeUnit::SECOND, "UTC"),
+                                                       "[0, 2]"))));
   ASSERT_RAISES(
-      Invalid,
+      TypeError,
       IndexIn(ArrayFromJSON(timestamp(TimeUnit::SECOND, "UTC"), "[0, 1, 2]"),
               SetLookupOptions(ArrayFromJSON(timestamp(TimeUnit::SECOND), "[0, 2]"))));
   // However, mixed timezones are allowed (underlying value is UTC)
@@ -866,18 +867,19 @@ TEST_F(TestIndexInKernel, ImplicitlyCastValueSet) {
                ArrayFromJSON(utf8(), R"(["aaa", "bbb"])"), "[0, 1, null, null, 1]");
   // But explicitly deny implicit casts from non-binary to utf8 to
   // avoid surprises
-  ASSERT_RAISES(Invalid,
+  ASSERT_RAISES(TypeError,
                 IndexIn(ArrayFromJSON(utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
                         SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
-  ASSERT_RAISES(Invalid, IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
-                                 SetLookupOptions(ArrayFromJSON(
-                                     utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
+  ASSERT_RAISES(TypeError,
+                IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
+                        SetLookupOptions(ArrayFromJSON(
+                            utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));
 
   ASSERT_RAISES(
-      Invalid,
+      TypeError,
       IndexIn(ArrayFromJSON(large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"),
               SetLookupOptions(ArrayFromJSON(float64(), "[1.0, 2.0]"))));
-  ASSERT_RAISES(Invalid,
+  ASSERT_RAISES(TypeError,
                 IndexIn(ArrayFromJSON(float64(), "[1.0, 2.0]"),
                         SetLookupOptions(ArrayFromJSON(
                             large_utf8(), R"(["aaa", "bbb", "ccc", null, "bbb"])"))));