You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/12/14 19:49:30 UTC

[GitHub] [arrow-datafusion] alamb commented on a change in pull request #1408: support sum/avg agg for decimal, change sum(float32) --> float64

alamb commented on a change in pull request #1408:
URL: https://github.com/apache/arrow-datafusion/pull/1408#discussion_r768967806



##########
File path: datafusion/src/execution/context.rs
##########
@@ -1878,6 +1883,58 @@ mod tests {
             "| 110.009         |",
             "+-----------------+",
         ];
+        assert_eq!(
+            &DataType::Decimal(10, 3),
+            result[0].schema().field(0).data_type()
+        );
+        assert_batches_sorted_eq!(expected, &result);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn aggregate_decimal_sum() -> Result<()> {
+        let mut ctx = ExecutionContext::new();
+        // the data type of c1 is decimal(10,3)
+        ctx.register_table("d_table", test::table_with_decimal())
+            .unwrap();
+        let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table")
+            .await
+            .unwrap();
+        let expected = vec![
+            "+-----------------+",
+            "| SUM(d_table.c1) |",
+            "+-----------------+",
+            "| 100.000         |",
+            "+-----------------+",
+        ];
+        assert_eq!(
+            &DataType::Decimal(20, 3),

Review comment:
       👍 

##########
File path: datafusion/src/physical_plan/expressions/average.rs
##########
@@ -220,6 +243,74 @@ mod tests {
     use arrow::record_batch::RecordBatch;
     use arrow::{array::*, datatypes::*};
 
+    #[test]
+    fn test_avg_return_data_type() -> Result<()> {
+        let data_type = DataType::Decimal(10, 5);
+        let result_type = avg_return_type(&data_type)?;
+        assert_eq!(DataType::Decimal(14, 9), result_type);
+
+        let data_type = DataType::Decimal(36, 10);
+        let result_type = avg_return_type(&data_type)?;
+        assert_eq!(DataType::Decimal(38, 14), result_type);
+        Ok(())
+    }
+
+    #[test]
+    fn avg_decimal() -> Result<()> {
+        // test agg
+        let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
+        for i in 1..7 {
+            // the avg is 3.5, but we get the result of 3

Review comment:
       I don't understand this comment: the result is `3.5` below (`35000` with precision `14` and scale `4`)

##########
File path: datafusion/src/physical_plan/expressions/sum.rs
##########
@@ -187,8 +222,63 @@ macro_rules! typed_sum {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    scale: &usize,
+) -> ScalarValue {
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *scale),
+        (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
+        (lhs, rhs) => {
+            ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
+        }
+    }
+}
+
+fn sum_decimal_with_diff_scale(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    lhs_scale: &usize,
+    rhs_scale: &usize,
+) -> ScalarValue {
+    // the lhs_scale must be greater or equal rhs_scale.
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
+        (None, rhs) => {
+            let new_value = rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32);
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale),
+        (lhs, rhs) => {

Review comment:
       ```suggestion
           (Some(lhs), rhs) => {
   ```
   ```suggestion
           (lhs, rhs) => {
   ```

##########
File path: datafusion/src/physical_plan/expressions/average.rs
##########
@@ -102,7 +114,14 @@ impl AggregateExpr for Avg {
     }
 
     fn field(&self) -> Result<Field> {
-        Ok(Field::new(&self.name, DataType::Float64, true))
+        Ok(Field::new(&self.name, self.data_type.clone(), true))
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(AvgAccumulator::try_new(
+            // avg is f64 or decimal
+            &self.data_type,

Review comment:
       if a sum of `decimal(10,2)` can be `decimal(20,2)` shouldn't the accumulator state also be `decimal(20,2)` to avoid overflow?
   
   I think handling overflow is probably fine to for a later date / PR, but it is strange to me that there is a discrepancy between the type for `sum` and the accumulator type for computing `avg`

##########
File path: datafusion/src/physical_plan/expressions/average.rs
##########
@@ -205,6 +217,17 @@ impl Accumulator for AvgAccumulator {
             ScalarValue::Float64(e) => {
                 Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64)))
             }
+            ScalarValue::Decimal128(value, precision, scale) => {
+                Ok(match value {
+                    None => ScalarValue::Decimal128(None, precision, scale),
+                    // TODO add the checker for overflow the precision

Review comment:
       👍 

##########
File path: datafusion/src/physical_plan/expressions/sum.rs
##########
@@ -54,8 +56,15 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
         DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
             Ok(DataType::UInt64)
         }
-        DataType::Float32 => Ok(DataType::Float32),
-        DataType::Float64 => Ok(DataType::Float64),
+        // In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc,
+        // the result type of floating-point is FLOAT64 with the double precision.
+        DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
+        // Max precision is 38
+        DataType::Decimal(precision, scale) => {
+            // in the spark, the result type is DECIMAL(min(38,precision+10), s)
+            let new_precision = 38.min(*precision + 10);

Review comment:
       As I mentioned above,  I think it will improve usability if we used symbolic constants in the code rather than hard coded numbers such as `38` and `10`
   
   Something like
   
   ```rust
   /// Maximum precision for a decimal number
   let DECIMAL_MAX_PRECISION: usize = 38;
   ```

##########
File path: datafusion/src/physical_plan/expressions/average.rs
##########
@@ -86,11 +94,15 @@ impl Avg {
         // Average is always Float64, but Avg::new() has a data_type
         // parameter to keep a consistent signature with the other
         // Aggregate expressions.
-        assert_eq!(data_type, DataType::Float64);
-
-        Self {
-            name: name.into(),
-            expr,
+        match data_type {

Review comment:
       The comment above looks out of date -- I think it should simply be removed. 
   
   And perhaps we can change this code so it doesn't use `unreachable` as I think it would be fairly easy to reach this code by calling `Avg::new(..)` with some incorrect paramters
   
   How about something like
   
   ```rust
   assert!(matches!(data_type, DataType::Float64 | DataType::Decimal(_, _)));
   ```
   
   Which I think might be easier to diagnose if anyone hits it

##########
File path: datafusion/src/physical_plan/expressions/sum.rs
##########
@@ -153,9 +163,34 @@ macro_rules! typed_sum_delta_batch {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal_batch(
+    values: &ArrayRef,
+    precision: &usize,
+    scale: &usize,
+) -> Result<ScalarValue> {
+    let array = values.as_any().downcast_ref::<DecimalArray>().unwrap();
+
+    if array.null_count() == array.len() {
+        return Ok(ScalarValue::Decimal128(None, *precision, *scale));
+    }
+
+    let mut result = 0_i128;
+    for i in 0..array.len() {
+        if array.is_valid(i) {
+            result += array.value(i);

Review comment:
       I wonder if we need to check for overflow here as well?

##########
File path: datafusion/src/physical_plan/expressions/average.rs
##########
@@ -38,11 +38,18 @@ use super::{format_state_name, sum};
 pub struct Avg {
     name: String,
     expr: Arc<dyn PhysicalExpr>,
+    data_type: DataType,
 }
 
 /// function return type of an average
 pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
     match arg_type {
+        DataType::Decimal(precision, scale) => {
+            // the new precision and scale for return type of avg function

Review comment:
       Can you please document document the rationale for the `4` and `38` constants below (or even better pull them into named constants somewhere)?
   
   I also don't understand where the additional `4` came from. I tried to see if it was what postgres did, but when I checked the output schema for `avg(numeric(10,3))` appears to be `numeric` without the precision or scale specified 🤔 
   
   
   ```shell
   (arrow_dev) alamb@MacBook-Pro-2:~/Downloads$ psql
   psql (14.1)
   Type "help" for help.
   
   alamb=# create table test(x decimal(10, 3));
   CREATE TABLE
   alamb=# insert into test values (1.02);
   INSERT 0 1
   alamb=# create table test2 as select avg(x) from test;
   SELECT 1
   
   alamb=# select table_name, column_name, numeric_precision, numeric_scale, data_type from information_schema.columns where table_name='test2';
    table_name | column_name | numeric_precision | numeric_scale | data_type 
   ------------+-------------+-------------------+---------------+-----------
    test2      | avg         |                   |               | numeric
   (1 row)
   ```
   

##########
File path: datafusion/src/physical_plan/expressions/sum.rs
##########
@@ -54,8 +56,15 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
         DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
             Ok(DataType::UInt64)
         }
-        DataType::Float32 => Ok(DataType::Float32),
-        DataType::Float64 => Ok(DataType::Float64),
+        // In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc,
+        // the result type of floating-point is FLOAT64 with the double precision.
+        DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
+        // Max precision is 38
+        DataType::Decimal(precision, scale) => {
+            // in the spark, the result type is DECIMAL(min(38,precision+10), s)

Review comment:
       thank you for the context

##########
File path: datafusion/src/physical_plan/expressions/sum.rs
##########
@@ -187,8 +222,63 @@ macro_rules! typed_sum {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    scale: &usize,
+) -> ScalarValue {
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *scale),
+        (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
+        (lhs, rhs) => {
+            ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
+        }
+    }
+}
+
+fn sum_decimal_with_diff_scale(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    lhs_scale: &usize,
+    rhs_scale: &usize,
+) -> ScalarValue {
+    // the lhs_scale must be greater or equal rhs_scale.
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
+        (None, rhs) => {
+            let new_value = rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32);
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale),
+        (lhs, rhs) => {
+            let new_value =
+                rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs.unwrap();
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+    }
+}
+
 pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
     Ok(match (lhs, rhs) {
+        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+            if s1.eq(s2) {
+                sum_decimal(v1, v2, p1, s1)
+            } else if s1.gt(s2) && p1.ge(p2) {
+                // For avg aggravate function.
+                // In the avg function, the scale of result data type is different with the scale of the input data type.

Review comment:
       I think this comment also applies to `sum`, not just `avg`

##########
File path: datafusion/src/physical_plan/expressions/sum.rs
##########
@@ -187,8 +222,63 @@ macro_rules! typed_sum {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    scale: &usize,
+) -> ScalarValue {
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *scale),
+        (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
+        (lhs, rhs) => {
+            ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
+        }
+    }
+}
+
+fn sum_decimal_with_diff_scale(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    lhs_scale: &usize,
+    rhs_scale: &usize,
+) -> ScalarValue {
+    // the lhs_scale must be greater or equal rhs_scale.
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
+        (None, rhs) => {

Review comment:
       ```suggestion
           (None, Some(rhs)) => {
   ```
   I think you could avoid the `unwrap` below (which does a redundant check) by using a pattern match here
   

##########
File path: datafusion/src/physical_plan/expressions/sum.rs
##########
@@ -187,8 +222,63 @@ macro_rules! typed_sum {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    scale: &usize,
+) -> ScalarValue {
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *scale),
+        (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
+        (lhs, rhs) => {
+            ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
+        }
+    }
+}
+
+fn sum_decimal_with_diff_scale(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    lhs_scale: &usize,
+    rhs_scale: &usize,
+) -> ScalarValue {
+    // the lhs_scale must be greater or equal rhs_scale.
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
+        (None, rhs) => {
+            let new_value = rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32);
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale),
+        (lhs, rhs) => {
+            let new_value =
+                rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs.unwrap();
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+    }
+}
+
 pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
     Ok(match (lhs, rhs) {
+        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+            if s1.eq(s2) {
+                sum_decimal(v1, v2, p1, s1)
+            } else if s1.gt(s2) && p1.ge(p2) {
+                // For avg aggravate function.

Review comment:
       ```suggestion
                   // For avg aggregate function.
   ```
   ```suggestion
                   // For avg aggravate function.
   ```

##########
File path: datafusion/src/physical_plan/expressions/sum.rs
##########
@@ -187,8 +222,63 @@ macro_rules! typed_sum {
     }};
 }
 
+// TODO implement this in arrow-rs with simd
+// https://github.com/apache/arrow-rs/issues/1010
+fn sum_decimal(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    scale: &usize,
+) -> ScalarValue {
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *scale),
+        (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
+        (lhs, rhs) => {
+            ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
+        }
+    }
+}
+
+fn sum_decimal_with_diff_scale(
+    lhs: &Option<i128>,
+    rhs: &Option<i128>,
+    precision: &usize,
+    lhs_scale: &usize,
+    rhs_scale: &usize,
+) -> ScalarValue {
+    // the lhs_scale must be greater or equal rhs_scale.
+    match (lhs, rhs) {
+        (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
+        (None, rhs) => {
+            let new_value = rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32);
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale),
+        (lhs, rhs) => {
+            let new_value =
+                rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs.unwrap();
+            ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
+        }
+    }
+}
+
 pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
     Ok(match (lhs, rhs) {
+        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+            if s1.eq(s2) {
+                sum_decimal(v1, v2, p1, s1)
+            } else if s1.gt(s2) && p1.ge(p2) {

Review comment:
       I don't understand the need for this clause. It means, among other things, it would seem to make `sum` for decimal is not commutative which is confusing
   
   I would expect that `sum(lhs, rhs) == sum(rhs, lhs)` for any specific `lhs` and `rhs`
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org