You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/06 23:51:21 UTC

[arrow-datafusion] branch master updated: support decimal scalar value (#1394)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 7f24a79  support decimal scalar value (#1394)
7f24a79 is described below

commit 7f24a79f49565468f3a5afbd617203d1d0c9e950
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Dec 7 07:51:14 2021 +0800

    support decimal scalar value (#1394)
---
 datafusion/src/scalar.rs | 283 +++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 263 insertions(+), 20 deletions(-)

diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index c06ccb1..e9eafe1 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -43,6 +43,8 @@ pub enum ScalarValue {
     Float32(Option<f32>),
     /// 64bit float
     Float64(Option<f64>),
+    /// 128bit decimal, using the i128 to represent the decimal
+    Decimal128(Option<i128>, usize, usize),
     /// signed 8bit int
     Int8(Option<i8>),
     /// signed 16bit int
@@ -100,6 +102,10 @@ impl PartialEq for ScalarValue {
         // any newly added enum variant will require editing this list
         // or else face a compile error
         match (self, other) {
+            (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => {
+                v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
+            }
+            (Decimal128(_, _, _), _) => false,
             (Boolean(v1), Boolean(v2)) => v1.eq(v2),
             (Boolean(_), _) => false,
             (Float32(v1), Float32(v2)) => {
@@ -171,6 +177,15 @@ impl PartialOrd for ScalarValue {
         // any newly added enum variant will require editing this list
         // or else face a compile error
         match (self, other) {
+            (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => {
+                if p1.eq(p2) && s1.eq(s2) {
+                    v1.partial_cmp(v2)
+                } else {
+                    // Two decimal values can be compared if they have the same precision and scale.
+                    None
+                }
+            }
+            (Decimal128(_, _, _), _) => None,
             (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2),
             (Boolean(_), _) => None,
             (Float32(v1), Float32(v2)) => {
@@ -253,6 +268,11 @@ impl std::hash::Hash for ScalarValue {
     fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
         use ScalarValue::*;
         match self {
+            Decimal128(v, p, s) => {
+                v.hash(state);
+                p.hash(state);
+                s.hash(state)
+            }
             Boolean(v) => v.hash(state),
             Float32(v) => {
                 let v = v.map(OrderedFloat);
@@ -453,6 +473,22 @@ macro_rules! eq_array_primitive {
 }
 
 impl ScalarValue {
+    /// Create a decimal Scalar from value/precision and scale.
+    pub fn try_new_decimal128(
+        value: i128,
+        precision: usize,
+        scale: usize,
+    ) -> Result<Self> {
+        // make sure the precision and scale is valid
+        // TODO const the max precision and min scale
+        if precision <= 38 && scale <= precision {
+            return Ok(ScalarValue::Decimal128(Some(value), precision, scale));
+        }
+        return Err(DataFusionError::Internal(format!(
+            "Can not new a decimal type ScalarValue for precision {} and scale {}",
+            precision, scale
+        )));
+    }
     /// Getter for the `DataType` of the value
     pub fn get_datatype(&self) -> DataType {
         match self {
@@ -465,6 +501,9 @@ impl ScalarValue {
             ScalarValue::Int16(_) => DataType::Int16,
             ScalarValue::Int32(_) => DataType::Int32,
             ScalarValue::Int64(_) => DataType::Int64,
+            ScalarValue::Decimal128(_, precision, scale) => {
+                DataType::Decimal(*precision, *scale)
+            }
             ScalarValue::TimestampSecond(_) => {
                 DataType::Timestamp(TimeUnit::Second, None)
             }
@@ -513,6 +552,9 @@ impl ScalarValue {
             ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)),
             ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)),
             ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)),
+            ScalarValue::Decimal128(Some(v), precision, scale) => {
+                ScalarValue::Decimal128(Some(-v), *precision, *scale)
+            }
             _ => panic!("Cannot run arithmetic negate on scalar value: {:?}", self),
         }
     }
@@ -541,6 +583,7 @@ impl ScalarValue {
                 | ScalarValue::TimestampMicrosecond(None)
                 | ScalarValue::TimestampNanosecond(None)
                 | ScalarValue::Struct(None, _)
+                | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null.
         )
     }
 
@@ -590,7 +633,7 @@ impl ScalarValue {
             None => {
                 return Err(DataFusionError::Internal(
                     "Empty iterator passed to ScalarValue::iter_to_array".to_string(),
-                ))
+                ));
             }
             Some(sv) => sv.get_datatype(),
         };
@@ -706,6 +749,11 @@ impl ScalarValue {
         }
 
         let array: ArrayRef = match &data_type {
+            DataType::Decimal(precision, scale) => {
+                let decimal_array =
+                    ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
+                Arc::new(decimal_array)
+            }
             DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
             DataType::Float32 => build_array_primitive!(Float32Array, Float32),
             DataType::Float64 => build_array_primitive!(Float64Array, Float64),
@@ -831,13 +879,40 @@ impl ScalarValue {
                     "Unsupported creation of {:?} array from ScalarValue {:?}",
                     data_type,
                     scalars.peek()
-                )))
+                )));
             }
         };
 
         Ok(array)
     }
 
+    fn iter_to_decimal_array(
+        scalars: impl IntoIterator<Item = ScalarValue>,
+        precision: &usize,
+        scale: &usize,
+    ) -> Result<DecimalArray> {
+        // collect the value as Option<i128>
+        let array = scalars
+            .into_iter()
+            .map(|element: ScalarValue| match element {
+                ScalarValue::Decimal128(v1, _, _) => v1,
+                _ => unreachable!(),
+            })
+            .collect::<Vec<Option<i128>>>();
+
+        // build the decimal array using the Decimal Builder
+        let mut builder = DecimalBuilder::new(array.len(), *precision, *scale);
+        array.iter().for_each(|element| match element {
+            None => {
+                builder.append_null().unwrap();
+            }
+            Some(v) => {
+                builder.append_value(*v).unwrap();
+            }
+        });
+        Ok(builder.finish())
+    }
+
     fn iter_to_array_list(
         scalars: impl IntoIterator<Item = ScalarValue>,
         data_type: &DataType,
@@ -905,9 +980,35 @@ impl ScalarValue {
         Ok(list_array)
     }
 
+    fn build_decimal_array(
+        value: &Option<i128>,
+        precision: &usize,
+        scale: &usize,
+        size: usize,
+    ) -> DecimalArray {
+        let mut builder = DecimalBuilder::new(size, *precision, *scale);
+        match value {
+            None => {
+                for _i in 0..size {
+                    builder.append_null().unwrap();
+                }
+            }
+            Some(v) => {
+                let v = *v;
+                for _i in 0..size {
+                    builder.append_value(v).unwrap();
+                }
+            }
+        };
+        builder.finish()
+    }
+
     /// Converts a scalar value into an array of `size` rows.
     pub fn to_array_of_size(&self, size: usize) -> ArrayRef {
         match self {
+            ScalarValue::Decimal128(e, precision, scale) => {
+                Arc::new(ScalarValue::build_decimal_array(e, precision, scale, size))
+            }
             ScalarValue::Boolean(e) => {
                 Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef
             }
@@ -1061,12 +1162,15 @@ impl ScalarValue {
                     Arc::new(StructArray::from(field_values))
                 }
                 None => {
-                    let field_values: Vec<_> = fields.iter().map(|field| {
+                    let field_values: Vec<_> = fields
+                        .iter()
+                        .map(|field| {
                             let none_field = Self::try_from(field.data_type()).expect(
                                 "Failed to construct null ScalarValue from Struct field type"
                             );
                             (field.clone(), none_field.to_array_of_size(size))
-                        }).collect();
+                        })
+                        .collect();
 
                     Arc::new(StructArray::from(field_values))
                 }
@@ -1074,6 +1178,20 @@ impl ScalarValue {
         }
     }
 
+    fn get_decimal_value_from_array(
+        array: &ArrayRef,
+        index: usize,
+        precision: &usize,
+        scale: &usize,
+    ) -> ScalarValue {
+        let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+        if array.is_null(index) {
+            ScalarValue::Decimal128(None, *precision, *scale)
+        } else {
+            ScalarValue::Decimal128(Some(array.value(index)), *precision, *scale)
+        }
+    }
+
     /// Converts a value in `array` at `index` into a ScalarValue
     pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
         // handle NULL value
@@ -1082,6 +1200,9 @@ impl ScalarValue {
         }
 
         Ok(match array.data_type() {
+            DataType::Decimal(precision, scale) => {
+                ScalarValue::get_decimal_value_from_array(array, index, precision, scale)
+            }
             DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean),
             DataType::Float64 => typed_cast!(array, index, Float64Array, Float64),
             DataType::Float32 => typed_cast!(array, index, Float32Array, Float32),
@@ -1162,7 +1283,7 @@ impl ScalarValue {
                         return Err(DataFusionError::Internal(format!(
                             "Index type not supported while creating scalar from dictionary: {}",
                             array.data_type(),
-                        )))
+                        )));
                     }
                 };
 
