You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/06/30 19:31:05 UTC

[GitHub] [arrow] bkietz commented on a change in pull request #10557: ARROW-13064: [C++] Implement select ('case when') function for fixed-width types

bkietz commented on a change in pull request #10557:
URL: https://github.com/apache/arrow/pull/10557#discussion_r661729404



##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
   CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()});
 }
 
+void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs,

Review comment:
       Also worth noting: this seems to be very similar in function to CheckWithDifferentShapes. I think it'd be worthwhile to unify these two and promote it to test_util.h so it can be reused by more varargs scalar kernels

##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun
   }
 }
 
-}  // namespace
+// Helper to copy or broadcast fixed-width values between buffers.
+template <typename Type, typename Enable = void>
+struct CopyFixedWidth {};
+template <>
+struct CopyFixedWidth<BooleanType> {
+  static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset,
+                         const int64_t length) {
+    const bool value = UnboxScalar<BooleanType>::Unbox(scalar);
+    BitUtil::SetBitsTo(out_values, offset, length, value);
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset,
+                        const int64_t length) {
+    arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length,
+                                out_values, offset);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_number<Type>> {
+  using CType = typename TypeTraits<Type>::CType;
+  static void CopyScalar(const Scalar& values, uint8_t* raw_out_values,
+                         const int64_t offset, const int64_t length) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType value = UnboxScalar<Type>::Unbox(values);
+    std::fill(out_values + offset, out_values + offset + length, value);
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* raw_out_values,
+                        const int64_t offset, const int64_t length) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType* in_values = array.GetValues<CType>(1);
+    std::copy(in_values + offset, in_values + offset + length, out_values + offset);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> {
+  static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset,
+                         const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values);
+    // Scalar may have null value buffer
+    if (!scalar.value) return;
+    DCHECK_EQ(scalar.value->size(), width);
+    for (int i = 0; i < length; i++) {
+      std::memcpy(next, scalar.value->data(), width);
+      next += width;
+    }
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset,
+                        const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width);
+    std::memcpy(next, in_values, length * width);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_decimal<Type>> {
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset,
+                         const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto& scalar = checked_cast<const ScalarType&>(values);
+    const auto value = scalar.value.ToBytes();
+    for (int i = 0; i < length; i++) {
+      std::memcpy(next, value.data(), width);
+      next += width;
+    }
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset,
+                        const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width);
+    std::memcpy(next, in_values, length * width);
+  }
+};
+// Copy fixed-width values from a scalar/array datum into an output values buffer
+template <typename Type>
+void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values,
+                const int64_t offset, const int64_t length) {
+  using Copier = CopyFixedWidth<Type>;
+  if (values.is_scalar()) {
+    const auto& scalar = *values.scalar();
+    if (out_valid) {
+      BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid);
+    }
+    Copier::CopyScalar(scalar, out_values, offset, length);
+  } else {
+    const ArrayData& array = *values.array();
+    if (out_valid) {
+      if (array.MayHaveNulls()) {
+        arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset,
+                                    length, out_valid, offset);
+      } else {
+        BitUtil::SetBitsTo(out_valid, offset, length, true);
+      }
+    }
+    Copier::CopyArray(array, out_values, offset, length);
+  }
+}
+
+struct CaseWhenFunction : ScalarFunction {
+  using ScalarFunction::ScalarFunction;
+
+  Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+    RETURN_NOT_OK(CheckArity(*values));
+    std::vector<ValueDescr> value_types;
+    for (size_t i = 0; i < values->size() - 1; i += 2) {
+      ValueDescr* cond = &(*values)[i];
+      if (cond->type->id() == Type::NA) {
+        cond->type = boolean();
+      }
+      if (cond->type->id() != Type::BOOL) {
+        return Status::Invalid("Condition arguments must be boolean, but argument ", i,
+                               " was ", cond->type->ToString());
+      }
+      value_types.push_back((*values)[i + 1]);
+    }
+    if (values->size() % 2 != 0) {
+      // Have an ELSE clause
+      value_types.push_back(values->back());
+    }
+    EnsureDictionaryDecoded(&value_types);
+    if (auto type = CommonNumeric(value_types)) {
+      ReplaceTypes(type, &value_types);
+    }
+
+    const DataType& common_values_type = *value_types.front().type;
+    auto next_type = value_types.cbegin();
+    for (size_t i = 0; i < values->size(); i += 2) {
+      if (!common_values_type.Equals(next_type->type)) {

Review comment:
       It seems to me we could restructure "case_when" and yield a less daunting interface:
   
   ```python
   case(
     when(cond_0, value_0),
     when(cond_1, value_1),
     value_else
   )
   ```
   
   Where `when` masks slots with null wherever its condition is not true and `case` is only a variadic coalescing function (takes the first non null).
   
   We'd be allocating a new null bitmap on each call to `when` which is not ideal. However typing is far clearer (`when(cond: Boolean, value: T): T`, `case(...values: T): T`) and case/when can be independently unit tested.
   
   @pitrou  what do you think?

##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun
   }
 }
 
-}  // namespace
+// Helper to copy or broadcast fixed-width values between buffers.
+template <typename Type, typename Enable = void>
+struct CopyFixedWidth {};
+template <>
+struct CopyFixedWidth<BooleanType> {
+  static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset,
+                         const int64_t length) {
+    const bool value = UnboxScalar<BooleanType>::Unbox(scalar);
+    BitUtil::SetBitsTo(out_values, offset, length, value);
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset,
+                        const int64_t length) {
+    arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length,
+                                out_values, offset);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_number<Type>> {
+  using CType = typename TypeTraits<Type>::CType;
+  static void CopyScalar(const Scalar& values, uint8_t* raw_out_values,
+                         const int64_t offset, const int64_t length) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType value = UnboxScalar<Type>::Unbox(values);
+    std::fill(out_values + offset, out_values + offset + length, value);
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* raw_out_values,
+                        const int64_t offset, const int64_t length) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType* in_values = array.GetValues<CType>(1);
+    std::copy(in_values + offset, in_values + offset + length, out_values + offset);

