You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/17 15:29:17 UTC

[arrow-datafusion] branch master updated: support sum/avg agg for decimal, change sum(float32) --> float64 (#1408)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d31866  support sum/avg agg for decimal, change sum(float32) --> float64 (#1408)
9d31866 is described below

commit 9d3186693b614db57143adbd81c82a60752a8bac
Author: Kun Liu <li...@apache.org>
AuthorDate: Fri Dec 17 23:29:10 2021 +0800

    support sum/avg agg for decimal, change sum(float32) --> float64 (#1408)
    
    * support sum/avg agg for decimal
    
    * support sum/avg agg for decimal
    
    * suppor the avg and add test
    
    * add comments and const
---
 datafusion/src/execution/context.rs                |  59 ++++-
 datafusion/src/physical_plan/aggregates.rs         |  34 ++-
 .../physical_plan/coercion_rule/aggregate_rule.rs  |   3 +-
 .../src/physical_plan/expressions/average.rs       | 120 ++++++++--
 datafusion/src/physical_plan/expressions/sum.rs    | 259 +++++++++++++++++++--
 datafusion/src/scalar.rs                           |   8 +-
 datafusion/src/sql/utils.rs                        |   4 +-
 7 files changed, 447 insertions(+), 40 deletions(-)

diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index d7c536e..8c3df46 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1845,9 +1845,9 @@ mod tests {
     #[tokio::test]
     async fn aggregate_decimal_min() -> Result<()> {
         let mut ctx = ExecutionContext::new();
+        // the data type of c1 is decimal(10,3)
         ctx.register_table("d_table", test::table_with_decimal())
             .unwrap();
-
         let result = plan_and_collect(&mut ctx, "select min(c1) from d_table")
             .await
             .unwrap();
@@ -1858,6 +1858,10 @@ mod tests {
             "| -100.009        |",
             "+-----------------+",
         ];
+        assert_eq!(
+            &DataType::Decimal(10, 3),
+            result[0].schema().field(0).data_type()
+        );
         assert_batches_sorted_eq!(expected, &result);
         Ok(())
     }
@@ -1865,6 +1869,7 @@ mod tests {
     #[tokio::test]
     async fn aggregate_decimal_max() -> Result<()> {
         let mut ctx = ExecutionContext::new();
+        // the data type of c1 is decimal(10,3)
         ctx.register_table("d_table", test::table_with_decimal())
             .unwrap();
 
@@ -1878,6 +1883,58 @@ mod tests {
             "| 110.009         |",
             "+-----------------+",
         ];
+        assert_eq!(
+            &DataType::Decimal(10, 3),
+            result[0].schema().field(0).data_type()
+        );
+        assert_batches_sorted_eq!(expected, &result);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn aggregate_decimal_sum() -> Result<()> {
+        let mut ctx = ExecutionContext::new();
+        // the data type of c1 is decimal(10,3)
+        ctx.register_table("d_table", test::table_with_decimal())
+            .unwrap();
+        let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table")
+            .await
+            .unwrap();
+        let expected = vec![
+            "+-----------------+",
+            "| SUM(d_table.c1) |",
+            "+-----------------+",
+            "| 100.000         |",
+            "+-----------------+",
+        ];
+        assert_eq!(
+            &DataType::Decimal(20, 3),
+            result[0].schema().field(0).data_type()
+        );
+        assert_batches_sorted_eq!(expected, &result);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn aggregate_decimal_avg() -> Result<()> {
+        let mut ctx = ExecutionContext::new();
+        // the data type of c1 is decimal(10,3)
+        ctx.register_table("d_table", test::table_with_decimal())
+            .unwrap();
+        let result = plan_and_collect(&mut ctx, "select avg(c1) from d_table")
+            .await
+            .unwrap();
+        let expected = vec![
+            "+-----------------+",
+            "| AVG(d_table.c1) |",
+            "+-----------------+",
+            "| 5.0000000       |",
+            "+-----------------+",
+        ];
+        assert_eq!(
+            &DataType::Decimal(14, 7),
+            result[0].schema().field(0).data_type()
+        );
         assert_batches_sorted_eq!(expected, &result);
         Ok(())
     }
diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs
index 50e1a82..e9f9696 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -426,7 +426,7 @@ mod tests {
                             | DataType::Int16
                             | DataType::Int32
                             | DataType::Int64 => DataType::Int64,
-                            DataType::Float32 | DataType::Float64 => data_type.clone(),
+                            DataType::Float32 | DataType::Float64 => DataType::Float64,
                             _ => data_type.clone(),
                         };
 
@@ -471,6 +471,29 @@ mod tests {
     }
 
     #[test]
+    fn test_sum_return_type() -> Result<()> {
+        let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32])?;
+        assert_eq!(DataType::Int64, observed);
+
+        let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8])?;
+        assert_eq!(DataType::UInt64, observed);
+
+        let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32])?;
+        assert_eq!(DataType::Float64, observed);
+
+        let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64])?;
+        assert_eq!(DataType::Float64, observed);
+
+        let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(10, 5)])?;
+        assert_eq!(DataType::Decimal(20, 5), observed);
+
+        let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(35, 5)])?;
+        assert_eq!(DataType::Decimal(38, 5), observed);
+
+        Ok(())
+    }
+
+    #[test]
     fn test_sum_no_utf8() {
         let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]);
         assert!(observed.is_err());
