You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/03 22:19:48 UTC
[arrow-rs] branch master updated: Round instead of Truncate while casting float to decimal (#3000)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 61cf6f75c Round instead of Truncate while casting float to decimal (#3000)
61cf6f75c is described below
commit 61cf6f75c4e03ca950f750cb2fdba4adee534372
Author: Wei-Ting Kuo <wa...@gmail.com>
AuthorDate: Fri Nov 4 06:19:42 2022 +0800
Round instead of Truncate while casting float to decimal (#3000)
* add .round() before casting to integer
* add more test cases
* update test cases
* add doc
* Format
Co-authored-by: Raphael Taylor-Davies <r....@googlemail.com>
---
arrow/src/compute/kernels/cast.rs | 103 +++++++++++++++++++++++++++++---------
1 file changed, 79 insertions(+), 24 deletions(-)
diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs
index 4c724b640..4ad8dd99e 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -297,6 +297,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
/// * Time32 and Time64: precision lost when going to higher interval
/// * Timestamp and Date{32|64}: precision lost when going to higher interval
/// * Temporal to/from backing primitive: zero-copy with data type change
+/// * Casting from `float32/float64` to `Decimal(precision, scale)` rounds to the `scale` decimals
+/// (i.e. casting 6.4999 to Decimal(10, 1) becomes 6.5). This is the breaking change from `26.0.0`.
+/// It used to truncate it instead of round (i.e. outputs 6.4 instead)
///
/// Unsupported Casts
/// * To or from `StructArray`
@@ -353,7 +356,7 @@ where
{
let mul = 10_f64.powi(scale as i32);
- unary::<T, _, Decimal128Type>(array, |v| (v.as_() * mul) as i128)
+ unary::<T, _, Decimal128Type>(array, |v| (v.as_() * mul).round() as i128)
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
}
@@ -368,9 +371,11 @@ where
{
let mul = 10_f64.powi(scale as i32);
- unary::<T, _, Decimal256Type>(array, |v| i256::from_i128((v.as_() * mul) as i128))
- .with_precision_and_scale(precision, scale)
- .map(|a| Arc::new(a) as ArrayRef)
+ unary::<T, _, Decimal256Type>(array, |v| {
+ i256::from_i128((v.as_() * mul).round() as i128)
+ })
+ .with_precision_and_scale(precision, scale)
+ .map(|a| Arc::new(a) as ArrayRef)
}
/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`]
@@ -3192,8 +3197,8 @@ mod tests {
Some(2.2),
Some(4.4),
None,
- Some(1.123_456_7),
- Some(1.123_456_7),
+ Some(1.123_456_4), // round down
+ Some(1.123_456_7), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
@@ -3205,8 +3210,8 @@ mod tests {
Some(2200000_i128),
Some(4400000_i128),
None,
- Some(1123456_i128),
- Some(1123456_i128),
+ Some(1123456_i128), // round down
+ Some(1123457_i128), // round up
]
);
@@ -3216,9 +3221,10 @@ mod tests {
Some(2.2),
Some(4.4),
None,
- Some(1.123_456_789_123_4),
- Some(1.123_456_789_012_345_6),
- Some(1.123_456_789_012_345_6),
+ Some(1.123_456_489_123_4), // round up
+ Some(1.123_456_789_123_4), // round up
+ Some(1.123_456_489_012_345_6), // round down
+ Some(1.123_456_789_012_345_6), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
@@ -3230,9 +3236,10 @@ mod tests {
Some(2200000_i128),
Some(4400000_i128),
None,
- Some(1123456_i128),
- Some(1123456_i128),
- Some(1123456_i128),
+ Some(1123456_i128), // round down
+ Some(1123457_i128), // round up
+ Some(1123456_i128), // round down
+ Some(1123457_i128), // round up
]
);
}
@@ -3307,8 +3314,8 @@ mod tests {
Some(2.2),
Some(4.4),
None,
- Some(1.123_456_7),
- Some(1.123_456_7),
+ Some(1.123_456_4), // round down
+ Some(1.123_456_7), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
@@ -3320,8 +3327,8 @@ mod tests {
Some(i256::from_i128(2200000_i128)),
Some(i256::from_i128(4400000_i128)),
None,
- Some(i256::from_i128(1123456_i128)),
- Some(i256::from_i128(1123456_i128)),
+ Some(i256::from_i128(1123456_i128)), // round down
+ Some(i256::from_i128(1123457_i128)), // round up
]
);
@@ -3331,9 +3338,10 @@ mod tests {
Some(2.2),
Some(4.4),
None,
- Some(1.123_456_789_123_4),
- Some(1.123_456_789_012_345_6),
- Some(1.123_456_789_012_345_6),
+ Some(1.123_456_489_123_4), // round down
+ Some(1.123_456_789_123_4), // round up
+ Some(1.123_456_489_012_345_6), // round down
+ Some(1.123_456_789_012_345_6), // round up
]);
let array = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
@@ -3345,9 +3353,10 @@ mod tests {
Some(i256::from_i128(2200000_i128)),
Some(i256::from_i128(4400000_i128)),
None,
- Some(i256::from_i128(1123456_i128)),
- Some(i256::from_i128(1123456_i128)),
- Some(i256::from_i128(1123456_i128)),
+ Some(i256::from_i128(1123456_i128)), // round down
+ Some(i256::from_i128(1123457_i128)), // round up
+ Some(i256::from_i128(1123456_i128)), // round down
+ Some(i256::from_i128(1123457_i128)), // round up
]
);
}
@@ -5994,4 +6003,50 @@ mod tests {
.collect::<Vec<_>>();
assert_eq!(&out, &vec!["[0, 1, 2]", "[3, 4, 5]", "[6, 7]"]);
}
+
+ #[test]
+ #[cfg(not(feature = "force_validate"))]
+ fn test_cast_f64_to_decimal128() {
+ // to reproduce https://github.com/apache/arrow-rs/issues/2997
+
+ let decimal_type = DataType::Decimal128(18, 2);
+ let array = Float64Array::from(vec![
+ Some(0.0699999999),
+ Some(0.0659999999),
+ Some(0.0650000000),
+ Some(0.0649999999),
+ ]);
+ let array = Arc::new(array) as ArrayRef;
+ generate_cast_test_case!(
+ &array,
+ Decimal128Array,
+ &decimal_type,
+ vec![
+ Some(7_i128), // round up
+ Some(7_i128), // round up
+ Some(7_i128), // round up
+ Some(6_i128), // round down
+ ]
+ );
+
+ let decimal_type = DataType::Decimal128(18, 3);
+ let array = Float64Array::from(vec![
+ Some(0.0699999999),
+ Some(0.0659999999),
+ Some(0.0650000000),
+ Some(0.0649999999),
+ ]);
+ let array = Arc::new(array) as ArrayRef;
+ generate_cast_test_case!(
+ &array,
+ Decimal128Array,
+ &decimal_type,
+ vec![
+ Some(70_i128), // round up
+ Some(66_i128), // round up
+ Some(65_i128), // round down
+ Some(65_i128), // round up
+ ]
+ );
+ }
}