@@ -1194,11 +1315,28 @@ impl ScalarValue {
                 return Err(DataFusionError::NotImplemented(format!(
                     "Can't create a scalar from array of type \"{:?}\"",
                     other
-                )))
+                )));
             }
         })
     }
 
+    fn eq_array_decimal(
+        array: &ArrayRef,
+        index: usize,
+        value: &Option<i128>,
+        precision: usize,
+        scale: usize,
+    ) -> bool {
+        let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+        if array.precision() != precision || array.scale() != scale {
+            return false;
+        }
+        match value {
+            None => array.is_null(index),
+            Some(v) => !array.is_null(index) && array.value(index) == *v,
+        }
+    }
+
     /// Compares a single row of array @ index for equality with self,
     /// in an optimized fashion.
     ///
@@ -1222,6 +1360,9 @@ impl ScalarValue {
         }
 
         match self {
+            ScalarValue::Decimal128(v, precision, scale) => {
+                ScalarValue::eq_array_decimal(array, index, v, *precision, *scale)
+            }
             ScalarValue::Boolean(val) => {
                 eq_array_primitive!(array, index, BooleanArray, val)
             }
@@ -1458,6 +1599,9 @@ impl TryFrom<&DataType> for ScalarValue {
             DataType::UInt16 => ScalarValue::UInt16(None),
             DataType::UInt32 => ScalarValue::UInt32(None),
             DataType::UInt64 => ScalarValue::UInt64(None),
+            DataType::Decimal(precision, scale) => {
+                ScalarValue::Decimal128(None, *precision, *scale)
+            }
             DataType::Utf8 => ScalarValue::Utf8(None),
             DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
             DataType::Date32 => ScalarValue::Date32(None),
@@ -1487,7 +1631,7 @@ impl TryFrom<&DataType> for ScalarValue {
                 return Err(DataFusionError::NotImplemented(format!(
                     "Can't create a scalar from data_type \"{:?}\"",
                     datatype
-                )))
+                )));
             }
         })
     }
