You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by mi...@apache.org on 2023/04/13 14:57:09 UTC

[arrow-datafusion] branch main updated: Row `AVG` accumulator support Decimal type (#5973)

This is an automated email from the ASF dual-hosted git repository.

mingmwang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new fcd8b899e2 Row `AVG` accumulator support Decimal type (#5973)
fcd8b899e2 is described below

commit fcd8b899e2a62f798413c536f47078289ece9d05
Author: mingmwang <mi...@gmail.com>
AuthorDate: Thu Apr 13 22:57:02 2023 +0800

    Row `AVG` accumulator support Decimal type (#5973)
    
    * RowAccumulator support for Decimal128
    
    * add test
---
 .../tests/sqllogictests/test_files/aggregate.slt   | 25 +++++++++
 datafusion/physical-expr/src/aggregate/average.rs  | 59 +++++++++++++++++-----
 datafusion/physical-expr/src/aggregate/min_max.rs  |  3 ++
 .../physical-expr/src/aggregate/row_accumulator.rs |  1 +
 datafusion/physical-expr/src/aggregate/sum.rs      |  3 ++
 datafusion/row/src/accessor.rs                     | 15 ++++++
 datafusion/row/src/layout.rs                       | 13 +++--
 datafusion/row/src/reader.rs                       | 15 ++++++
 datafusion/row/src/writer.rs                       | 18 ++++++-
 9 files changed, 132 insertions(+), 20 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index 10368341d8..9e122d3a26 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1615,6 +1615,31 @@ SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test_table;
 ----
 NULL
 
+# Creating the decimal table
+statement ok
+CREATE TABLE test_decimal_table (c1 INT, c2 DECIMAL(5, 2), c3 DECIMAL(5, 1), c4 DECIMAL(5, 1))
+
+# Inserting data
+statement ok
+INSERT INTO test_decimal_table VALUES (1, 10.10, 100.1, NULL), (1, 20.20, 200.2, NULL), (2, 10.10, 700.1, NULL), (2, 20.20, 700.1, NULL), (3, 10.1, 100.1, NULL), (3, 10.1, NULL, NULL)
+
+# aggregate_decimal_with_group_by
+query IIRRRRIIR rowsort
+select c1, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c3), count(c4), sum(c4) from test_decimal_table group by c1
+----
+1 2 15.15 30.3 10.1 20.2 2 0 NULL
+2 2 15.15 30.3 10.1 20.2 2 0 NULL
+3 2 10.1 20.2 10.1 10.1 1 0 NULL
+
+# aggregate_decimal_with_group_by_decimal
+query RIRRRRIR rowsort
+select c3, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c4), sum(c4) from test_decimal_table group by c3
+----
+100.1 2 10.1 20.2 10.1 10.1 0 NULL
+200.2 1 20.2 20.2 20.2 20.2 0 NULL
+700.1 2 15.15 30.3 10.1 20.2 0 NULL
+NULL 1 10.1 10.1 10.1 10.1 0 NULL
+
 # Restore the default dialect
 statement ok
 set datafusion.sql_parser.dialect = 'Generic';
diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index 7888b8b17f..f898214b4b 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -144,7 +144,8 @@ impl AggregateExpr for Avg {
     ) -> Result<Box<dyn RowAccumulator>> {
         Ok(Box::new(AvgRowAccumulator::new(
             start_index,
-            self.sum_data_type.clone(),
+            &self.sum_data_type,
+            &self.rt_data_type,
         )))
     }
 
@@ -251,7 +252,7 @@ impl Accumulator for AvgAccumulator {
                 })
             }
             _ => Err(DataFusionError::Internal(
-                "Sum should be f64 on average".to_string(),
+                "Sum should be f64 or decimal128 on average".to_string(),
             )),
         }
     }
