You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/11/23 18:33:52 UTC
[arrow-rs] branch master updated: Support decimal negative scale (#3152)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 78ab0ef3f Support decimal negative scale (#3152)
78ab0ef3f is described below
commit 78ab0ef3f6f422fd4b79a29504f0274220aaf74b
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Nov 23 10:33:47 2022 -0800
Support decimal negative scale (#3152)
* Support decimal negative scale
* Fix casting from numeric to negative scale decimal
* Fix clippy
---
arrow-array/src/array/primitive_array.rs | 15 +++--
arrow-array/src/types.rs | 33 +++++-----
arrow-cast/src/cast.rs | 102 ++++++++++++++++++++++++++-----
arrow-csv/src/reader.rs | 6 +-
arrow-data/src/decimal.rs | 6 +-
arrow-schema/src/datatype.rs | 4 +-
arrow-select/src/take.rs | 6 +-
arrow/benches/cast_kernels.rs | 4 +-
arrow/src/datatypes/ffi.rs | 4 +-
arrow/tests/array_transform.rs | 2 +-
10 files changed, 131 insertions(+), 51 deletions(-)
diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs
index f34c899e2..bd68b9698 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -1003,7 +1003,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
pub fn with_precision_and_scale(
self,
precision: u8,
- scale: u8,
+ scale: i8,
) -> Result<Self, ArrowError>
where
Self: Sized,
@@ -1024,7 +1024,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
fn validate_precision_scale(
&self,
precision: u8,
- scale: u8,
+ scale: i8,
) -> Result<(), ArrowError> {
if precision == 0 {
return Err(ArrowError::InvalidArgumentError(format!(
@@ -1046,7 +1046,14 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
T::MAX_SCALE
)));
}
- if scale > precision {
+ if scale < -T::MAX_SCALE {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "scale {} is smaller than min {}",
+ scale,
+ -Decimal128Type::MAX_SCALE
+ )));
+ }
+ if scale > 0 && scale as u8 > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than precision {}",
scale, precision
@@ -1102,7 +1109,7 @@ impl<T: DecimalType + ArrowPrimitiveType> PrimitiveArray<T> {
}
/// Returns the decimal scale of this array
- pub fn scale(&self) -> u8 {
+ pub fn scale(&self) -> i8 {
match T::BYTE_LENGTH {
16 => {
if let DataType::Decimal128(_, s) = self.data().data_type() {
diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index dd4d1ba42..40d262e8e 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -491,15 +491,15 @@ pub trait DecimalType:
{
const BYTE_LENGTH: usize;
const MAX_PRECISION: u8;
- const MAX_SCALE: u8;
- const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType;
+ const MAX_SCALE: i8;
+ const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType;
const DEFAULT_TYPE: DataType;
/// "Decimal128" or "Decimal256", for use in error messages
const PREFIX: &'static str;
/// Formats the decimal value with the provided precision and scale
- fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String;
+ fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String;
/// Validates that `value` contains no more than `precision` decimal digits
fn validate_decimal_precision(
@@ -515,14 +515,14 @@ pub struct Decimal128Type {}
impl DecimalType for Decimal128Type {
const BYTE_LENGTH: usize = 16;
const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION;
- const MAX_SCALE: u8 = DECIMAL128_MAX_SCALE;
- const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128;
+ const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE;
+ const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal128";
- fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
- format_decimal_str(&value.to_string(), precision as usize, scale as usize)
+ fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
+ format_decimal_str(&value.to_string(), precision as usize, scale)
}
fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> {
@@ -543,14 +543,14 @@ pub struct Decimal256Type {}
impl DecimalType for Decimal256Type {
const BYTE_LENGTH: usize = 32;
const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
- const MAX_SCALE: u8 = DECIMAL256_MAX_SCALE;
- const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256;
+ const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
+ const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const PREFIX: &'static str = "Decimal256";
- fn format_decimal(value: Self::Native, precision: u8, scale: u8) -> String {
- format_decimal_str(&value.to_string(), precision as usize, scale as usize)
+ fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String {
+ format_decimal_str(&value.to_string(), precision as usize, scale)
}
fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> {
@@ -564,7 +564,7 @@ impl ArrowPrimitiveType for Decimal256Type {
const DATA_TYPE: DataType = <Self as DecimalType>::DEFAULT_TYPE;
}
-fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String {
+fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
let (sign, rest) = match value_str.strip_prefix('-') {
Some(stripped) => ("-", stripped),
None => ("", value_str),
@@ -574,13 +574,16 @@ fn format_decimal_str(value_str: &str, precision: usize, scale: usize) -> String
if scale == 0 {
value_str.to_string()
- } else if rest.len() > scale {
+ } else if scale < 0 {
+ let padding = value_str.len() + scale.unsigned_abs() as usize;
+ format!("{:0<width$}", value_str, width = padding)
+ } else if rest.len() > scale as usize {
// Decimal separator is in the middle of the string
- let (whole, decimal) = value_str.split_at(value_str.len() - scale);
+ let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
format!("{}.{}", whole, decimal)
} else {
// String has to be padded
- format!("{}0.{:0>width$}", sign, rest, width = scale)
+ format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
}
}
diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index 3bf97cf7a..61be2171b 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -319,7 +319,7 @@ fn cast_integer_to_decimal<
>(
array: &PrimitiveArray<T>,
precision: u8,
- scale: u8,
+ scale: i8,
base: M,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
@@ -327,7 +327,7 @@ where
<T as ArrowPrimitiveType>::Native: AsPrimitive<M>,
M: ArrowNativeTypeOp,
{
- let mul: M = base.pow_checked(scale as u32).map_err(|_| {
+ let mul_or_div: M = base.pow_checked(scale.unsigned_abs() as u32).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). The scale causes overflow.",
D::PREFIX,
@@ -336,14 +336,26 @@ where
))
})?;
- if cast_options.safe {
+ if scale < 0 {
+ if cast_options.safe {
+ array
+ .unary_opt::<_, D>(|v| v.as_().div_checked(mul_or_div).ok())
+ .with_precision_and_scale(precision, scale)
+ .map(|a| Arc::new(a) as ArrayRef)
+ } else {
+ array
+ .try_unary::<_, D, _>(|v| v.as_().div_checked(mul_or_div))
+ .and_then(|a| a.with_precision_and_scale(precision, scale))
+ .map(|a| Arc::new(a) as ArrayRef)
+ }
+ } else if cast_options.safe {
array
- .unary_opt::<_, D>(|v| v.as_().mul_checked(mul).ok())
+ .unary_opt::<_, D>(|v| v.as_().mul_checked(mul_or_div).ok())
.with_precision_and_scale(precision, scale)
.map(|a| Arc::new(a) as ArrayRef)
} else {
array
- .try_unary::<_, D, _>(|v| v.as_().mul_checked(mul))
+ .try_unary::<_, D, _>(|v| v.as_().mul_checked(mul_or_div))
.and_then(|a| a.with_precision_and_scale(precision, scale))
.map(|a| Arc::new(a) as ArrayRef)
}
@@ -352,7 +364,7 @@ where
fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
- scale: u8,
+ scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
@@ -391,7 +403,7 @@ where
fn cast_floating_point_to_decimal256<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
precision: u8,
- scale: u8,
+ scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
@@ -437,7 +449,7 @@ fn cast_reinterpret_arrays<
fn cast_decimal_to_integer<D, T>(
array: &ArrayRef,
base: D::Native,
- scale: u8,
+ scale: i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError>
where
@@ -1921,9 +1933,9 @@ fn cast_decimal_to_decimal_with_option<
const BYTE_WIDTH2: usize,
>(
array: &ArrayRef,
- input_scale: &u8,
+ input_scale: &i8,
output_precision: &u8,
- output_scale: &u8,
+ output_scale: &i8,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
if cast_options.safe {
@@ -1947,9 +1959,9 @@ fn cast_decimal_to_decimal_with_option<
/// the array values when cast failures happen.
fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
array: &ArrayRef,
- input_scale: &u8,
+ input_scale: &i8,
output_precision: &u8,
- output_scale: &u8,
+ output_scale: &i8,
) -> Result<ArrayRef, ArrowError> {
if input_scale > output_scale {
// For example, input_scale is 4 and output_scale is 3;
@@ -2062,9 +2074,9 @@ fn cast_decimal_to_decimal_safe<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usi
/// cast failure happens.
fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
array: &ArrayRef,
- input_scale: &u8,
+ input_scale: &i8,
output_precision: &u8,
- output_scale: &u8,
+ output_scale: &i8,
) -> Result<ArrayRef, ArrowError> {
if input_scale > output_scale {
// For example, input_scale is 4 and output_scale is 3;
@@ -3540,7 +3552,7 @@ mod tests {
fn create_decimal_array(
array: Vec<Option<i128>>,
precision: u8,
- scale: u8,
+ scale: i8,
) -> Result<Decimal128Array, ArrowError> {
array
.into_iter()
@@ -3551,7 +3563,7 @@ mod tests {
fn create_decimal256_array(
array: Vec<Option<i256>>,
precision: u8,
- scale: u8,
+ scale: i8,
) -> Result<Decimal256Array, ArrowError> {
array
.into_iter()
@@ -7206,4 +7218,62 @@ mod tests {
err
);
}
+
+ #[test]
+ fn test_cast_decimal128_to_decimal128_negative_scale() {
+ let input_type = DataType::Decimal128(20, 0);
+ let output_type = DataType::Decimal128(20, -1);
+ assert!(can_cast_types(&input_type, &output_type));
+ let array = vec![Some(1123456), Some(2123456), Some(3123456), None];
+ let input_decimal_array = create_decimal_array(array, 20, 0).unwrap();
+ let array = Arc::new(input_decimal_array) as ArrayRef;
+ generate_cast_test_case!(
+ &array,
+ Decimal128Array,
+ &output_type,
+ vec![
+ Some(112345_i128),
+ Some(212345_i128),
+ Some(312345_i128),
+ None
+ ]
+ );
+
+ let casted_array = cast(&array, &output_type).unwrap();
+ let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);
+
+ assert_eq!("1123450", decimal_arr.value_as_string(0));
+ assert_eq!("2123450", decimal_arr.value_as_string(1));
+ assert_eq!("3123450", decimal_arr.value_as_string(2));
+ }
+
+ #[test]
+ fn test_cast_numeric_to_decimal128_negative() {
+ let decimal_type = DataType::Decimal128(38, -1);
+ let array = Arc::new(Int32Array::from(vec![
+ Some(1123456),
+ Some(2123456),
+ Some(3123456),
+ ])) as ArrayRef;
+
+ let casted_array = cast(&array, &decimal_type).unwrap();
+ let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);
+
+ assert_eq!("1123450", decimal_arr.value_as_string(0));
+ assert_eq!("2123450", decimal_arr.value_as_string(1));
+ assert_eq!("3123450", decimal_arr.value_as_string(2));
+
+ let array = Arc::new(Float32Array::from(vec![
+ Some(1123.456),
+ Some(2123.456),
+ Some(3123.456),
+ ])) as ArrayRef;
+
+ let casted_array = cast(&array, &decimal_type).unwrap();
+ let decimal_arr = as_primitive_array::<Decimal128Type>(&casted_array);
+
+ assert_eq!("1120", decimal_arr.value_as_string(0));
+ assert_eq!("2120", decimal_arr.value_as_string(1));
+ assert_eq!("3120", decimal_arr.value_as_string(2));
+ }
}
diff --git a/arrow-csv/src/reader.rs b/arrow-csv/src/reader.rs
index 4200e9329..6432fb1b8 100644
--- a/arrow-csv/src/reader.rs
+++ b/arrow-csv/src/reader.rs
@@ -721,7 +721,7 @@ fn build_decimal_array(
rows: &[StringRecord],
col_idx: usize,
precision: u8,
- scale: u8,
+ scale: i8,
) -> Result<ArrayRef, ArrowError> {
let mut decimal_builder = Decimal128Builder::with_capacity(rows.len());
for row in rows {
@@ -762,13 +762,13 @@ fn build_decimal_array(
fn parse_decimal_with_parameter(
s: &str,
precision: u8,
- scale: u8,
+ scale: i8,
) -> Result<i128, ArrowError> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
let len = s.len();
let mut base = 1;
- let scale_usize = usize::from(scale);
+ let scale_usize = usize::from(scale as u8);
// handle the value after the '.' and meet the scale
let delimiter_position = s.find('.');
diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs
index a6a087749..7011c4085 100644
--- a/arrow-data/src/decimal.rs
+++ b/arrow-data/src/decimal.rs
@@ -728,17 +728,17 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
pub const DECIMAL128_MAX_PRECISION: u8 = 38;
/// The maximum scale for [arrow_schema::DataType::Decimal128] values
-pub const DECIMAL128_MAX_SCALE: u8 = 38;
+pub const DECIMAL128_MAX_SCALE: i8 = 38;
/// The maximum precision for [arrow_schema::DataType::Decimal256] values
pub const DECIMAL256_MAX_PRECISION: u8 = 76;
/// The maximum scale for [arrow_schema::DataType::Decimal256] values
-pub const DECIMAL256_MAX_SCALE: u8 = 76;
+pub const DECIMAL256_MAX_SCALE: i8 = 76;
/// The default scale for [arrow_schema::DataType::Decimal128] and
/// [arrow_schema::DataType::Decimal256] values
-pub const DECIMAL_DEFAULT_SCALE: u8 = 10;
+pub const DECIMAL_DEFAULT_SCALE: i8 = 10;
/// Validates that the specified `i128` value can be properly
/// interpreted as a Decimal number with precision `precision`
diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs
index cf85902e4..f74e2a24b 100644
--- a/arrow-schema/src/datatype.rs
+++ b/arrow-schema/src/datatype.rs
@@ -190,14 +190,14 @@ pub enum DataType {
/// * scale is the number of digits past the decimal
///
/// For example the number 123.45 has precision 5 and scale 2.
- Decimal128(u8, u8),
+ Decimal128(u8, i8),
/// Exact 256-bit width decimal value with precision and scale
///
/// * precision is the total number of digits
/// * scale is the number of digits past the decimal
///
/// For example the number 123.45 has precision 5 and scale 2.
- Decimal256(u8, u8),
+ Decimal256(u8, i8),
/// A Map is a logical nested type that is represented as
///
/// `List<entries: Struct<key: K, value: V>>`
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index d498ae487..857b6e323 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -914,7 +914,7 @@ mod tests {
options: Option<TakeOptions>,
expected_data: Vec<Option<i128>>,
precision: &u8,
- scale: &u8,
+ scale: &i8,
) -> Result<(), ArrowError> {
let output = data
.into_iter()
@@ -1032,7 +1032,7 @@ mod tests {
fn test_take_decimal128_non_null_indices() {
let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
let precision: u8 = 10;
- let scale: u8 = 5;
+ let scale: i8 = 5;
test_take_decimal_arrays(
vec![None, Some(3), Some(5), Some(2), Some(3), None],
&index,
@@ -1048,7 +1048,7 @@ mod tests {
fn test_take_decimal128() {
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
let precision: u8 = 10;
- let scale: u8 = 5;
+ let scale: i8 = 5;
test_take_decimal_arrays(
vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
&index,
diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs
index e93c78608..7ef4d1d7e 100644
--- a/arrow/benches/cast_kernels.rs
+++ b/arrow/benches/cast_kernels.rs
@@ -84,7 +84,7 @@ fn build_utf8_date_time_array(size: usize, with_nulls: bool) -> ArrayRef {
Arc::new(builder.finish())
}
-fn build_decimal128_array(size: usize, precision: u8, scale: u8) -> ArrayRef {
+fn build_decimal128_array(size: usize, precision: u8, scale: i8) -> ArrayRef {
let mut rng = seedable_rng();
let mut builder = Decimal128Builder::with_capacity(size);
@@ -99,7 +99,7 @@ fn build_decimal128_array(size: usize, precision: u8, scale: u8) -> ArrayRef {
)
}
-fn build_decimal256_array(size: usize, precision: u8, scale: u8) -> ArrayRef {
+fn build_decimal256_array(size: usize, precision: u8, scale: i8) -> ArrayRef {
let mut rng = seedable_rng();
let mut builder = Decimal256Builder::with_capacity(size);
let mut bytes = [0; 32];
diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs
index ef303dfdd..41addf24f 100644
--- a/arrow/src/datatypes/ffi.rs
+++ b/arrow/src/datatypes/ffi.rs
@@ -103,7 +103,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
"The decimal type requires an integer precision".to_string(),
)
})?;
- let parsed_scale = scale.parse::<u8>().map_err(|_| {
+ let parsed_scale = scale.parse::<i8>().map_err(|_| {
ArrowError::CDataInterface(
"The decimal type requires an integer scale".to_string(),
)
@@ -119,7 +119,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
"The decimal type requires an integer precision".to_string(),
)
})?;
- let parsed_scale = scale.parse::<u8>().map_err(|_| {
+ let parsed_scale = scale.parse::<i8>().map_err(|_| {
ArrowError::CDataInterface(
"The decimal type requires an integer scale".to_string(),
)
diff --git a/arrow/tests/array_transform.rs b/arrow/tests/array_transform.rs
index 42f9ab277..3c08a592d 100644
--- a/arrow/tests/array_transform.rs
+++ b/arrow/tests/array_transform.rs
@@ -31,7 +31,7 @@ use std::sync::Arc;
fn create_decimal_array(
array: Vec<Option<i128>>,
precision: u8,
- scale: u8,
+ scale: i8,
) -> Decimal128Array {
array
.into_iter()