Review comment:
       std::copy allows the ranges to overlap (unlike memcpy), so for simple pointers like this it gets [inlined to memmove](https://godbolt.org/z/K1ov93rYb). Memmove can be slower (or faster) than memcpy; depends on your libc. It shouldn't differ by a very wide margin though so I don't think it's necessary to avoid std::copy

##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
   CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()});
 }
 
+void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs,
+                  Datum expected) {
+  ASSERT_OK_AND_ASSIGN(Datum datum_out, CallFunction(name, inputs));
+  if (datum_out.is_array()) {
+    std::shared_ptr<Array> result = datum_out.make_array();
+    ASSERT_OK(result->ValidateFull());
+    std::shared_ptr<Array> expected_ = expected.make_array();
+    AssertArraysEqual(*expected_, *result, /*verbose=*/true);
+
+    for (int64_t i = 0; i < result->length(); i++) {
+      // Check scalar
+      ASSERT_OK_AND_ASSIGN(auto expected_scalar, expected_->GetScalar(i));
+      std::vector<Datum> inputs_scalar;
+      for (const auto& input : inputs) {
+        if (input.is_scalar()) {
+          inputs_scalar.push_back(input);
+        } else {
+          auto array = input.make_array();
+          ASSERT_OK_AND_ASSIGN(auto input_scalar, array->GetScalar(i));
+          inputs_scalar.push_back(input_scalar);
+        }
+      }
+      ASSERT_OK_AND_ASSIGN(auto scalar_out, CallFunction(name, inputs_scalar));
+      ASSERT_TRUE(scalar_out.is_scalar());
+      AssertScalarsEqual(*expected_scalar, *scalar_out.scalar(), /*verbose=*/true);
+
+      // Check slice
+      inputs_scalar.clear();
+      auto expected_array = expected_->Slice(i);
+      for (const auto& input : inputs) {
+        if (input.is_scalar()) {
+          inputs_scalar.push_back(input);
+        } else {
+          inputs_scalar.push_back(input.make_array()->Slice(i));
+        }
+      }
+      ASSERT_OK_AND_ASSIGN(auto array_out, CallFunction(name, inputs_scalar));
+      ASSERT_TRUE(array_out.is_array());
+      AssertArraysEqual(*expected_array, *array_out.make_array(), /*verbose=*/true);
+    }
+  } else {
+    const std::shared_ptr<Scalar>& result = datum_out.scalar();
+    const std::shared_ptr<Scalar>& expected_ = expected.scalar();
+    AssertScalarsEqual(*expected_, *result, /*verbose=*/true);
+  }
+}
+
+template <typename Type>
+class TestCaseWhenNumeric : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes);
+
+void CheckCaseWhenCases(const std::shared_ptr<DataType>& type, const std::string& value1,
+                        const std::string& value2) {
+  auto scalar_true = ScalarFromJSON(boolean(), "true");
+  auto scalar_false = ScalarFromJSON(boolean(), "false");
+  auto scalar_null = ScalarFromJSON(boolean(), "null");
+  auto cond1 = ArrayFromJSON(boolean(), "[true, false, false, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]");
+  auto value_null = ScalarFromJSON(type, "null");
+  auto scalar1 = ScalarFromJSON(type, value1);
+  auto scalar2 = ScalarFromJSON(type, value2);
+  auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+  std::stringstream builder;
+  builder << "[null, " << value1 << ',' << value1 << ',' << value1 << ']';
+  auto values1 = ArrayFromJSON(type, builder.str());
+  builder.str("");
+  builder << '[' << value2 << ',' << value2 << ',' << value2 << ',' << value2 << ']';
+  auto values2 = ArrayFromJSON(type, builder.str());
+  // N.B. all-scalar cases are checked in CheckCaseWhen
+  // Only an else array
+  CheckVarArgs("case_when", {values1}, values1);
+  // No else clause, scalar cond, array values
+  CheckVarArgs("case_when", {scalar_true, values1}, values1);
+  CheckVarArgs("case_when", {scalar_false, values1}, values_null);
+  CheckVarArgs("case_when", {scalar_null, values1}, values_null);
+  CheckVarArgs("case_when", {scalar_true, values1, scalar_null, values1}, values1);
+  CheckVarArgs("case_when", {scalar_null, values2, scalar_true, values1}, values1);
+  CheckVarArgs("case_when", {scalar_true, values1, scalar_true, values2}, values1);
+  // No else clause, array cond, scalar values
+  builder.str("");
+  builder << '[' << value1 << ", null, null, null]";
+  CheckVarArgs("case_when", {cond1, scalar1}, ArrayFromJSON(type, builder.str()));
+  CheckVarArgs("case_when", {cond1, value_null}, values_null);
+  builder.str("");
+  builder << '[' << value1 << ", null, null, " << value2 << ']';
+  CheckVarArgs("case_when", {cond1, scalar1, cond2, scalar2},
+               ArrayFromJSON(type, builder.str()));
+  // No else clause, array cond, array values
+  builder.str("");
+  builder << "[null, null, null, " << value2 << ']';
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2},
+               ArrayFromJSON(type, builder.str()));
+  // Else clauses/mixed scalar and array
+  builder.str("");
+  builder << "[null, " << value1 << ',' << value1 << ',' << value2 << ']';
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2, scalar1},
+               ArrayFromJSON(type, builder.str()));
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2, values1},
+               ArrayFromJSON(type, builder.str()));
+}
+
+TYPED_TEST(TestCaseWhenNumeric, FixedSize) {
+  auto type = default_type_instance<TypeParam>();
+  CheckCaseWhenCases(type, "10", "42");

Review comment:
       I agree; there's too much indirection here to see what's actually being tested. Please inline this some

##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -316,5 +318,165 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
   CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()});
 }
 
