You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/06 23:51:21 UTC
[arrow-datafusion] branch master updated: support decimal scalar value (#1394)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 7f24a79 support decimal scalar value (#1394)
7f24a79 is described below
commit 7f24a79f49565468f3a5afbd617203d1d0c9e950
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Dec 7 07:51:14 2021 +0800
support decimal scalar value (#1394)
---
datafusion/src/scalar.rs | 283 +++++++++++++++++++++++++++++++++++++++++++----
1 file changed, 263 insertions(+), 20 deletions(-)
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index c06ccb1..e9eafe1 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -43,6 +43,8 @@ pub enum ScalarValue {
Float32(Option<f32>),
/// 64bit float
Float64(Option<f64>),
+ /// 128bit decimal, using the i128 to represent the decimal
+ Decimal128(Option<i128>, usize, usize),
/// signed 8bit int
Int8(Option<i8>),
/// signed 16bit int
@@ -100,6 +102,10 @@ impl PartialEq for ScalarValue {
// any newly added enum variant will require editing this list
// or else face a compile error
match (self, other) {
+ (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => {
+ v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
+ }
+ (Decimal128(_, _, _), _) => false,
(Boolean(v1), Boolean(v2)) => v1.eq(v2),
(Boolean(_), _) => false,
(Float32(v1), Float32(v2)) => {
@@ -171,6 +177,15 @@ impl PartialOrd for ScalarValue {
// any newly added enum variant will require editing this list
// or else face a compile error
match (self, other) {
+ (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => {
+ if p1.eq(p2) && s1.eq(s2) {
+ v1.partial_cmp(v2)
+ } else {
+ // Two decimal values can be compared if they have the same precision and scale.
+ None
+ }
+ }
+ (Decimal128(_, _, _), _) => None,
(Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2),
(Boolean(_), _) => None,
(Float32(v1), Float32(v2)) => {
@@ -253,6 +268,11 @@ impl std::hash::Hash for ScalarValue {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
use ScalarValue::*;
match self {
+ Decimal128(v, p, s) => {
+ v.hash(state);
+ p.hash(state);
+ s.hash(state)
+ }
Boolean(v) => v.hash(state),
Float32(v) => {
let v = v.map(OrderedFloat);
@@ -453,6 +473,22 @@ macro_rules! eq_array_primitive {
}
impl ScalarValue {
+ /// Create a decimal Scalar from value/precision and scale.
+ pub fn try_new_decimal128(
+ value: i128,
+ precision: usize,
+ scale: usize,
+ ) -> Result<Self> {
+ // make sure the precision and scale is valid
+ // TODO const the max precision and min scale
+ if precision <= 38 && scale <= precision {
+ return Ok(ScalarValue::Decimal128(Some(value), precision, scale));
+ }
+ return Err(DataFusionError::Internal(format!(
+ "Can not new a decimal type ScalarValue for precision {} and scale {}",
+ precision, scale
+ )));
+ }
/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
match self {
@@ -465,6 +501,9 @@ impl ScalarValue {
ScalarValue::Int16(_) => DataType::Int16,
ScalarValue::Int32(_) => DataType::Int32,
ScalarValue::Int64(_) => DataType::Int64,
+ ScalarValue::Decimal128(_, precision, scale) => {
+ DataType::Decimal(*precision, *scale)
+ }
ScalarValue::TimestampSecond(_) => {
DataType::Timestamp(TimeUnit::Second, None)
}
@@ -513,6 +552,9 @@ impl ScalarValue {
ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)),
ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)),
ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)),
+ ScalarValue::Decimal128(Some(v), precision, scale) => {
+ ScalarValue::Decimal128(Some(-v), *precision, *scale)
+ }
_ => panic!("Cannot run arithmetic negate on scalar value: {:?}", self),
}
}
@@ -541,6 +583,7 @@ impl ScalarValue {
| ScalarValue::TimestampMicrosecond(None)
| ScalarValue::TimestampNanosecond(None)
| ScalarValue::Struct(None, _)
+ | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null.
)
}
@@ -590,7 +633,7 @@ impl ScalarValue {
None => {
return Err(DataFusionError::Internal(
"Empty iterator passed to ScalarValue::iter_to_array".to_string(),
- ))
+ ));
}
Some(sv) => sv.get_datatype(),
};
@@ -706,6 +749,11 @@ impl ScalarValue {
}
let array: ArrayRef = match &data_type {
+ DataType::Decimal(precision, scale) => {
+ let decimal_array =
+ ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
+ Arc::new(decimal_array)
+ }
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
@@ -831,13 +879,40 @@ impl ScalarValue {
"Unsupported creation of {:?} array from ScalarValue {:?}",
data_type,
scalars.peek()
- )))
+ )));
}
};
Ok(array)
}
+ fn iter_to_decimal_array(
+ scalars: impl IntoIterator<Item = ScalarValue>,
+ precision: &usize,
+ scale: &usize,
+ ) -> Result<DecimalArray> {
+ // collect the value as Option<i128>
+ let array = scalars
+ .into_iter()
+ .map(|element: ScalarValue| match element {
+ ScalarValue::Decimal128(v1, _, _) => v1,
+ _ => unreachable!(),
+ })
+ .collect::<Vec<Option<i128>>>();
+
+ // build the decimal array using the Decimal Builder
+ let mut builder = DecimalBuilder::new(array.len(), *precision, *scale);
+ array.iter().for_each(|element| match element {
+ None => {
+ builder.append_null().unwrap();
+ }
+ Some(v) => {
+ builder.append_value(*v).unwrap();
+ }
+ });
+ Ok(builder.finish())
+ }
+
fn iter_to_array_list(
scalars: impl IntoIterator<Item = ScalarValue>,
data_type: &DataType,
@@ -905,9 +980,35 @@ impl ScalarValue {
Ok(list_array)
}
+ fn build_decimal_array(
+ value: &Option<i128>,
+ precision: &usize,
+ scale: &usize,
+ size: usize,
+ ) -> DecimalArray {
+ let mut builder = DecimalBuilder::new(size, *precision, *scale);
+ match value {
+ None => {
+ for _i in 0..size {
+ builder.append_null().unwrap();
+ }
+ }
+ Some(v) => {
+ let v = *v;
+ for _i in 0..size {
+ builder.append_value(v).unwrap();
+ }
+ }
+ };
+ builder.finish()
+ }
+
/// Converts a scalar value into an array of `size` rows.
pub fn to_array_of_size(&self, size: usize) -> ArrayRef {
match self {
+ ScalarValue::Decimal128(e, precision, scale) => {
+ Arc::new(ScalarValue::build_decimal_array(e, precision, scale, size))
+ }
ScalarValue::Boolean(e) => {
Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef
}
@@ -1061,12 +1162,15 @@ impl ScalarValue {
Arc::new(StructArray::from(field_values))
}
None => {
- let field_values: Vec<_> = fields.iter().map(|field| {
+ let field_values: Vec<_> = fields
+ .iter()
+ .map(|field| {
let none_field = Self::try_from(field.data_type()).expect(
"Failed to construct null ScalarValue from Struct field type"
);
(field.clone(), none_field.to_array_of_size(size))
- }).collect();
+ })
+ .collect();
Arc::new(StructArray::from(field_values))
}
@@ -1074,6 +1178,20 @@ impl ScalarValue {
}
}
+ fn get_decimal_value_from_array(
+ array: &ArrayRef,
+ index: usize,
+ precision: &usize,
+ scale: &usize,
+ ) -> ScalarValue {
+ let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+ if array.is_null(index) {
+ ScalarValue::Decimal128(None, *precision, *scale)
+ } else {
+ ScalarValue::Decimal128(Some(array.value(index)), *precision, *scale)
+ }
+ }
+
/// Converts a value in `array` at `index` into a ScalarValue
pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
// handle NULL value
@@ -1082,6 +1200,9 @@ impl ScalarValue {
}
Ok(match array.data_type() {
+ DataType::Decimal(precision, scale) => {
+ ScalarValue::get_decimal_value_from_array(array, index, precision, scale)
+ }
DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean),
DataType::Float64 => typed_cast!(array, index, Float64Array, Float64),
DataType::Float32 => typed_cast!(array, index, Float32Array, Float32),
@@ -1162,7 +1283,7 @@ impl ScalarValue {
return Err(DataFusionError::Internal(format!(
"Index type not supported while creating scalar from dictionary: {}",
array.data_type(),
- )))
+ )));
}
};
@@ -1194,11 +1315,28 @@ impl ScalarValue {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from array of type \"{:?}\"",
other
- )))
+ )));
}
})
}
+ fn eq_array_decimal(
+ array: &ArrayRef,
+ index: usize,
+ value: &Option<i128>,
+ precision: usize,
+ scale: usize,
+ ) -> bool {
+ let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+ if array.precision() != precision || array.scale() != scale {
+ return false;
+ }
+ match value {
+ None => array.is_null(index),
+ Some(v) => !array.is_null(index) && array.value(index) == *v,
+ }
+ }
+
/// Compares a single row of array @ index for equality with self,
/// in an optimized fashion.
///
@@ -1222,6 +1360,9 @@ impl ScalarValue {
}
match self {
+ ScalarValue::Decimal128(v, precision, scale) => {
+ ScalarValue::eq_array_decimal(array, index, v, *precision, *scale)
+ }
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)
}
@@ -1458,6 +1599,9 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::UInt16 => ScalarValue::UInt16(None),
DataType::UInt32 => ScalarValue::UInt32(None),
DataType::UInt64 => ScalarValue::UInt64(None),
+ DataType::Decimal(precision, scale) => {
+ ScalarValue::Decimal128(None, *precision, *scale)
+ }
DataType::Utf8 => ScalarValue::Utf8(None),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
DataType::Date32 => ScalarValue::Date32(None),
@@ -1487,7 +1631,7 @@ impl TryFrom<&DataType> for ScalarValue {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from data_type \"{:?}\"",
datatype
- )))
+ )));
}
})
}
@@ -1505,6 +1649,9 @@ macro_rules! format_option {
impl fmt::Display for ScalarValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
+ ScalarValue::Decimal128(v, p, s) => {
+ write!(f, "{}", format!("{:?},{:?},{:?}", v, p, s))?;
+ }
ScalarValue::Boolean(e) => format_option!(f, e)?,
ScalarValue::Float32(e) => format_option!(f, e)?,
ScalarValue::Float64(e) => format_option!(f, e)?,
@@ -1579,6 +1726,7 @@ impl fmt::Display for ScalarValue {
impl fmt::Debug for ScalarValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
+ ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self),
ScalarValue::Boolean(_) => write!(f, "Boolean({})", self),
ScalarValue::Float32(_) => write!(f, "Float32({})", self),
ScalarValue::Float64(_) => write!(f, "Float64({})", self),
@@ -1677,6 +1825,101 @@ mod tests {
use super::*;
#[test]
+ fn scalar_decimal_test() {
+ let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1);
+ assert_eq!(DataType::Decimal(10, 1), decimal_value.get_datatype());
+ assert!(!decimal_value.is_null());
+ let neg_decimal_value = decimal_value.arithmetic_negate();
+ match neg_decimal_value {
+ ScalarValue::Decimal128(v, _, _) => {
+ assert_eq!(-123, v.unwrap());
+ }
+ _ => {
+ unreachable!();
+ }
+ }
+
+ // decimal scalar to array
+ let array = decimal_value.to_array();
+ let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+ assert_eq!(1, array.len());
+ assert_eq!(DataType::Decimal(10, 1), array.data_type().clone());
+ assert_eq!(123i128, array.value(0));
+
+ // decimal scalar to array with size
+ let array = decimal_value.to_array_of_size(10);
+ let array_decimal = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+ assert_eq!(10, array.len());
+ assert_eq!(DataType::Decimal(10, 1), array.data_type().clone());
+ assert_eq!(123i128, array_decimal.value(0));
+ assert_eq!(123i128, array_decimal.value(9));
+ // test eq array
+ assert!(decimal_value.eq_array(&array, 1));
+ assert!(decimal_value.eq_array(&array, 5));
+ // test try from array
+ assert_eq!(
+ decimal_value,
+ ScalarValue::try_from_array(&array, 5).unwrap()
+ );
+
+ assert_eq!(
+ decimal_value,
+ ScalarValue::try_new_decimal128(123, 10, 1).unwrap()
+ );
+
+ // test compare
+ let left = ScalarValue::Decimal128(Some(123), 10, 2);
+ let right = ScalarValue::Decimal128(Some(124), 10, 2);
+ assert!(!left.eq(&right));
+ let result = left < right;
+ assert!(result);
+ let result = left <= right;
+ assert!(result);
+ let right = ScalarValue::Decimal128(Some(124), 10, 3);
+ // make sure that two decimals with diff datatype can't be compared.
+ let result = left.partial_cmp(&right);
+ assert_eq!(None, result);
+
+ let decimal_vec = vec![
+ ScalarValue::Decimal128(Some(1), 10, 2),
+ ScalarValue::Decimal128(Some(2), 10, 2),
+ ScalarValue::Decimal128(Some(3), 10, 2),
+ ];
+ // convert the vec to decimal array and check the result
+ let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap();
+ assert_eq!(3, array.len());
+ assert_eq!(DataType::Decimal(10, 2), array.data_type().clone());
+
+ let decimal_vec = vec![
+ ScalarValue::Decimal128(Some(1), 10, 2),
+ ScalarValue::Decimal128(Some(2), 10, 2),
+ ScalarValue::Decimal128(Some(3), 10, 2),
+ ScalarValue::Decimal128(None, 10, 2),
+ ];
+ let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap();
+ assert_eq!(4, array.len());
+ assert_eq!(DataType::Decimal(10, 2), array.data_type().clone());
+
+ assert!(ScalarValue::try_new_decimal128(1, 10, 2)
+ .unwrap()
+ .eq_array(&array, 0));
+ assert!(ScalarValue::try_new_decimal128(2, 10, 2)
+ .unwrap()
+ .eq_array(&array, 1));
+ assert!(ScalarValue::try_new_decimal128(3, 10, 2)
+ .unwrap()
+ .eq_array(&array, 2));
+ assert_eq!(
+ ScalarValue::Decimal128(None, 10, 2),
+ ScalarValue::try_from_array(&array, 3).unwrap()
+ );
+ assert_eq!(
+ ScalarValue::Decimal128(None, 10, 2),
+ ScalarValue::try_from_array(&array, 4).unwrap()
+ );
+ }
+
+ #[test]
fn scalar_value_to_array_u64() {
let value = ScalarValue::UInt64(Some(13u64));
let array = value.to_array();
@@ -1909,7 +2152,7 @@ mod tests {
// Since ScalarValues are used in a non trivial number of places,
// making it larger means significant more memory consumption
// per distinct value.
- assert_eq!(std::mem::size_of::<ScalarValue>(), 32);
+ assert_eq!(std::mem::size_of::<ScalarValue>(), 48);
}
#[test]
@@ -2088,11 +2331,11 @@ mod tests {
assert_eq!(
List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
- Box::new(DataType::Int32)
+ Box::new(DataType::Int32),
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
- Box::new(DataType::Int32)
+ Box::new(DataType::Int32),
)),
Some(Ordering::Equal)
);
@@ -2100,11 +2343,11 @@ mod tests {
assert_eq!(
List(
Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
- Box::new(DataType::Int32)
+ Box::new(DataType::Int32),
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
- Box::new(DataType::Int32)
+ Box::new(DataType::Int32),
)),
Some(Ordering::Greater)
);
@@ -2112,11 +2355,11 @@ mod tests {
assert_eq!(
List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
- Box::new(DataType::Int32)
+ Box::new(DataType::Int32),
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
- Box::new(DataType::Int32)
+ Box::new(DataType::Int32),
)),
Some(Ordering::Less)
);
@@ -2125,11 +2368,11 @@ mod tests {
assert_eq!(
List(
Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])),
- Box::new(DataType::Int64)
+ Box::new(DataType::Int64),
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
- Box::new(DataType::Int32)
+ Box::new(DataType::Int32),
)),
None
);
@@ -2137,11 +2380,11 @@ mod tests {
assert_eq!(
ScalarValue::from(vec![
("A", ScalarValue::from(1.0)),
- ("B", ScalarValue::from("Z"))
+ ("B", ScalarValue::from("Z")),
])
.partial_cmp(&ScalarValue::from(vec![
("A", ScalarValue::from(2.0)),
- ("B", ScalarValue::from("A"))
+ ("B", ScalarValue::from("A")),
])),
Some(Ordering::Less)
);
@@ -2150,11 +2393,11 @@ mod tests {
assert_eq!(
ScalarValue::from(vec![
("A", ScalarValue::from(1.0)),
- ("B", ScalarValue::from("Z"))
+ ("B", ScalarValue::from("Z")),
])
.partial_cmp(&ScalarValue::from(vec![
("a", ScalarValue::from(2.0)),
- ("b", ScalarValue::from("A"))
+ ("b", ScalarValue::from("A")),
])),
None
);