@@ -504,6 +527,15 @@ mod tests {
 
         let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?;
         assert_eq!(DataType::Float64, observed);
+
+        let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32])?;
+        assert_eq!(DataType::Float64, observed);
+
+        let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(10, 6)])?;
+        assert_eq!(DataType::Decimal(14, 10), observed);
+
+        let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(36, 6)])?;
+        assert_eq!(DataType::Decimal(38, 10), observed);
         Ok(())
     }
 
diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
index d7b4375..e76e4a6 100644
--- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
+++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
@@ -193,8 +193,7 @@ mod tests {
         let input_types = vec![
             vec![DataType::Int32],
             vec![DataType::Float32],
-            // support the decimal data type
-            // vec![DataType::Decimal(20, 3)],
+            vec![DataType::Decimal(20, 3)],
         ];
         for fun in funs {
             for input_type in &input_types {
diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs
index feb568c..f092989 100644
--- a/datafusion/src/physical_plan/expressions/average.rs
+++ b/datafusion/src/physical_plan/expressions/average.rs
@@ -23,7 +23,9 @@ use std::sync::Arc;
 
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
-use crate::scalar::ScalarValue;
+use crate::scalar::{
+    ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128,
+};
 use arrow::compute;
 use arrow::datatypes::DataType;
 use arrow::{
@@ -38,11 +40,19 @@ use super::{format_state_name, sum};
 pub struct Avg {
     name: String,
     expr: Arc<dyn PhysicalExpr>,
+    data_type: DataType,
 }
 
 /// function return type of an average
 pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
     match arg_type {
+        DataType::Decimal(precision, scale) => {
+            // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
+            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+            let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4);
+            let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4);
+            Ok(DataType::Decimal(new_precision, new_scale))
+        }
         DataType::Int8
         | DataType::Int16
         | DataType::Int32
@@ -73,6 +83,7 @@ pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
             | DataType::Int64
             | DataType::Float32
             | DataType::Float64
+            | DataType::Decimal(_, _)
     )
 }
 
@@ -83,14 +94,15 @@ impl Avg {
         name: impl Into<String>,
         data_type: DataType,
     ) -> Self {
-        // Average is always Float64, but Avg::new() has a data_type
-        // parameter to keep a consistent signature with the other
-        // Aggregate expressions.
-        assert_eq!(data_type, DataType::Float64);
-
+        // the result of avg just support FLOAT64 and Decimal data type.
+        assert!(matches!(
+            data_type,
+            DataType::Float64 | DataType::Decimal(_, _)
+        ));
         Self {
             name: name.into(),
             expr,
+            data_type,
         }
     }
 }
@@ -102,7 +114,14 @@ impl AggregateExpr for Avg {
     }
 
     fn field(&self) -> Result<Field> {
-        Ok(Field::new(&self.name, DataType::Float64, true))
+        Ok(Field::new(&self.name, self.data_type.clone(), true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(AvgAccumulator::try_new(
+            // avg is f64 or decimal
+            &self.data_type,
+        )?))
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
@@ -114,19 +133,12 @@ impl AggregateExpr for Avg {
             ),
             Field::new(
                 &format_state_name(&self.name, "sum"),
-                DataType::Float64,
+                self.data_type.clone(),
                 true,
             ),
         ])
     }
 
-    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(AvgAccumulator::try_new(
-            // avg is f64
-            &DataType::Float64,
-        )?))
-    }
-
     fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
         vec![self.expr.clone()]
     }
@@ -205,6 +217,17 @@ impl Accumulator for AvgAccumulator {
             ScalarValue::Float64(e) => {
                 Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64)))
             }
+            ScalarValue::Decimal128(value, precision, scale) => {
+                Ok(match value {
+                    None => ScalarValue::Decimal128(None, precision, scale),
+                    // TODO add the checker for overflow the precision
+                    Some(v) => ScalarValue::Decimal128(
+                        Some(v / self.count as i128),
+                        precision,
+                        scale,
+                    ),
+                })
+            }
             _ => Err(DataFusionError::Internal(
                 "Sum should be f64 on average".to_string(),
             )),
