You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/17 15:29:17 UTC
[arrow-datafusion] branch master updated: support sum/avg agg for decimal, change sum(float32) --> float64 (#1408)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 9d31866 support sum/avg agg for decimal, change sum(float32) --> float64 (#1408)
9d31866 is described below
commit 9d3186693b614db57143adbd81c82a60752a8bac
Author: Kun Liu <li...@apache.org>
AuthorDate: Fri Dec 17 23:29:10 2021 +0800
support sum/avg agg for decimal, change sum(float32) --> float64 (#1408)
* support sum/avg agg for decimal
* support sum/avg agg for decimal
* suppor the avg and add test
* add comments and const
---
datafusion/src/execution/context.rs | 59 ++++-
datafusion/src/physical_plan/aggregates.rs | 34 ++-
.../physical_plan/coercion_rule/aggregate_rule.rs | 3 +-
.../src/physical_plan/expressions/average.rs | 120 ++++++++--
datafusion/src/physical_plan/expressions/sum.rs | 259 +++++++++++++++++++--
datafusion/src/scalar.rs | 8 +-
datafusion/src/sql/utils.rs | 4 +-
7 files changed, 447 insertions(+), 40 deletions(-)
diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index d7c536e..8c3df46 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1845,9 +1845,9 @@ mod tests {
#[tokio::test]
async fn aggregate_decimal_min() -> Result<()> {
let mut ctx = ExecutionContext::new();
+ // the data type of c1 is decimal(10,3)
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();
-
let result = plan_and_collect(&mut ctx, "select min(c1) from d_table")
.await
.unwrap();
@@ -1858,6 +1858,10 @@ mod tests {
"| -100.009 |",
"+-----------------+",
];
+ assert_eq!(
+ &DataType::Decimal(10, 3),
+ result[0].schema().field(0).data_type()
+ );
assert_batches_sorted_eq!(expected, &result);
Ok(())
}
@@ -1865,6 +1869,7 @@ mod tests {
#[tokio::test]
async fn aggregate_decimal_max() -> Result<()> {
let mut ctx = ExecutionContext::new();
+ // the data type of c1 is decimal(10,3)
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();
@@ -1878,6 +1883,58 @@ mod tests {
"| 110.009 |",
"+-----------------+",
];
+ assert_eq!(
+ &DataType::Decimal(10, 3),
+ result[0].schema().field(0).data_type()
+ );
+ assert_batches_sorted_eq!(expected, &result);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn aggregate_decimal_sum() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ // the data type of c1 is decimal(10,3)
+ ctx.register_table("d_table", test::table_with_decimal())
+ .unwrap();
+ let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table")
+ .await
+ .unwrap();
+ let expected = vec![
+ "+-----------------+",
+ "| SUM(d_table.c1) |",
+ "+-----------------+",
+ "| 100.000 |",
+ "+-----------------+",
+ ];
+ assert_eq!(
+ &DataType::Decimal(20, 3),
+ result[0].schema().field(0).data_type()
+ );
+ assert_batches_sorted_eq!(expected, &result);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn aggregate_decimal_avg() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ // the data type of c1 is decimal(10,3)
+ ctx.register_table("d_table", test::table_with_decimal())
+ .unwrap();
+ let result = plan_and_collect(&mut ctx, "select avg(c1) from d_table")
+ .await
+ .unwrap();
+ let expected = vec![
+ "+-----------------+",
+ "| AVG(d_table.c1) |",
+ "+-----------------+",
+ "| 5.0000000 |",
+ "+-----------------+",
+ ];
+ assert_eq!(
+ &DataType::Decimal(14, 7),
+ result[0].schema().field(0).data_type()
+ );
assert_batches_sorted_eq!(expected, &result);
Ok(())
}
diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs
index 50e1a82..e9f9696 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -426,7 +426,7 @@ mod tests {
| DataType::Int16
| DataType::Int32
| DataType::Int64 => DataType::Int64,
- DataType::Float32 | DataType::Float64 => data_type.clone(),
+ DataType::Float32 | DataType::Float64 => DataType::Float64,
_ => data_type.clone(),
};
@@ -471,6 +471,29 @@ mod tests {
}
#[test]
+ fn test_sum_return_type() -> Result<()> {
+ let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32])?;
+ assert_eq!(DataType::Int64, observed);
+
+ let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8])?;
+ assert_eq!(DataType::UInt64, observed);
+
+ let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32])?;
+ assert_eq!(DataType::Float64, observed);
+
+ let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(10, 5)])?;
+ assert_eq!(DataType::Decimal(20, 5), observed);
+
+ let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(35, 5)])?;
+ assert_eq!(DataType::Decimal(38, 5), observed);
+
+ Ok(())
+ }
+
+ #[test]
fn test_sum_no_utf8() {
let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]);
assert!(observed.is_err());
@@ -504,6 +527,15 @@ mod tests {
let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);
+
+ let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32])?;
+ assert_eq!(DataType::Float64, observed);
+
+ let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(10, 6)])?;
+ assert_eq!(DataType::Decimal(14, 10), observed);
+
+ let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(36, 6)])?;
+ assert_eq!(DataType::Decimal(38, 10), observed);
Ok(())
}
diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
index d7b4375..e76e4a6 100644
--- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
+++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
@@ -193,8 +193,7 @@ mod tests {
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Float32],
- // support the decimal data type
- // vec![DataType::Decimal(20, 3)],
+ vec![DataType::Decimal(20, 3)],
];
for fun in funs {
for input_type in &input_types {
diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs
index feb568c..f092989 100644
--- a/datafusion/src/physical_plan/expressions/average.rs
+++ b/datafusion/src/physical_plan/expressions/average.rs
@@ -23,7 +23,9 @@ use std::sync::Arc;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
-use crate::scalar::ScalarValue;
+use crate::scalar::{
+ ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128,
+};
use arrow::compute;
use arrow::datatypes::DataType;
use arrow::{
@@ -38,11 +40,19 @@ use super::{format_state_name, sum};
pub struct Avg {
name: String,
expr: Arc<dyn PhysicalExpr>,
+ data_type: DataType,
}
/// function return type of an average
pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
+ DataType::Decimal(precision, scale) => {
+ // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
+ // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+ let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4);
+ let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4);
+ Ok(DataType::Decimal(new_precision, new_scale))
+ }
DataType::Int8
| DataType::Int16
| DataType::Int32
@@ -73,6 +83,7 @@ pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
| DataType::Int64
| DataType::Float32
| DataType::Float64
+ | DataType::Decimal(_, _)
)
}
@@ -83,14 +94,15 @@ impl Avg {
name: impl Into<String>,
data_type: DataType,
) -> Self {
- // Average is always Float64, but Avg::new() has a data_type
- // parameter to keep a consistent signature with the other
- // Aggregate expressions.
- assert_eq!(data_type, DataType::Float64);
-
+ // the result of avg just support FLOAT64 and Decimal data type.
+ assert!(matches!(
+ data_type,
+ DataType::Float64 | DataType::Decimal(_, _)
+ ));
Self {
name: name.into(),
expr,
+ data_type,
}
}
}
@@ -102,7 +114,14 @@ impl AggregateExpr for Avg {
}
fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, DataType::Float64, true))
+ Ok(Field::new(&self.name, self.data_type.clone(), true))
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(AvgAccumulator::try_new(
+ // avg is f64 or decimal
+ &self.data_type,
+ )?))
}
fn state_fields(&self) -> Result<Vec<Field>> {
@@ -114,19 +133,12 @@ impl AggregateExpr for Avg {
),
Field::new(
&format_state_name(&self.name, "sum"),
- DataType::Float64,
+ self.data_type.clone(),
true,
),
])
}
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(AvgAccumulator::try_new(
- // avg is f64
- &DataType::Float64,
- )?))
- }
-
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}
@@ -205,6 +217,17 @@ impl Accumulator for AvgAccumulator {
ScalarValue::Float64(e) => {
Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64)))
}
+ ScalarValue::Decimal128(value, precision, scale) => {
+ Ok(match value {
+ None => ScalarValue::Decimal128(None, precision, scale),
+ // TODO add the checker for overflow the precision
+ Some(v) => ScalarValue::Decimal128(
+ Some(v / self.count as i128),
+ precision,
+ scale,
+ ),
+ })
+ }
_ => Err(DataFusionError::Internal(
"Sum should be f64 on average".to_string(),
)),
@@ -221,6 +244,73 @@ mod tests {
use arrow::{array::*, datatypes::*};
#[test]
+ fn test_avg_return_data_type() -> Result<()> {
+ let data_type = DataType::Decimal(10, 5);
+ let result_type = avg_return_type(&data_type)?;
+ assert_eq!(DataType::Decimal(14, 9), result_type);
+
+ let data_type = DataType::Decimal(36, 10);
+ let result_type = avg_return_type(&data_type)?;
+ assert_eq!(DataType::Decimal(38, 14), result_type);
+ Ok(())
+ }
+
+ #[test]
+ fn avg_decimal() -> Result<()> {
+ // test agg
+ let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
+ for i in 1..7 {
+ decimal_builder.append_value(i as i128)?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Avg,
+ ScalarValue::Decimal128(Some(35000), 14, 4),
+ DataType::Decimal(14, 4)
+ )
+ }
+
+ #[test]
+ fn avg_decimal_with_nulls() -> Result<()> {
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for i in 1..6 {
+ if i == 2 {
+ decimal_builder.append_null()?;
+ } else {
+ decimal_builder.append_value(i)?;
+ }
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Avg,
+ ScalarValue::Decimal128(Some(32500), 14, 4),
+ DataType::Decimal(14, 4)
+ )
+ }
+
+ #[test]
+ fn avg_decimal_all_nulls() -> Result<()> {
+ // test agg
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for _i in 1..6 {
+ decimal_builder.append_null()?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Avg,
+ ScalarValue::Decimal128(None, 14, 4),
+ DataType::Decimal(14, 4)
+ )
+ }
+
+ #[test]
fn avg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
generic_test_op!(
diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs
index c570aef..027736d 100644
--- a/datafusion/src/physical_plan/expressions/sum.rs
+++ b/datafusion/src/physical_plan/expressions/sum.rs
@@ -23,7 +23,7 @@ use std::sync::Arc;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
-use crate::scalar::ScalarValue;
+use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128};
use arrow::compute;
use arrow::datatypes::DataType;
use arrow::{
@@ -35,6 +35,8 @@ use arrow::{
};
use super::format_state_name;
+use crate::arrow::array::Array;
+use arrow::array::DecimalArray;
/// SUM aggregate expression
#[derive(Debug)]
@@ -54,8 +56,15 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
Ok(DataType::UInt64)
}
- DataType::Float32 => Ok(DataType::Float32),
- DataType::Float64 => Ok(DataType::Float64),
+ // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,
+ // the result type of floating-point is FLOAT64 with the double precision.
+ DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
+ DataType::Decimal(precision, scale) => {
+ // in the spark, the result type is DECIMAL(min(38,precision+10), s)
+ // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 10);
+ Ok(DataType::Decimal(new_precision, *scale))
+ }
other => Err(DataFusionError::Plan(format!(
"SUM does not support type \"{:?}\"",
other
@@ -76,6 +85,7 @@ pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
| DataType::Int64
| DataType::Float32
| DataType::Float64
+ | DataType::Decimal(_, _)
)
}
@@ -109,6 +119,10 @@ impl AggregateExpr for Sum {
))
}
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
+ }
+
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new(
&format_state_name(&self.name, "sum"),
@@ -121,10 +135,6 @@ impl AggregateExpr for Sum {
vec![self.expr.clone()]
}
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
- }
-
fn name(&self) -> &str {
&self.name
}
@@ -153,9 +163,34 @@ macro_rules! typed_sum_delta_batch {
}};
}
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal_batch(
+ values: &ArrayRef,
+ precision: &usize,
+ scale: &usize,
+) -> Result<ScalarValue> {
+ let array = values.as_any().downcast_ref::<DecimalArray>().unwrap();
+
+ if array.null_count() == array.len() {
+ return Ok(ScalarValue::Decimal128(None, *precision, *scale));
+ }
+
+ let mut result = 0_i128;
+ for i in 0..array.len() {
+ if array.is_valid(i) {
+ result += array.value(i);
+ }
+ }
+ Ok(ScalarValue::Decimal128(Some(result), *precision, *scale))
+}
+
// sums the array and returns a ScalarValue of its corresponding type.
pub(super) fn sum_batch(values: &ArrayRef) -> Result<ScalarValue> {
Ok(match values.data_type() {
+ DataType::Decimal(precision, scale) => {
+ sum_decimal_batch(values, precision, scale)?
+ }
DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64),
DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, Float32),
DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64),
@@ -170,7 +205,7 @@ pub(super) fn sum_batch(values: &ArrayRef) -> Result<ScalarValue> {
return Err(DataFusionError::Internal(format!(
"Sum is not expected to receive the type {:?}",
e
- )))
+ )));
}
})
}
@@ -187,8 +222,62 @@ macro_rules! typed_sum {
}};
}
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal(
+ lhs: &Option<i128>,
+ rhs: &Option<i128>,
+ precision: &usize,
+ scale: &usize,
+) -> ScalarValue {
+ match (lhs, rhs) {
+ (None, None) => ScalarValue::Decimal128(None, *precision, *scale),
+ (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
+ (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
+ (Some(lhs_value), Some(rhs_value)) => {
+ ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale)
+ }
+ }
+}
+
+fn sum_decimal_with_diff_scale(
+ lhs: &Option<i128>,
+ rhs: &Option<i128>,
+ precision: &usize,
+ lhs_scale: &usize,
+ rhs_scale: &usize,
+) -> ScalarValue {
+ // the lhs_scale must be greater or equal rhs_scale.
+ match (lhs, rhs) {
+ (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
+ (None, Some(rhs_value)) => {
+ let new_value = rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32);
+ ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+ }
+ (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale),
+ (Some(lhs_value), Some(rhs_value)) => {
+ let new_value =
+ rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs_value;
+ ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+ }
+ }
+}
+
pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
Ok(match (lhs, rhs) {
+ (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+ let max_precision = p1.max(p2);
+ if s1.eq(s2) {
+ // s1 = s2
+ sum_decimal(v1, v2, max_precision, s1)
+ } else if s1.gt(s2) {
+ // s1 > s2
+ sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
+ } else {
+ // s1 < s2
+ sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+ }
+ }
// float64 coerces everything to f64
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
typed_sum!(lhs, rhs, Float64, f64)
@@ -254,16 +343,14 @@ pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
return Err(DataFusionError::Internal(format!(
"Sum is not expected to receive a scalar {:?}",
e
- )))
+ )));
}
})
}
impl Accumulator for SumAccumulator {
- fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let values = &values[0];
- self.sum = sum(&self.sum, &sum_batch(values)?)?;
- Ok(())
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.sum.clone()])
}
fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
@@ -272,6 +359,12 @@ impl Accumulator for SumAccumulator {
Ok(())
}
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = &values[0];
+ self.sum = sum(&self.sum, &sum_batch(values)?)?;
+ Ok(())
+ }
+
fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
// sum(sum1, sum2) = sum1 + sum2
self.update(states)
@@ -282,11 +375,9 @@ impl Accumulator for SumAccumulator {
self.update_batch(states)
}
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![self.sum.clone()])
- }
-
fn evaluate(&self) -> Result<ScalarValue> {
+ // TODO: add the checker for overflow
+ // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision.
Ok(self.sum.clone())
}
}
@@ -294,12 +385,146 @@ impl Accumulator for SumAccumulator {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::arrow::array::DecimalBuilder;
use crate::physical_plan::expressions::col;
use crate::{error::Result, generic_test_op};
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;
#[test]
+ fn test_sum_return_data_type() -> Result<()> {
+ let data_type = DataType::Decimal(10, 5);
+ let result_type = sum_return_type(&data_type)?;
+ assert_eq!(DataType::Decimal(20, 5), result_type);
+
+ let data_type = DataType::Decimal(36, 10);
+ let result_type = sum_return_type(&data_type)?;
+ assert_eq!(DataType::Decimal(38, 10), result_type);
+ Ok(())
+ }
+
+ #[test]
+ fn sum_decimal() -> Result<()> {
+ // test sum
+ let left = ScalarValue::Decimal128(Some(123), 10, 2);
+ let right = ScalarValue::Decimal128(Some(124), 10, 2);
+ let result = sum(&left, &right)?;
+ assert_eq!(ScalarValue::Decimal128(Some(123 + 124), 10, 2), result);
+ // test sum decimal with diff scale
+ let left = ScalarValue::Decimal128(Some(123), 10, 3);
+ let right = ScalarValue::Decimal128(Some(124), 10, 2);
+ let result = sum(&left, &right)?;
+ assert_eq!(
+ ScalarValue::Decimal128(Some(123 + 124 * 10_i128.pow(1)), 10, 3),
+ result
+ );
+ // diff precision and scale for decimal data type
+ let left = ScalarValue::Decimal128(Some(123), 10, 2);
+ let right = ScalarValue::Decimal128(Some(124), 11, 3);
+ let result = sum(&left, &right);
+ assert_eq!(
+ ScalarValue::Decimal128(Some(123 * 10_i128.pow(3 - 2) + 124), 11, 3),
+ result.unwrap()
+ );
+
+ // test sum batch
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for i in 1..6 {
+ decimal_builder.append_value(i as i128)?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ let result = sum_batch(&array)?;
+ assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result);
+
+ // test agg
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for i in 1..6 {
+ decimal_builder.append_value(i as i128)?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Sum,
+ ScalarValue::Decimal128(Some(15), 20, 0),
+ DataType::Decimal(20, 0)
+ )
+ }
+
+ #[test]
+ fn sum_decimal_with_nulls() -> Result<()> {
+ // test sum
+ let left = ScalarValue::Decimal128(None, 10, 2);
+ let right = ScalarValue::Decimal128(Some(123), 10, 2);
+ let result = sum(&left, &right)?;
+ assert_eq!(ScalarValue::Decimal128(Some(123), 10, 2), result);
+
+ // test with batch
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for i in 1..6 {
+ if i == 2 {
+ decimal_builder.append_null()?;
+ } else {
+ decimal_builder.append_value(i)?;
+ }
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ let result = sum_batch(&array)?;
+ assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result);
+
+ // test agg
+ let mut decimal_builder = DecimalBuilder::new(5, 35, 0);
+ for i in 1..6 {
+ if i == 2 {
+ decimal_builder.append_null()?;
+ } else {
+ decimal_builder.append_value(i)?;
+ }
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(35, 0),
+ Sum,
+ ScalarValue::Decimal128(Some(13), 38, 0),
+ DataType::Decimal(38, 0)
+ )
+ }
+
+ #[test]
+ fn sum_decimal_all_nulls() -> Result<()> {
+ // test sum
+ let left = ScalarValue::Decimal128(None, 10, 2);
+ let right = ScalarValue::Decimal128(None, 10, 2);
+ let result = sum(&left, &right)?;
+ assert_eq!(ScalarValue::Decimal128(None, 10, 2), result);
+
+ // test with batch
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for _i in 1..6 {
+ decimal_builder.append_null()?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ let result = sum_batch(&array)?;
+ assert_eq!(ScalarValue::Decimal128(None, 10, 0), result);
+
+ // test agg
+ let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
+ for _i in 1..6 {
+ decimal_builder.append_null()?;
+ }
+ let array: ArrayRef = Arc::new(decimal_builder.finish());
+ generic_test_op!(
+ array,
+ DataType::Decimal(10, 0),
+ Sum,
+ ScalarValue::Decimal128(None, 20, 0),
+ DataType::Decimal(20, 0)
+ )
+ }
+
+ #[test]
fn sum_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
generic_test_op!(
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index e9eafe1..35ebb2a 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -33,6 +33,11 @@ use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
+// TODO may need to be moved to arrow-rs
+/// The max precision and scale for decimal128
+pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38;
+pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38;
+
/// Represents a dynamically typed, nullable single value.
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone)]
@@ -480,8 +485,7 @@ impl ScalarValue {
scale: usize,
) -> Result<Self> {
// make sure the precision and scale is valid
- // TODO const the max precision and min scale
- if precision <= 38 && scale <= precision {
+ if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision {
return Ok(ScalarValue::Decimal128(Some(value), precision, scale));
}
return Err(DataFusionError::Internal(format!(
diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs
index bce50e5..0ede5ad 100644
--- a/datafusion/src/sql/utils.rs
+++ b/datafusion/src/sql/utils.rs
@@ -20,7 +20,7 @@
use arrow::datatypes::DataType;
use crate::logical_plan::{Expr, LogicalPlan};
-use crate::scalar::ScalarValue;
+use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128};
use crate::{
error::{DataFusionError, Result},
logical_plan::{Column, ExpressionVisitor, Recursion},
@@ -520,7 +520,7 @@ pub(crate) fn make_decimal_type(
}
(Some(p), Some(s)) => {
// Arrow decimal is i128 meaning 38 maximum decimal digits
- if p > 38 || s > p {
+ if (p as usize) > MAX_PRECISION_FOR_DECIMAL128 || s > p {
return Err(DataFusionError::Internal(format!(
"For decimal(precision, scale) precision must be less than or equal to 38 and scale can't be greater than precision. Got ({}, {})",
p, s