+void CheckVarArgs(const std::string& name, const std::vector<Datum>& inputs,
+                  Datum expected) {
+  ASSERT_OK_AND_ASSIGN(Datum datum_out, CallFunction(name, inputs));
+  if (datum_out.is_array()) {
+    std::shared_ptr<Array> result = datum_out.make_array();
+    ASSERT_OK(result->ValidateFull());
+    std::shared_ptr<Array> expected_ = expected.make_array();
+    AssertArraysEqual(*expected_, *result, /*verbose=*/true);
+
+    for (int64_t i = 0; i < result->length(); i++) {
+      // Check scalar
+      ASSERT_OK_AND_ASSIGN(auto expected_scalar, expected_->GetScalar(i));
+      std::vector<Datum> inputs_scalar;
+      for (const auto& input : inputs) {
+        if (input.is_scalar()) {
+          inputs_scalar.push_back(input);
+        } else {
+          auto array = input.make_array();
+          ASSERT_OK_AND_ASSIGN(auto input_scalar, array->GetScalar(i));
+          inputs_scalar.push_back(input_scalar);
+        }
+      }
+      ASSERT_OK_AND_ASSIGN(auto scalar_out, CallFunction(name, inputs_scalar));
+      ASSERT_TRUE(scalar_out.is_scalar());
+      AssertScalarsEqual(*expected_scalar, *scalar_out.scalar(), /*verbose=*/true);
+
+      // Check slice
+      inputs_scalar.clear();
+      auto expected_array = expected_->Slice(i);
+      for (const auto& input : inputs) {
+        if (input.is_scalar()) {
+          inputs_scalar.push_back(input);
+        } else {
+          inputs_scalar.push_back(input.make_array()->Slice(i));
+        }
+      }
+      ASSERT_OK_AND_ASSIGN(auto array_out, CallFunction(name, inputs_scalar));
+      ASSERT_TRUE(array_out.is_array());
+      AssertArraysEqual(*expected_array, *array_out.make_array(), /*verbose=*/true);
+    }
+  } else {
+    const std::shared_ptr<Scalar>& result = datum_out.scalar();
+    const std::shared_ptr<Scalar>& expected_ = expected.scalar();
+    AssertScalarsEqual(*expected_, *result, /*verbose=*/true);
+  }
+}
+
+template <typename Type>
+class TestCaseWhenNumeric : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes);
+
+void CheckCaseWhenCases(const std::shared_ptr<DataType>& type, const std::string& value1,
+                        const std::string& value2) {
+  auto scalar_true = ScalarFromJSON(boolean(), "true");
+  auto scalar_false = ScalarFromJSON(boolean(), "false");
+  auto scalar_null = ScalarFromJSON(boolean(), "null");
+  auto cond1 = ArrayFromJSON(boolean(), "[true, false, false, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]");
+  auto value_null = ScalarFromJSON(type, "null");
+  auto scalar1 = ScalarFromJSON(type, value1);
+  auto scalar2 = ScalarFromJSON(type, value2);
+  auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+  std::stringstream builder;
+  builder << "[null, " << value1 << ',' << value1 << ',' << value1 << ']';
+  auto values1 = ArrayFromJSON(type, builder.str());
+  builder.str("");
+  builder << '[' << value2 << ',' << value2 << ',' << value2 << ',' << value2 << ']';
+  auto values2 = ArrayFromJSON(type, builder.str());
+  // N.B. all-scalar cases are checked in CheckCaseWhen
+  // Only an else array
+  CheckVarArgs("case_when", {values1}, values1);
+  // No else clause, scalar cond, array values
+  CheckVarArgs("case_when", {scalar_true, values1}, values1);
+  CheckVarArgs("case_when", {scalar_false, values1}, values_null);
+  CheckVarArgs("case_when", {scalar_null, values1}, values_null);
+  CheckVarArgs("case_when", {scalar_true, values1, scalar_null, values1}, values1);
+  CheckVarArgs("case_when", {scalar_null, values2, scalar_true, values1}, values1);
+  CheckVarArgs("case_when", {scalar_true, values1, scalar_true, values2}, values1);
+  // No else clause, array cond, scalar values
+  builder.str("");
+  builder << '[' << value1 << ", null, null, null]";
+  CheckVarArgs("case_when", {cond1, scalar1}, ArrayFromJSON(type, builder.str()));
+  CheckVarArgs("case_when", {cond1, value_null}, values_null);
+  builder.str("");
+  builder << '[' << value1 << ", null, null, " << value2 << ']';
+  CheckVarArgs("case_when", {cond1, scalar1, cond2, scalar2},
+               ArrayFromJSON(type, builder.str()));
+  // No else clause, array cond, array values
+  builder.str("");
+  builder << "[null, null, null, " << value2 << ']';
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2},
+               ArrayFromJSON(type, builder.str()));
+  // Else clauses/mixed scalar and array
+  builder.str("");
+  builder << "[null, " << value1 << ',' << value1 << ',' << value2 << ']';
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2, scalar1},
+               ArrayFromJSON(type, builder.str()));
+  CheckVarArgs("case_when", {cond1, values1, cond2, values2, values1},
+               ArrayFromJSON(type, builder.str()));
+}
+
+TYPED_TEST(TestCaseWhenNumeric, FixedSize) {
+  auto type = default_type_instance<TypeParam>();
+  CheckCaseWhenCases(type, "10", "42");
+}
+
+TEST(TestCaseWhen, Null) {
+  auto scalar = ScalarFromJSON(null(), "null");
+  auto array = ArrayFromJSON(null(), "[null, null, null, null]");
+  CheckVarArgs("case_when", {array}, array);
+  CheckVarArgs("case_when", {scalar, array}, array);
+  CheckVarArgs("case_when", {scalar, array, array}, array);
+}
+
+TEST(TestCaseWhen, Boolean) { CheckCaseWhenCases(boolean(), "true", "false"); }
+
+TEST(TestCaseWhen, DayTimeInterval) {
+  CheckCaseWhenCases(day_time_interval(), "[10, 2]", "[2, 5]");
+}
+
+TEST(TestCaseWhen, Decimal) {
+  for (const auto& type :
+       std::vector<std::shared_ptr<DataType>>{decimal128(3, 2), decimal256(3, 2)}) {
+    CheckCaseWhenCases(type, "\"1.23\"", "\"4.56\"");
+  }
+}
+
+TEST(TestCaseWhen, FixedSizeBinary) {
+  auto type = fixed_size_binary(3);
+  CheckCaseWhenCases(type, "\"aaa\"", "\"bbb\"");
+}
+
+TEST(TestCaseWhen, DispatchBest) {
+  auto Check = [](std::vector<ValueDescr> original_values,

Review comment:
       This doesn't seem significantly different from `CheckDispatchBest`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org