@@ -221,6 +244,73 @@ mod tests {
     use arrow::{array::*, datatypes::*};
 
     #[test]
+    fn test_avg_return_data_type() -> Result<()> {
+        let data_type = DataType::Decimal(10, 5);
+        let result_type = avg_return_type(&data_type)?;
+        assert_eq!(DataType::Decimal(14, 9), result_type);
+
+        let data_type = DataType::Decimal(36, 10);
+        let result_type = avg_return_type(&data_type)?;
+        assert_eq!(DataType::Decimal(38, 14), result_type);
+        Ok(())
+    }
+
+    #[test]
+    fn avg_decimal() -> Result<()> {
+        // test agg
+        let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
+        for i in 1..7 {
+            decimal_builder.append_value(i as i128)?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Avg,
+            ScalarValue::Decimal128(Some(35000), 14, 4),
+            DataType::Decimal(14, 4)
+        )
+    }
+
+    #[test]
+    fn avg_decimal_with_nulls() -> Result<()> {
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for i in 1..6 {
+            if i == 2 {
+                decimal_builder.append_null()?;
+            } else {
+                decimal_builder.append_value(i)?;
+            }
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Avg,
+            ScalarValue::Decimal128(Some(32500), 14, 4),
+            DataType::Decimal(14, 4)
+        )
+    }
+
+    #[test]
+    fn avg_decimal_all_nulls() -> Result<()> {
+        // test agg
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for _i in 1..6 {
+            decimal_builder.append_null()?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Avg,
+            ScalarValue::Decimal128(None, 14, 4),
+            DataType::Decimal(14, 4)
+        )
+    }
+
+    #[test]
     fn avg_i32() -> Result<()> {
         let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
         generic_test_op!(
diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs
index c570aef..027736d 100644
--- a/datafusion/src/physical_plan/expressions/sum.rs
+++ b/datafusion/src/physical_plan/expressions/sum.rs
@@ -23,7 +23,7 @@ use std::sync::Arc;
 
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
-use crate::scalar::ScalarValue;
+use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128};
 use arrow::compute;
 use arrow::datatypes::DataType;
 use arrow::{
@@ -35,6 +35,8 @@ use arrow::{
 };
 
 use super::format_state_name;
+use crate::arrow::array::Array;
+use arrow::array::DecimalArray;
 
 /// SUM aggregate expression
 #[derive(Debug)]
@@ -54,8 +56,15 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
         DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
             Ok(DataType::UInt64)
         }
-        DataType::Float32 => Ok(DataType::Float32),
-        DataType::Float64 => Ok(DataType::Float64),
+        // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,
+        // the result type of floating-point is FLOAT64 with the double precision.
+        DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
+        DataType::Decimal(precision, scale) => {
+            // in the spark, the result type is DECIMAL(min(38,precision+10), s)
+            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+            let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 10);
+            Ok(DataType::Decimal(new_precision, *scale))
+        }
         other => Err(DataFusionError::Plan(format!(
             "SUM does not support type \"{:?}\"",
             other
@@ -76,6 +85,7 @@ pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
             | DataType::Int64
             | DataType::Float32
             | DataType::Float64
+            | DataType::Decimal(_, _)
     )
 }
 
@@ -109,6 +119,10 @@ impl AggregateExpr for Sum {
         ))
     }
 
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
+    }
+
     fn state_fields(&self) -> Result<Vec<Field>> {
         Ok(vec![Field::new(
             &format_state_name(&self.name, "sum"),
@@ -121,10 +135,6 @@ impl AggregateExpr for Sum {
         vec![self.expr.clone()]
     }
 
-    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
-    }
-
     fn name(&self) -> &str {
         &self.name
     }
@@ -153,9 +163,34 @@ macro_rules! typed_sum_delta_batch {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal_batch(
+    values: &ArrayRef,
+    precision: &usize,
+    scale: &usize,
+) -> Result<ScalarValue> {
+    let array = values.as_any().downcast_ref::<DecimalArray>().unwrap();
+
+    if array.null_count() == array.len() {
+        return Ok(ScalarValue::Decimal128(None, *precision, *scale));
+    }
+
+    let mut result = 0_i128;
+    for i in 0..array.len() {
+        if array.is_valid(i) {
+            result += array.value(i);
+        }
+    }
+    Ok(ScalarValue::Decimal128(Some(result), *precision, *scale))
+}
+
 // sums the array and returns a ScalarValue of its corresponding type.
 pub(super) fn sum_batch(values: &ArrayRef) -> Result<ScalarValue> {
     Ok(match values.data_type() {
+        DataType::Decimal(precision, scale) => {
+            sum_decimal_batch(values, precision, scale)?
+        }
         DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64),
         DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, Float32),
         DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64),
@@ -170,7 +205,7 @@ pub(super) fn sum_batch(values: &ArrayRef) -> Result<ScalarValue> {
             return Err(DataFusionError::Internal(format!(
                 "Sum is not expected to receive the type {:?}",
                 e
-            )))
+            )));
         }
     })
 }
