You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/04/14 11:25:34 UTC

[GitHub] [arrow] pitrou commented on a diff in pull request #12660: ARROW-15965: [C++][Python] Add Scalar constructor of RoundToMultipleOptions to Python

pitrou commented on code in PR #12660:
URL: https://github.com/apache/arrow/pull/12660#discussion_r850336110


##########
python/pyarrow/tests/test_compute.py:
##########
@@ -1567,8 +1568,10 @@ def test_round_to_multiple():
         assert pc.round_to_multiple(values, multiple,
                                     "half_towards_infinity") == result
 
-    with pytest.raises(pa.ArrowInvalid, match="multiple must be positive"):
-        pc.round_to_multiple(values, multiple=-2)
+    for multiple in [-2, pa.scalar(-10.4)]:
+        with pytest.raises(pa.ArrowInvalid,
+                           match="multiple must be nonnegative"):
+            pc.round_to_multiple(values, multiple=multiple)

Review Comment:
   Can you also add a test for `TypeError` when the multiple is not convertible to a scalar?



##########
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc:
##########
@@ -1453,20 +1447,20 @@ struct RoundToMultiple<ArrowType, kRoundMode, enable_if_decimal<ArrowType>> {
   bool has_halfway_point;
 
   explicit RoundToMultiple(const State& state, const DataType& out_ty)
-      : ty(checked_cast<const ArrowType&>(out_ty)) {
-    const auto& options = state.options;
-    DCHECK(options.multiple);
-    DCHECK(options.multiple->is_valid);
-    DCHECK(options.multiple->type->Equals(out_ty));
-    multiple = UnboxScalar<ArrowType>::Unbox(*options.multiple);
-    half_multiple = multiple;
-    half_multiple /= 2;
-    neg_half_multiple = -half_multiple;
-    has_halfway_point = multiple.low_bits() % 2 == 0;
-  }
+      : ty(checked_cast<const ArrowType&>(out_ty)),
+        multiple(UnboxScalar<ArrowType>::Unbox(*state.options.multiple)),
+        half_multiple(multiple / 2),
+        neg_half_multiple(-half_multiple),
+        has_halfway_point(multiple.low_bits() % 2 == 0) {}
 
   template <typename T = ArrowType, typename CType = typename TypeTraits<T>::CType>
   enable_if_decimal_value<CType> Call(KernelContext* ctx, CType arg, Status* st) const {
+    // Return zeros if `multiple` option is zero.

Review Comment:
   I'm not sure why this is the expected result?



##########
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc:
##########
@@ -1200,70 +1201,71 @@ template <>
 struct RoundOptionsWrapper<RoundToMultipleOptions>
     : public OptionsWrapper<RoundToMultipleOptions> {
   using OptionsType = RoundToMultipleOptions;
-  using State = RoundOptionsWrapper<OptionsType>;
   using OptionsWrapper::OptionsWrapper;
 
   static Result<std::unique_ptr<KernelState>> Init(KernelContext* ctx,
                                                    const KernelInitArgs& args) {
-    std::unique_ptr<State> state;
-    if (auto options = static_cast<const OptionsType*>(args.options)) {
-      state = ::arrow::internal::make_unique<State>(*options);
-    } else {
+    auto options = static_cast<const OptionsType*>(args.options);
+    if (!options) {
       return Status::Invalid(
           "Attempted to initialize KernelState from null FunctionOptions");
     }
 
-    auto options = Get(*state);
-    const auto& type = *args.inputs[0].type;
-    if (!options.multiple || !options.multiple->is_valid) {
+    const auto& multiple = options->multiple;
+    if (!multiple || !multiple->is_valid) {
       return Status::Invalid("Rounding multiple must be non-null and valid");
     }
-    if (is_floating(type.id())) {
-      switch (options.multiple->type->id()) {
-        case Type::FLOAT: {
-          if (UnboxScalar<FloatType>::Unbox(*options.multiple) < 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::DOUBLE: {
-          if (UnboxScalar<DoubleType>::Unbox(*options.multiple) < 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::HALF_FLOAT:
-          return Status::NotImplemented("Half-float values are not supported");
-        default:
-          return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
-                                 *options.multiple->type);
-      }
+
+    // Ensure the rounding multiple option matches the kernel's output type.
+    // The output type is not available here so we use the following rule:
+    // If `multiple` is neither a floating-point nor a decimal type, then
+    // cast to float64, else cast to the kernel's input type.
+    std::shared_ptr<Scalar> resolved_multiple;
+    const auto& to_type =
+        (!is_floating(multiple->type->id()) && !is_decimal(multiple->type->id()))
+            ? float64()
+            : args.inputs[0].type;
+    bool is_casted = false;
+    if (!multiple->type->Equals(to_type)) {
+      ARROW_ASSIGN_OR_RAISE(
+          auto casted_multiple,
+          Cast(Datum(multiple), to_type, CastOptions::Safe(), ctx->exec_context()));
+      resolved_multiple = casted_multiple.scalar();
+      is_casted = true;
     } else {
-      DCHECK(is_decimal(type.id()));
-      if (!type.Equals(*options.multiple->type)) {
-        return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
-                               *options.multiple->type);
-      }
-      switch (options.multiple->type->id()) {
-        case Type::DECIMAL128: {
-          if (UnboxScalar<Decimal128Type>::Unbox(*options.multiple) <= 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        case Type::DECIMAL256: {
-          if (UnboxScalar<Decimal256Type>::Unbox(*options.multiple) <= 0) {
-            return Status::Invalid("Rounding multiple must be positive");
-          }
-          break;
-        }
-        default:
-          // This shouldn't happen
-          return Status::Invalid("Rounding multiple must be a ", type, " scalar, not ",
-                                 *options.multiple->type);
-      }
+      resolved_multiple = multiple;
     }
-    return std::move(state);
+
+    // NOTE: The positive value check can be simplified by using a comparison kernel.
+    bool is_negative = false;
+    switch (resolved_multiple->type->id()) {

Review Comment:
   An interesting way of writing this would be to use `VisitScalarInline`:
   ```c++
   bool IsNegative(const Scalar& scalar) {
     struct IsNegativeVisitor {
       bool result = false;
   
       template <typename... Ts>
       Status Visit(const NumericScalar<Ts...>& scalar) {
         result = scalar.value < 0;
         return Status::OK();
       }
       template <typename... Ts>
       Status Visit(const DecimalScalar<Ts...>& scalar) {
         result = scalar.value < 0;
         return Status::OK();
       }
       Status Visit(const Scalar& scalar) {
         return Status::OK();
       }
     };
     IsNegativeVisitor visitor{};
     std::ignore = VisitScalarInline(scalar, &visitor);
     return visitor.result;
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org