You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/28 13:58:55 UTC

[arrow] branch main updated: GH-36309: [C++] Add ability to cast between scalars of list-like types (#36310)

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e35911629 GH-36309: [C++] Add ability to cast between scalars of list-like types (#36310)
3e35911629 is described below

commit 3e35911629c2e5dfb58164e1f98cbbc232be1cd2
Author: Felipe Oliveira Carvalho <fe...@gmail.com>
AuthorDate: Wed Jun 28 10:58:47 2023 -0300

    GH-36309: [C++] Add ability to cast between scalars of list-like types (#36310)
    
    ### Rationale for this change
    
    Makes it easier to work generically with different list types.
    
    ### What changes are included in this PR?
    
    More cast implementations between list-like types of scalars.
    
    ### Are these changes tested?
    
    Yes. With new unit tests.
    * Closes: #36309
    
    Authored-by: Felipe Oliveira Carvalho <fe...@gmail.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/arrow/scalar.cc      | 43 +++++++++++++++++++++++++-
 cpp/src/arrow/scalar_test.cc | 73 +++++++++++++++++++++++++++++++++++++++-----
 2 files changed, 107 insertions(+), 9 deletions(-)

diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index 0537ddafe2..b2ad1ad519 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -1127,6 +1127,32 @@ Status CastImpl(const StructScalar& from, StringScalar* to) {
   return Status::OK();
 }
 
+// casts between variable-length and fixed-length list types
+template <typename ToScalar>
+enable_if_list_type<typename ToScalar::TypeClass, Status> CastImpl(
+    const BaseListScalar& from, ToScalar* to) {
+  if constexpr (sizeof(typename ToScalar::TypeClass::offset_type) < sizeof(int64_t)) {
+    if (from.value->length() >
+        std::numeric_limits<typename ToScalar::TypeClass::offset_type>::max()) {
+      return Status::Invalid(from.type->ToString(), " too large to cast to ",
+                             to->type->ToString());
+    }
+  }
+
+  if constexpr (is_fixed_size_list_type<typename ToScalar::TypeClass>::value) {
+    const auto& fixed_size_list_type = checked_cast<const FixedSizeListType&>(*to->type);
+    if (from.value->length() != fixed_size_list_type.list_size()) {
+      return Status::Invalid("Cannot cast ", from.type->ToString(), " of length ",
+                             from.value->length(), " to fixed size list of length ",
+                             fixed_size_list_type.list_size());
+    }
+  }
+
+  DCHECK_EQ(from.is_valid, to->is_valid);
+  to->value = from.value;
+  return Status::OK();
+}
+
 // list based types (list, large list and map (fixed sized list too)) to string
 Status CastImpl(const BaseListScalar& from, StringScalar* to) {
   std::stringstream ss;
@@ -1183,12 +1209,27 @@ struct FromTypeVisitor : CastImplVisitor {
 
   // identity cast only for parameter free types
   template <typename T1 = ToType>
-  typename std::enable_if<TypeTraits<T1>::is_parameter_free, Status>::type Visit(
+  typename std::enable_if_t<TypeTraits<T1>::is_parameter_free, Status> Visit(
       const ToType&) {
     checked_cast<ToScalar*>(out_)->value = checked_cast<const ToScalar&>(from_).value;
     return Status::OK();
   }
 
+  Status CastFromListLike(const BaseListType& base_list_type) {
+    return CastImpl(checked_cast<const BaseListScalar&>(from_),
+                    checked_cast<ToScalar*>(out_));
+  }
+
+  Status Visit(const ListType& list_type) { return CastFromListLike(list_type); }
+
+  Status Visit(const LargeListType& large_list_type) {
+    return CastFromListLike(large_list_type);
+  }
+
+  Status Visit(const FixedSizeListType& fixed_size_list_type) {
+    return CastFromListLike(fixed_size_list_type);
+  }
+
   Status Visit(const NullType&) { return NotImplemented(); }
   Status Visit(const DictionaryType&) { return NotImplemented(); }
   Status Visit(const ExtensionType&) { return NotImplemented(); }
diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc
index 1c7572a1c8..b66857717f 100644
--- a/cpp/src/arrow/scalar_test.cc
+++ b/cpp/src/arrow/scalar_test.cc
@@ -334,13 +334,12 @@ class TestRealScalar : public ::testing::Test {
     ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options));
   }
 
-  void TestListOf() {
-    auto ty = list(type_);
-
-    ListScalar list_val(ArrayFromJSON(type_, "[0, null, 1.0]"), ty);
-    ListScalar list_other_val(ArrayFromJSON(type_, "[0, null, 1.1]"), ty);
-    ListScalar list_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty);
-    ListScalar list_other_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty);
+  template <typename ListScalarClass>
+  void TestListOf(const std::shared_ptr<DataType>& list_ty) {
+    ListScalarClass list_val(ArrayFromJSON(type_, "[0, null, 1.0]"), list_ty);
+    ListScalarClass list_other_val(ArrayFromJSON(type_, "[0, null, 1.1]"), list_ty);
+    ListScalarClass list_nan(ArrayFromJSON(type_, "[0, null, NaN]"), list_ty);
+    ListScalarClass list_other_nan(ArrayFromJSON(type_, "[0, null, NaN]"), list_ty);
 
     EqualOptions options = EqualOptions::Defaults().atol(0.05);
     ASSERT_TRUE(list_val.Equals(list_val, options));
@@ -391,6 +390,10 @@ class TestRealScalar : public ::testing::Test {
     ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options));
   }
 
+  void TestListOf() { TestListOf<ListScalar>(list(type_)); }
+
+  void TestLargeListOf() { TestListOf<LargeListScalar>(large_list(type_)); }
+
  protected:
   std::shared_ptr<DataType> type_;
   std::shared_ptr<Scalar> scalar_val_, scalar_other_, scalar_nan_, scalar_other_nan_,
@@ -409,6 +412,8 @@ TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); }
 
 TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); }
 