@@ -187,8 +222,62 @@ macro_rules! typed_sum {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    scale: &usize,
+) -> ScalarValue {
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *scale),
+        (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
+        (Some(lhs_value), Some(rhs_value)) => {
+            ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale)
+        }
+    }
+}
+
+fn sum_decimal_with_diff_scale(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    lhs_scale: &usize,
+    rhs_scale: &usize,
+) -> ScalarValue {
+    // the lhs_scale must be greater or equal rhs_scale.
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
+        (None, Some(rhs_value)) => {
+            let new_value = rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32);
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale),
+        (Some(lhs_value), Some(rhs_value)) => {
+            let new_value =
+                rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs_value;
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+    }
+}
+
 pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
     Ok(match (lhs, rhs) {
+        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+            let max_precision = p1.max(p2);
+            if s1.eq(s2) {
+                // s1 = s2
+                sum_decimal(v1, v2, max_precision, s1)
+            } else if s1.gt(s2) {
+                // s1 > s2
+                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
+            } else {
+                // s1 < s2
+                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+            }
+        }
         // float64 coerces everything to f64
         (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
             typed_sum!(lhs, rhs, Float64, f64)
@@ -254,16 +343,14 @@ pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
             return Err(DataFusionError::Internal(format!(
                 "Sum is not expected to receive a scalar {:?}",
                 e
-            )))
+            )));
         }
     })
 }
 
 impl Accumulator for SumAccumulator {
-    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = &values[0];
-        self.sum = sum(&self.sum, &sum_batch(values)?)?;
-        Ok(())
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![self.sum.clone()])
     }
 
     fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
@@ -272,6 +359,12 @@ impl Accumulator for SumAccumulator {
         Ok(())
     }
 
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let values = &values[0];
+        self.sum = sum(&self.sum, &sum_batch(values)?)?;
+        Ok(())
+    }
+
     fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
         // sum(sum1, sum2) = sum1 + sum2
         self.update(states)
@@ -282,11 +375,9 @@ impl Accumulator for SumAccumulator {
         self.update_batch(states)
     }
 
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.sum.clone()])
-    }
-
     fn evaluate(&self) -> Result<ScalarValue> {
+        // TODO: add the checker for overflow
+        // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision.
         Ok(self.sum.clone())
     }
 }