@@ -1505,6 +1649,9 @@ macro_rules! format_option {
 impl fmt::Display for ScalarValue {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match self {
+            ScalarValue::Decimal128(v, p, s) => {
+                write!(f, "{}", format!("{:?},{:?},{:?}", v, p, s))?;
+            }
             ScalarValue::Boolean(e) => format_option!(f, e)?,
             ScalarValue::Float32(e) => format_option!(f, e)?,
             ScalarValue::Float64(e) => format_option!(f, e)?,
@@ -1579,6 +1726,7 @@ impl fmt::Display for ScalarValue {
 impl fmt::Debug for ScalarValue {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         match self {
+            ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self),
             ScalarValue::Boolean(_) => write!(f, "Boolean({})", self),
             ScalarValue::Float32(_) => write!(f, "Float32({})", self),
             ScalarValue::Float64(_) => write!(f, "Float64({})", self),
@@ -1677,6 +1825,101 @@ mod tests {
     use super::*;
 
     #[test]
+    fn scalar_decimal_test() {
+        let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1);
+        assert_eq!(DataType::Decimal(10, 1), decimal_value.get_datatype());
+        assert!(!decimal_value.is_null());
+        let neg_decimal_value = decimal_value.arithmetic_negate();
+        match neg_decimal_value {
+            ScalarValue::Decimal128(v, _, _) => {
+                assert_eq!(-123, v.unwrap());
+            }
+            _ => {
+                unreachable!();
+            }
+        }
+
+        // decimal scalar to array
+        let array = decimal_value.to_array();
+        let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+        assert_eq!(1, array.len());
+        assert_eq!(DataType::Decimal(10, 1), array.data_type().clone());
+        assert_eq!(123i128, array.value(0));
+
+        // decimal scalar to array with size
+        let array = decimal_value.to_array_of_size(10);
+        let array_decimal = array.as_any().downcast_ref::<DecimalArray>().unwrap();
+        assert_eq!(10, array.len());
+        assert_eq!(DataType::Decimal(10, 1), array.data_type().clone());
+        assert_eq!(123i128, array_decimal.value(0));
+        assert_eq!(123i128, array_decimal.value(9));
+        // test eq array
+        assert!(decimal_value.eq_array(&array, 1));
+        assert!(decimal_value.eq_array(&array, 5));
+        // test try from array
+        assert_eq!(
+            decimal_value,
+            ScalarValue::try_from_array(&array, 5).unwrap()
+        );
+
+        assert_eq!(
+            decimal_value,
+            ScalarValue::try_new_decimal128(123, 10, 1).unwrap()
+        );
+
+        // test compare
+        let left = ScalarValue::Decimal128(Some(123), 10, 2);
+        let right = ScalarValue::Decimal128(Some(124), 10, 2);
+        assert!(!left.eq(&right));
+        let result = left < right;
+        assert!(result);
+        let result = left <= right;
+        assert!(result);
+        let right = ScalarValue::Decimal128(Some(124), 10, 3);
+        // make sure that two decimals with diff datatype can't be compared.
+        let result = left.partial_cmp(&right);
+        assert_eq!(None, result);
+
+        let decimal_vec = vec![
+            ScalarValue::Decimal128(Some(1), 10, 2),
+            ScalarValue::Decimal128(Some(2), 10, 2),
+            ScalarValue::Decimal128(Some(3), 10, 2),
+        ];
+        // convert the vec to decimal array and check the result
+        let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap();
+        assert_eq!(3, array.len());
+        assert_eq!(DataType::Decimal(10, 2), array.data_type().clone());
+
+        let decimal_vec = vec![
+            ScalarValue::Decimal128(Some(1), 10, 2),
+            ScalarValue::Decimal128(Some(2), 10, 2),
+            ScalarValue::Decimal128(Some(3), 10, 2),
+            ScalarValue::Decimal128(None, 10, 2),
+        ];
+        let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap();
+        assert_eq!(4, array.len());
+        assert_eq!(DataType::Decimal(10, 2), array.data_type().clone());
+
+        assert!(ScalarValue::try_new_decimal128(1, 10, 2)
+            .unwrap()
+            .eq_array(&array, 0));
+        assert!(ScalarValue::try_new_decimal128(2, 10, 2)
+            .unwrap()
+            .eq_array(&array, 1));
+        assert!(ScalarValue::try_new_decimal128(3, 10, 2)
+            .unwrap()
+            .eq_array(&array, 2));
+        assert_eq!(
+            ScalarValue::Decimal128(None, 10, 2),
+            ScalarValue::try_from_array(&array, 3).unwrap()
+        );
+        assert_eq!(
+            ScalarValue::Decimal128(None, 10, 2),
+            ScalarValue::try_from_array(&array, 4).unwrap()
+        );
+    }
+
+    #[test]
     fn scalar_value_to_array_u64() {
         let value = ScalarValue::UInt64(Some(13u64));
         let array = value.to_array();
@@ -1909,7 +2152,7 @@ mod tests {
         // Since ScalarValues are used in a non trivial number of places,
         // making it larger means significant more memory consumption
         // per distinct value.
-        assert_eq!(std::mem::size_of::<ScalarValue>(), 32);
+        assert_eq!(std::mem::size_of::<ScalarValue>(), 48);
     }
 
     #[test]
@@ -2088,11 +2331,11 @@ mod tests {
         assert_eq!(
             List(
                 Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
-                Box::new(DataType::Int32)
+                Box::new(DataType::Int32),
             )
             .partial_cmp(&List(
                 Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
-                Box::new(DataType::Int32)
+                Box::new(DataType::Int32),
             )),
             Some(Ordering::Equal)
         );
@@ -2100,11 +2343,11 @@ mod tests {
         assert_eq!(
             List(
                 Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
-                Box::new(DataType::Int32)
+                Box::new(DataType::Int32),
             )
             .partial_cmp(&List(
                 Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
-                Box::new(DataType::Int32)
+                Box::new(DataType::Int32),
             )),
             Some(Ordering::Greater)
         );
@@ -2112,11 +2355,11 @@ mod tests {
         assert_eq!(
             List(
                 Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
-                Box::new(DataType::Int32)
+                Box::new(DataType::Int32),
             )
             .partial_cmp(&List(
                 Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
-                Box::new(DataType::Int32)
+                Box::new(DataType::Int32),
             )),
             Some(Ordering::Less)
         );
@@ -2125,11 +2368,11 @@ mod tests {
         assert_eq!(
             List(
                 Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])),
-                Box::new(DataType::Int64)
+                Box::new(DataType::Int64),
             )
             .partial_cmp(&List(
                 Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
-                Box::new(DataType::Int32)
+                Box::new(DataType::Int32),
             )),
             None
         );
@@ -2137,11 +2380,11 @@ mod tests {
         assert_eq!(
             ScalarValue::from(vec![
                 ("A", ScalarValue::from(1.0)),
-                ("B", ScalarValue::from("Z"))
+                ("B", ScalarValue::from("Z")),
             ])
             .partial_cmp(&ScalarValue::from(vec![
                 ("A", ScalarValue::from(2.0)),
-                ("B", ScalarValue::from("A"))
+                ("B", ScalarValue::from("A")),
             ])),
             Some(Ordering::Less)
         );
@@ -2150,11 +2393,11 @@ mod tests {
         assert_eq!(
             ScalarValue::from(vec![
                 ("A", ScalarValue::from(1.0)),
-                ("B", ScalarValue::from("Z"))
+                ("B", ScalarValue::from("Z")),
             ])
             .partial_cmp(&ScalarValue::from(vec![
                 ("a", ScalarValue::from(2.0)),
-                ("b", ScalarValue::from("A"))
+                ("b", ScalarValue::from("A")),
             ])),
             None
         );