@@ -265,13 +266,19 @@ impl Accumulator for AvgAccumulator {
 struct AvgRowAccumulator {
     state_index: usize,
     sum_datatype: DataType,
+    return_data_type: DataType,
 }
 
 impl AvgRowAccumulator {
-    pub fn new(start_index: usize, sum_datatype: DataType) -> Self {
+    pub fn new(
+        start_index: usize,
+        sum_datatype: &DataType,
+        return_data_type: &DataType,
+    ) -> Self {
         Self {
             state_index: start_index,
-            sum_datatype,
+            sum_datatype: sum_datatype.clone(),
+            return_data_type: return_data_type.clone(),
         }
     }
 }
@@ -313,16 +320,40 @@ impl RowAccumulator for AvgRowAccumulator {
     }
 
     fn evaluate(&self, accessor: &RowAccessor) -> Result<ScalarValue> {
-        assert_eq!(self.sum_datatype, DataType::Float64);
-        Ok(match accessor.get_u64_opt(self.state_index()) {
-            None => ScalarValue::Float64(None),
-            Some(0) => ScalarValue::Float64(None),
-            Some(n) => ScalarValue::Float64(
-                accessor
-                    .get_f64_opt(self.state_index() + 1)
-                    .map(|f| f / n as f64),
-            ),
-        })
+        match self.sum_datatype {
+            DataType::Decimal128(p, s) => {
+                match accessor.get_u64_opt(self.state_index()) {
+                    None => Ok(ScalarValue::Decimal128(None, p, s)),
+                    Some(0) => Ok(ScalarValue::Decimal128(None, p, s)),
+                    Some(n) => {
+                        // now the sum_type and return type is not the same, need to convert the sum type to return type
+                        accessor.get_i128_opt(self.state_index() + 1).map_or_else(
+                            || Ok(ScalarValue::Decimal128(None, p, s)),
+                            |f| {
+                                calculate_result_decimal_for_avg(
+                                    f,
+                                    n as i128,
+                                    s,
+                                    &self.return_data_type,
+                                )
+                            },
+                        )
+                    }
+                }
+            }
+            DataType::Float64 => Ok(match accessor.get_u64_opt(self.state_index()) {
+                None => ScalarValue::Float64(None),
+                Some(0) => ScalarValue::Float64(None),
+                Some(n) => ScalarValue::Float64(
+                    accessor
+                        .get_f64_opt(self.state_index() + 1)
+                        .map(|f| f / n as f64),
+                ),
+            }),
+            _ => Err(DataFusionError::Internal(
+                "Sum should be f64 or decimal128 on average".to_string(),
+            )),
+        }
     }
 
     #[inline(always)]
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs
index 4cda3779ba..077804b1a6 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -500,6 +500,9 @@ macro_rules! min_max_v2 {
             ScalarValue::Int8(rhs) => {
                 typed_min_max_v2!($INDEX, $ACC, rhs, i8, $OP)
             }
+            ScalarValue::Decimal128(rhs, ..) => {
+                typed_min_max_v2!($INDEX, $ACC, rhs, i128, $OP)
+            }
             e => {
                 return Err(DataFusionError::Internal(format!(
                     "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
diff --git a/datafusion/physical-expr/src/aggregate/row_accumulator.rs b/datafusion/physical-expr/src/aggregate/row_accumulator.rs
index d26da8f4ce..00717a113f 100644
--- a/datafusion/physical-expr/src/aggregate/row_accumulator.rs
+++ b/datafusion/physical-expr/src/aggregate/row_accumulator.rs
@@ -79,5 +79,6 @@ pub fn is_row_accumulator_support_dtype(data_type: &DataType) -> bool {
             | DataType::Int64
             | DataType::Float32
             | DataType::Float64
+            | DataType::Decimal128(_, _)
     )
 }
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs
index f8d4d303ea..abf67933eb 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -258,6 +258,9 @@ pub(crate) fn add_to_row(
         ScalarValue::Int64(rhs) => {
             sum_row!(index, accessor, rhs, i64)
         }
+        ScalarValue::Decimal128(rhs, _, _) => {
+            sum_row!(index, accessor, rhs, i128)
+        }
         _ => {
             let msg =
                 format!("Row sum updater is not expected to receive a scalar {s:?}");
diff --git a/datafusion/row/src/accessor.rs b/datafusion/row/src/accessor.rs
index f8e34578db..e7b4ed8501 100644
--- a/datafusion/row/src/accessor.rs
+++ b/datafusion/row/src/accessor.rs
@@ -193,6 +193,7 @@ impl<'a> RowAccessor<'a> {
     fn_get_idx!(i64, 8);
     fn_get_idx!(f32, 4);
     fn_get_idx!(f64, 8);
+    fn_get_idx!(i128, 16);
 
     fn_get_idx_opt!(bool);
     fn_get_idx_opt!(u8);
@@ -205,6 +206,7 @@ impl<'a> RowAccessor<'a> {
     fn_get_idx_opt!(i64);
     fn_get_idx_opt!(f32);
     fn_get_idx_opt!(f64);
+    fn_get_idx_opt!(i128);
 
     fn_get_idx_scalar!(bool, Boolean);
     fn_get_idx_scalar!(u8, UInt8);
@@ -218,6 +220,14 @@ impl<'a> RowAccessor<'a> {
     fn_get_idx_scalar!(f32, Float32);
     fn_get_idx_scalar!(f64, Float64);
 
+    fn get_decimal128_scalar(&self, idx: usize, p: u8, s: i8) -> ScalarValue {
+        if self.is_valid_at(idx) {
+            ScalarValue::Decimal128(Some(self.get_i128(idx)), p, s)
+        } else {
+            ScalarValue::Decimal128(None, p, s)
+        }
+    }
+
     pub fn get_as_scalar(&self, dt: &DataType, index: usize) -> ScalarValue {
         match dt {
             DataType::Boolean => self.get_bool_scalar(index),
@@ -231,6 +241,7 @@ impl<'a> RowAccessor<'a> {
             DataType::UInt64 => self.get_u64_scalar(index),
             DataType::Float32 => self.get_f32_scalar(index),
             DataType::Float64 => self.get_f64_scalar(index),
+            DataType::Decimal128(p, s) => self.get_decimal128_scalar(index, *p, *s),
             _ => unreachable!(),
         }
     }
@@ -264,6 +275,7 @@ impl<'a> RowAccessor<'a> {
     fn_set_idx!(i64, 8);
     fn_set_idx!(f32, 4);
     fn_set_idx!(f64, 8);
+    fn_set_idx!(i128, 16);
 
     fn set_i8(&mut self, idx: usize, value: i8) {
         self.assert_index_valid(idx);
@@ -285,6 +297,7 @@ impl<'a> RowAccessor<'a> {
     fn_add_idx!(i64);
     fn_add_idx!(f32);
     fn_add_idx!(f64);
+    fn_add_idx!(i128);
 
     fn_max_min_idx!(u8, max);
     fn_max_min_idx!(u16, max);
@@ -296,6 +309,7 @@ impl<'a> RowAccessor<'a> {
     fn_max_min_idx!(i64, max);
     fn_max_min_idx!(f32, max);
     fn_max_min_idx!(f64, max);
+    fn_max_min_idx!(i128, max);
 
     fn_max_min_idx!(u8, min);
     fn_max_min_idx!(u16, min);
@@ -307,4 +321,5 @@ impl<'a> RowAccessor<'a> {
     fn_max_min_idx!(i64, min);
     fn_max_min_idx!(f32, min);
     fn_max_min_idx!(f64, min);
+    fn_max_min_idx!(i128, min);
 }
diff --git a/datafusion/row/src/layout.rs b/datafusion/row/src/layout.rs
index 502812cb9f..6a8e8a78ec 100644
--- a/datafusion/row/src/layout.rs
+++ b/datafusion/row/src/layout.rs
@@ -164,11 +164,13 @@ fn word_aligned_offsets(null_width: usize, schema: &Schema) -> (Vec<usize>, usiz
     let mut offset = null_width;
     for f in schema.fields() {
         offsets.push(offset);
-        assert!(!matches!(f.data_type(), DataType::Decimal128(_, _)));
-        // All of the current support types can fit into one single 8-bytes word.
-        // When we decide to support Decimal type in the future, its width would be
-        // of two 8-bytes words and should adapt the width calculation below.
-        offset += 8;
+        assert!(!matches!(f.data_type(), DataType::Decimal256(_, _)));
+        // All of the current support types can fit into one single 8-bytes word except for Decimal128.
+        // For Decimal128, its width is of two 8-bytes words.
+        match f.data_type() {
+            DataType::Decimal128(_, _) => offset += 16,
+            _ => offset += 8,
+        }
     }
     (offsets, offset - null_width)
 }
@@ -241,6 +243,7 @@ fn supported_type(dt: &DataType, row_type: RowType) -> bool {
                     | Float64
                     | Date32
                     | Date64
+                    | Decimal128(_, _)
             )
         }
     }
diff --git a/datafusion/row/src/reader.rs b/datafusion/row/src/reader.rs
index 634b814ad3..a8dc8211f0 100644
--- a/datafusion/row/src/reader.rs
+++ b/datafusion/row/src/reader.rs
@@ -213,6 +213,10 @@ impl<'a> RowReader<'a> {
         get_idx!(i64, self, idx, 8)
     }
 
+    fn get_decimal128(&self, idx: usize) -> i128 {
+        get_idx!(i128, self, idx, 16)
+    }
+
     fn get_utf8(&self, idx: usize) -> &str {
         self.assert_index_valid(idx);
         let offset_size = self.get_u64(idx);
@@ -260,6 +264,14 @@ impl<'a> RowReader<'a> {
         }
     }
 
+    fn get_decimal128_opt(&self, idx: usize) -> Option<i128> {
+        if self.is_valid_at(idx) {
+            Some(self.get_decimal128(idx))
+        } else {
+            None
+        }
+    }
+
     fn get_utf8_opt(&self, idx: usize) -> Option<&str> {
         if self.is_valid_at(idx) {
             Some(self.get_utf8(idx))
@@ -328,6 +340,7 @@ fn_read_field!(f64, Float64Builder);
 fn_read_field!(date32, Date32Builder);
 fn_read_field!(date64, Date64Builder);
 fn_read_field!(utf8, StringBuilder);
+fn_read_field!(decimal128, Decimal128Builder);
 
 pub(crate) fn read_field_binary(
     to: &mut Box<dyn ArrayBuilder>,
@@ -374,6 +387,7 @@ fn read_field(
         Date64 => read_field_date64(to, col_idx, row),
         Utf8 => read_field_utf8(to, col_idx, row),
         Binary => read_field_binary(to, col_idx, row),
+        Decimal128(_, _) => read_field_decimal128(to, col_idx, row),
         _ => unimplemented!(),
     }
 }
@@ -401,6 +415,7 @@ fn read_field_null_free(
         Date64 => read_field_date64_null_free(to, col_idx, row),
         Utf8 => read_field_utf8_null_free(to, col_idx, row),
         Binary => read_field_binary_null_free(to, col_idx, row),
+        Decimal128(_, _) => read_field_decimal128_null_free(to, col_idx, row),
         _ => unimplemented!(),
     }
 }
diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs
index 12339afe77..7bf9ac0267 100644
--- a/datafusion/row/src/writer.rs
+++ b/datafusion/row/src/writer.rs
@@ -23,7 +23,8 @@ use arrow::datatypes::{DataType, Schema};
 use arrow::record_batch::RecordBatch;
 use arrow::util::bit_util::{round_upto_power_of_2, set_bit_raw, unset_bit_raw};
 use datafusion_common::cast::{
-    as_binary_array, as_date32_array, as_date64_array, as_string_array,
+    as_binary_array, as_date32_array, as_date64_array, as_decimal128_array,
+    as_string_array,
 };
 use datafusion_common::Result;
 use std::cmp::max;
@@ -225,6 +226,10 @@ impl RowWriter {
         set_idx!(8, self, idx, value)
     }
 
+    fn set_decimal128(&mut self, idx: usize, value: i128) {
+        set_idx!(16, self, idx, value)
+    }
+
     fn set_offset_size(&mut self, idx: usize, size: u32) {
         let offset_and_size: u64 = (self.varlena_offset as u64) << 32 | (size as u64);
         self.set_u64(idx, offset_and_size);
@@ -375,6 +380,16 @@ pub(crate) fn write_field_binary(
     to.set_binary(col_idx, s);
 }
 
+pub(crate) fn write_field_decimal128(
+    to: &mut RowWriter,
+    from: &Arc<dyn Array>,
+    col_idx: usize,
+    row_idx: usize,
+) {
+    let from = as_decimal128_array(from).unwrap();
+    to.set_decimal128(col_idx, from.value(row_idx));
+}
+
 fn write_field(
     col_idx: usize,
     row_idx: usize,
@@ -399,6 +414,7 @@ fn write_field(
         Date64 => write_field_date64(row, col, col_idx, row_idx),
         Utf8 => write_field_utf8(row, col, col_idx, row_idx),
         Binary => write_field_binary(row, col, col_idx, row_idx),
+        Decimal128(_, _) => write_field_decimal128(row, col, col_idx, row_idx),
         _ => unimplemented!(),
     }
 }