+TYPED_TEST(TestRealScalar, LargeListOf) { this->TestLargeListOf(); }
+
 template <typename T>
 class TestDecimalScalar : public ::testing::Test {
  public:
@@ -1058,6 +1063,27 @@ std::shared_ptr<DataType> MakeListType<FixedSizeListType>(
   return fixed_size_list(std::move(value_type), list_size);
 }
 
+template <typename ScalarType>
+void CheckListCast(const ScalarType& scalar, const std::shared_ptr<DataType>& to_type) {
+  EXPECT_OK_AND_ASSIGN(auto cast_scalar, scalar.CastTo(to_type));
+  ASSERT_OK(cast_scalar->ValidateFull());
+  ASSERT_EQ(*cast_scalar->type, *to_type);
+
+  ASSERT_EQ(scalar.is_valid, cast_scalar->is_valid);
+  ASSERT_TRUE(scalar.is_valid);
+  ASSERT_ARRAYS_EQUAL(*scalar.value,
+                      *checked_cast<const BaseListScalar&>(*cast_scalar).value);
+}
+
+template <typename ScalarType>
+void CheckInvalidListCast(const ScalarType& scalar,
+                          const std::shared_ptr<DataType>& to_type,
+                          std::string_view expected_message) {
+  EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT(StatusCode::Invalid,
+                                           ::testing::HasSubstr(expected_message),
+                                           scalar.CastTo(to_type));
+}
+
 template <typename T>
 class TestListScalar : public ::testing::Test {
  public:
@@ -1135,6 +1161,19 @@ class TestListScalar : public ::testing::Test {
     ASSERT_NE(a0->hash(), b0->hash());
   }
 
+  void TestCast() {
+    ScalarType scalar(value_);
+    CheckListCast(scalar, list(value_->type()));
+    CheckListCast(scalar, large_list(value_->type()));
+    CheckListCast(
+        scalar, fixed_size_list(value_->type(), static_cast<int32_t>(value_->length())));
+
+    CheckInvalidListCast(scalar, fixed_size_list(value_->type(), 5),
+                         "Cannot cast " + scalar.type->ToString() + " of length " +
+                             std::to_string(value_->length()) +
+                             " to fixed size list of length 5");
+  }
+
  protected:
   std::shared_ptr<DataType> type_;
   std::shared_ptr<Array> value_;
@@ -1148,7 +1187,9 @@ TYPED_TEST(TestListScalar, Basics) { this->TestBasics(); }
 
 TYPED_TEST(TestListScalar, ValidateErrors) { this->TestValidateErrors(); }
 
-TYPED_TEST(TestListScalar, TestHashing) { this->TestHashing(); }
+TYPED_TEST(TestListScalar, Hashing) { this->TestHashing(); }
+
+TYPED_TEST(TestListScalar, Cast) { this->TestCast(); }
 
 TEST(TestFixedSizeListScalar, ValidateErrors) {
   const auto ty = fixed_size_list(int16(), 3);
@@ -1176,6 +1217,22 @@ TEST(TestMapScalar, NullScalar) {
   CheckMakeNullScalar(map(utf8(), field("value", int8())));
 }
 
+TEST(TestMapScalar, Cast) {
+  auto key_value_type = struct_({field("key", utf8(), false), field("value", int8())});
+  auto value = ArrayFromJSON(key_value_type,
+                             R"([{"key": "a", "value": 1}, {"key": "b", "value": 2}])");
+  auto scalar = MapScalar(value);
+
+  CheckListCast(scalar, list(key_value_type));
+  CheckListCast(scalar, large_list(key_value_type));
+  CheckListCast(scalar, fixed_size_list(key_value_type, 2));
+
+  CheckInvalidListCast(scalar, fixed_size_list(key_value_type, 5),
+                       "Cannot cast " + scalar.type->ToString() + " of length " +
+                           std::to_string(value->length()) +
+                           " to fixed size list of length 5");
+}
+
 TEST(TestStructScalar, FieldAccess) {
   StructScalar abc({MakeScalar(true), MakeNullScalar(int32()), MakeScalar("hello"),
                     MakeNullScalar(int64())},