@@ -294,12 +385,146 @@ impl Accumulator for SumAccumulator {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::arrow::array::DecimalBuilder;
     use crate::physical_plan::expressions::col;
     use crate::{error::Result, generic_test_op};
     use arrow::datatypes::*;
     use arrow::record_batch::RecordBatch;
 
     #[test]
+    fn test_sum_return_data_type() -> Result<()> {
+        let data_type = DataType::Decimal(10, 5);
+        let result_type = sum_return_type(&data_type)?;
+        assert_eq!(DataType::Decimal(20, 5), result_type);
+
+        let data_type = DataType::Decimal(36, 10);
+        let result_type = sum_return_type(&data_type)?;
+        assert_eq!(DataType::Decimal(38, 10), result_type);
+        Ok(())
+    }
+
+    #[test]
+    fn sum_decimal() -> Result<()> {
+        // test sum
+        let left = ScalarValue::Decimal128(Some(123), 10, 2);
+        let right = ScalarValue::Decimal128(Some(124), 10, 2);
+        let result = sum(&left, &right)?;
+        assert_eq!(ScalarValue::Decimal128(Some(123 + 124), 10, 2), result);
+        // test sum decimal with diff scale
+        let left = ScalarValue::Decimal128(Some(123), 10, 3);
+        let right = ScalarValue::Decimal128(Some(124), 10, 2);
+        let result = sum(&left, &right)?;
+        assert_eq!(
+            ScalarValue::Decimal128(Some(123 + 124 * 10_i128.pow(1)), 10, 3),
+            result
+        );
+        // diff precision and scale for decimal data type
+        let left = ScalarValue::Decimal128(Some(123), 10, 2);
+        let right = ScalarValue::Decimal128(Some(124), 11, 3);
+        let result = sum(&left, &right);
+        assert_eq!(
+            ScalarValue::Decimal128(Some(123 * 10_i128.pow(3 - 2) + 124), 11, 3),
+            result.unwrap()
+        );
+
+        // test sum batch
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for i in 1..6 {
+            decimal_builder.append_value(i as i128)?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        let result = sum_batch(&array)?;
+        assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result);
+
+        // test agg
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for i in 1..6 {
+            decimal_builder.append_value(i as i128)?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Sum,
+            ScalarValue::Decimal128(Some(15), 20, 0),
+            DataType::Decimal(20, 0)
+        )
+    }
+
+    #[test]
+    fn sum_decimal_with_nulls() -> Result<()> {
+        // test sum
+        let left = ScalarValue::Decimal128(None, 10, 2);
+        let right = ScalarValue::Decimal128(Some(123), 10, 2);
+        let result = sum(&left, &right)?;
+        assert_eq!(ScalarValue::Decimal128(Some(123), 10, 2), result);
+
+        // test with batch
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for i in 1..6 {
+            if i == 2 {
+                decimal_builder.append_null()?;
+            } else {
+                decimal_builder.append_value(i)?;
+            }
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        let result = sum_batch(&array)?;
+        assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result);
+
+        // test agg
+        let mut decimal_builder = DecimalBuilder::new(5, 35, 0);
+        for i in 1..6 {
+            if i == 2 {
+                decimal_builder.append_null()?;
+            } else {
+                decimal_builder.append_value(i)?;
+            }
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(35, 0),
+            Sum,
+            ScalarValue::Decimal128(Some(13), 38, 0),
+            DataType::Decimal(38, 0)
+        )
+    }
+
+    #[test]
+    fn sum_decimal_all_nulls() -> Result<()> {
+        // test sum
+        let left = ScalarValue::Decimal128(None, 10, 2);
+        let right = ScalarValue::Decimal128(None, 10, 2);
+        let result = sum(&left, &right)?;
+        assert_eq!(ScalarValue::Decimal128(None, 10, 2), result);
+
+        // test with batch
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for _i in 1..6 {
+            decimal_builder.append_null()?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        let result = sum_batch(&array)?;
+        assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+        // test agg
+        let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+        for _i in 1..6 {
+            decimal_builder.append_null()?;
+        }
+        let array: ArrayRef = Arc::new(decimal_builder.finish());
+        generic_test_op!(
+            array,
+            DataType::Decimal(10, 0),
+            Sum,
+            ScalarValue::Decimal128(None, 20, 0),
+            DataType::Decimal(20, 0)
+        )
+    }
+
+    #[test]
     fn sum_i32() -> Result<()> {
         let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
         generic_test_op!(
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index e9eafe1..35ebb2a 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -33,6 +33,11 @@ use std::convert::{Infallible, TryInto};
 use std::str::FromStr;
 use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
 
+// TODO may need to be moved to arrow-rs
+/// The max precision and scale for decimal128
+pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38;
+pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38;
+
 /// Represents a dynamically typed, nullable single value.
 /// This is the single-valued counter-part of arrow’s `Array`.
 #[derive(Clone)]
@@ -480,8 +485,7 @@ impl ScalarValue {
         scale: usize,
     ) -> Result<Self> {
         // make sure the precision and scale is valid
-        // TODO const the max precision and min scale
-        if precision <= 38 && scale <= precision {
+        if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision {
             return Ok(ScalarValue::Decimal128(Some(value), precision, scale));
         }
         return Err(DataFusionError::Internal(format!(
diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs
index bce50e5..0ede5ad 100644
--- a/datafusion/src/sql/utils.rs
+++ b/datafusion/src/sql/utils.rs
@@ -20,7 +20,7 @@
 use arrow::datatypes::DataType;
 
 use crate::logical_plan::{Expr, LogicalPlan};
-use crate::scalar::ScalarValue;
+use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128};
 use crate::{
     error::{DataFusionError, Result},
     logical_plan::{Column, ExpressionVisitor, Recursion},
@@ -520,7 +520,7 @@ pub(crate) fn make_decimal_type(
         }
         (Some(p), Some(s)) => {
             // Arrow decimal is i128 meaning 38 maximum decimal digits
-            if p > 38 || s > p {
+            if (p as usize) > MAX_PRECISION_FOR_DECIMAL128 || s > p {
                 return Err(DataFusionError::Internal(format!(
                     "For decimal(precision, scale) precision must be less than or equal to 38 and scale can't be greater than precision. Got ({}, {})",
                     p, s