You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "liukun4515 (via GitHub)" <gi...@apache.org> on 2023/04/15 09:01:22 UTC

[GitHub] [arrow-rs] liukun4515 commented on a diff in pull request #4088: optimize cast for same decimal type and same scale

liukun4515 commented on code in PR #4088:
URL: https://github.com/apache/arrow-rs/pull/4088#discussion_r1167447570


##########
arrow-cast/src/cast.rs:
##########
@@ -2081,29 +2081,102 @@ impl DecimalCast for i256 {
     }
 }
 
-fn cast_decimal_to_decimal<I, O>(
-    array: &PrimitiveArray<I>,
-    input_scale: i8,
+fn cast_decimal_to_decimal_error<O>(
     output_precision: u8,
     output_scale: i8,
-    cast_options: &CastOptions,
-) -> Result<ArrayRef, ArrowError>
+) -> impl Fn(<O as ArrowPrimitiveType>::Native) -> ArrowError
 where
-    I: DecimalType,
     O: DecimalType,
-    I::Native: DecimalCast + ArrowNativeTypeOp,
     O::Native: DecimalCast + ArrowNativeTypeOp,
 {
-    let error = |x| {
+    move |x: O::Native| {
         ArrowError::CastError(format!(
             "Cannot cast to {}({}, {}). Overflowing on {:?}",
             O::PREFIX,
             output_precision,
             output_scale,
             x
         ))
+    }
+}
+
+fn cast_decimal_to_decimal_same_type<T>(
+    array: &PrimitiveArray<T>,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+    cast_options: &CastOptions,
+) -> Result<ArrayRef, ArrowError>
+where
+    T: DecimalType,
+    T::Native: DecimalCast + ArrowNativeTypeOp,
+{
+    let error = cast_decimal_to_decimal_error::<T>(output_precision, output_scale);
+
+    let array: PrimitiveArray<T> = if input_scale == output_scale {
+        // the scale doesn't change, the native value don't need to be changed
+        array.clone()
+    } else if input_scale > output_scale {
+        let div = T::Native::from_decimal(10_i128)
+            .unwrap()
+            .pow_checked((input_scale - output_scale) as u32)?;
+
+        let half = div.div_wrapping(T::Native::from_usize(2).unwrap());
+        let half_neg = half.neg_wrapping();
+
+        let f = |x: T::Native| {
+            // div is >= 10 and so this cannot overflow
+            let d = x.div_wrapping(div);
+            let r = x.mod_wrapping(div);
+
+            // Round result
+            let adjusted = match x >= T::Native::ZERO {
+                true if r >= half => d.add_wrapping(T::Native::ONE),
+                false if r <= half_neg => d.sub_wrapping(T::Native::ONE),
+                _ => d,
+            };
+            Some(adjusted)
+        };
+
+        match cast_options.safe {
+            true => array.unary_opt(f),
+            false => array.try_unary(|x| f(x).ok_or_else(|| error(x)))?,

Review Comment:
   Sorry, I didn't get your point.
   The result of `f(x)` must be Some(...) 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org