You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/07/01 14:37:45 UTC

[GitHub] [arrow] bkietz commented on a change in pull request #10557: ARROW-13064: [C++] Implement select ('case when') function for fixed-width types

bkietz commented on a change in pull request #10557:
URL: https://github.com/apache/arrow/pull/10557#discussion_r662348413



##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun
   }
 }
 
-}  // namespace
+// Helper to copy or broadcast fixed-width values between buffers.
+template <typename Type, typename Enable = void>
+struct CopyFixedWidth {};
+template <>
+struct CopyFixedWidth<BooleanType> {
+  static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset,
+                         const int64_t length) {
+    const bool value = UnboxScalar<BooleanType>::Unbox(scalar);
+    BitUtil::SetBitsTo(out_values, offset, length, value);
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset,
+                        const int64_t length) {
+    arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length,
+                                out_values, offset);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_number<Type>> {
+  using CType = typename TypeTraits<Type>::CType;
+  static void CopyScalar(const Scalar& values, uint8_t* raw_out_values,
+                         const int64_t offset, const int64_t length) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType value = UnboxScalar<Type>::Unbox(values);
+    std::fill(out_values + offset, out_values + offset + length, value);
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* raw_out_values,
+                        const int64_t offset, const int64_t length) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType* in_values = array.GetValues<CType>(1);
+    std::copy(in_values + offset, in_values + offset + length, out_values + offset);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> {
+  static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset,
+                         const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values);
+    // Scalar may have null value buffer
+    if (!scalar.value) return;
+    DCHECK_EQ(scalar.value->size(), width);
+    for (int i = 0; i < length; i++) {
+      std::memcpy(next, scalar.value->data(), width);
+      next += width;
+    }
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset,
+                        const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width);
+    std::memcpy(next, in_values, length * width);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_decimal<Type>> {
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset,
+                         const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto& scalar = checked_cast<const ScalarType&>(values);
+    const auto value = scalar.value.ToBytes();
+    for (int i = 0; i < length; i++) {
+      std::memcpy(next, value.data(), width);
+      next += width;
+    }
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset,
+                        const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto* in_values = array.GetValues<uint8_t>(1, (offset + array.offset) * width);
+    std::memcpy(next, in_values, length * width);
+  }
+};
+// Copy fixed-width values from a scalar/array datum into an output values buffer
+template <typename Type>
+void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values,
+                const int64_t offset, const int64_t length) {
+  using Copier = CopyFixedWidth<Type>;
+  if (values.is_scalar()) {
+    const auto& scalar = *values.scalar();
+    if (out_valid) {
+      BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid);
+    }
+    Copier::CopyScalar(scalar, out_values, offset, length);
+  } else {
+    const ArrayData& array = *values.array();
+    if (out_valid) {
+      if (array.MayHaveNulls()) {
+        arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset,
+                                    length, out_valid, offset);
+      } else {
+        BitUtil::SetBitsTo(out_valid, offset, length, true);
+      }
+    }
+    Copier::CopyArray(array, out_values, offset, length);
+  }
+}
+
+struct CaseWhenFunction : ScalarFunction {
+  using ScalarFunction::ScalarFunction;
+
+  Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
+    RETURN_NOT_OK(CheckArity(*values));
+    std::vector<ValueDescr> value_types;
+    for (size_t i = 0; i < values->size() - 1; i += 2) {
+      ValueDescr* cond = &(*values)[i];
+      if (cond->type->id() == Type::NA) {
+        cond->type = boolean();
+      }
+      if (cond->type->id() != Type::BOOL) {
+        return Status::Invalid("Condition arguments must be boolean, but argument ", i,
+                               " was ", cond->type->ToString());
+      }
+      value_types.push_back((*values)[i + 1]);
+    }
+    if (values->size() % 2 != 0) {
+      // Have an ELSE clause
+      value_types.push_back(values->back());
+    }
+    EnsureDictionaryDecoded(&value_types);
+    if (auto type = CommonNumeric(value_types)) {
+      ReplaceTypes(type, &value_types);
+    }
+
+    const DataType& common_values_type = *value_types.front().type;
+    auto next_type = value_types.cbegin();
+    for (size_t i = 0; i < values->size(); i += 2) {
+      if (!common_values_type.Equals(next_type->type)) {

Review comment:
       I think the masking 'when' and 'first_true_in' are definitely independently useful




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org