You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by mi...@apache.org on 2023/04/13 14:57:09 UTC
[arrow-datafusion] branch main updated: Row `AVG` accumulator support Decimal type (#5973)
This is an automated email from the ASF dual-hosted git repository.
mingmwang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new fcd8b899e2 Row `AVG` accumulator support Decimal type (#5973)
fcd8b899e2 is described below
commit fcd8b899e2a62f798413c536f47078289ece9d05
Author: mingmwang <mi...@gmail.com>
AuthorDate: Thu Apr 13 22:57:02 2023 +0800
Row `AVG` accumulator support Decimal type (#5973)
* RowAccumulator support for Decimal128
* add test
---
.../tests/sqllogictests/test_files/aggregate.slt | 25 +++++++++
datafusion/physical-expr/src/aggregate/average.rs | 59 +++++++++++++++++-----
datafusion/physical-expr/src/aggregate/min_max.rs | 3 ++
.../physical-expr/src/aggregate/row_accumulator.rs | 1 +
datafusion/physical-expr/src/aggregate/sum.rs | 3 ++
datafusion/row/src/accessor.rs | 15 ++++++
datafusion/row/src/layout.rs | 13 +++--
datafusion/row/src/reader.rs | 15 ++++++
datafusion/row/src/writer.rs | 18 ++++++-
9 files changed, 132 insertions(+), 20 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index 10368341d8..9e122d3a26 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1615,6 +1615,31 @@ SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test_table;
----
NULL
+# Creating the decimal table
+statement ok
+CREATE TABLE test_decimal_table (c1 INT, c2 DECIMAL(5, 2), c3 DECIMAL(5, 1), c4 DECIMAL(5, 1))
+
+# Inserting data
+statement ok
+INSERT INTO test_decimal_table VALUES (1, 10.10, 100.1, NULL), (1, 20.20, 200.2, NULL), (2, 10.10, 700.1, NULL), (2, 20.20, 700.1, NULL), (3, 10.1, 100.1, NULL), (3, 10.1, NULL, NULL)
+
+# aggregate_decimal_with_group_by
+query IIRRRRIIR rowsort
+select c1, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c3), count(c4), sum(c4) from test_decimal_table group by c1
+----
+1 2 15.15 30.3 10.1 20.2 2 0 NULL
+2 2 15.15 30.3 10.1 20.2 2 0 NULL
+3 2 10.1 20.2 10.1 10.1 1 0 NULL
+
+# aggregate_decimal_with_group_by_decimal
+query RIRRRRIR rowsort
+select c3, count(c2), avg(c2), sum(c2), min(c2), max(c2), count(c4), sum(c4) from test_decimal_table group by c3
+----
+100.1 2 10.1 20.2 10.1 10.1 0 NULL
+200.2 1 20.2 20.2 20.2 20.2 0 NULL
+700.1 2 15.15 30.3 10.1 20.2 0 NULL
+NULL 1 10.1 10.1 10.1 10.1 0 NULL
+
# Restore the default dialect
statement ok
set datafusion.sql_parser.dialect = 'Generic';
diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index 7888b8b17f..f898214b4b 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -144,7 +144,8 @@ impl AggregateExpr for Avg {
) -> Result<Box<dyn RowAccumulator>> {
Ok(Box::new(AvgRowAccumulator::new(
start_index,
- self.sum_data_type.clone(),
+ &self.sum_data_type,
+ &self.rt_data_type,
)))
}
@@ -251,7 +252,7 @@ impl Accumulator for AvgAccumulator {
})
}
_ => Err(DataFusionError::Internal(
- "Sum should be f64 on average".to_string(),
+ "Sum should be f64 or decimal128 on average".to_string(),
)),
}
}
@@ -265,13 +266,19 @@ impl Accumulator for AvgAccumulator {
struct AvgRowAccumulator {
state_index: usize,
sum_datatype: DataType,
+ return_data_type: DataType,
}
impl AvgRowAccumulator {
- pub fn new(start_index: usize, sum_datatype: DataType) -> Self {
+ pub fn new(
+ start_index: usize,
+ sum_datatype: &DataType,
+ return_data_type: &DataType,
+ ) -> Self {
Self {
state_index: start_index,
- sum_datatype,
+ sum_datatype: sum_datatype.clone(),
+ return_data_type: return_data_type.clone(),
}
}
}
@@ -313,16 +320,40 @@ impl RowAccumulator for AvgRowAccumulator {
}
fn evaluate(&self, accessor: &RowAccessor) -> Result<ScalarValue> {
- assert_eq!(self.sum_datatype, DataType::Float64);
- Ok(match accessor.get_u64_opt(self.state_index()) {
- None => ScalarValue::Float64(None),
- Some(0) => ScalarValue::Float64(None),
- Some(n) => ScalarValue::Float64(
- accessor
- .get_f64_opt(self.state_index() + 1)
- .map(|f| f / n as f64),
- ),
- })
+ match self.sum_datatype {
+ DataType::Decimal128(p, s) => {
+ match accessor.get_u64_opt(self.state_index()) {
+ None => Ok(ScalarValue::Decimal128(None, p, s)),
+ Some(0) => Ok(ScalarValue::Decimal128(None, p, s)),
+ Some(n) => {
+ // now the sum_type and return type is not the same, need to convert the sum type to return type
+ accessor.get_i128_opt(self.state_index() + 1).map_or_else(
+ || Ok(ScalarValue::Decimal128(None, p, s)),
+ |f| {
+ calculate_result_decimal_for_avg(
+ f,
+ n as i128,
+ s,
+ &self.return_data_type,
+ )
+ },
+ )
+ }
+ }
+ }
+ DataType::Float64 => Ok(match accessor.get_u64_opt(self.state_index()) {
+ None => ScalarValue::Float64(None),
+ Some(0) => ScalarValue::Float64(None),
+ Some(n) => ScalarValue::Float64(
+ accessor
+ .get_f64_opt(self.state_index() + 1)
+ .map(|f| f / n as f64),
+ ),
+ }),
+ _ => Err(DataFusionError::Internal(
+ "Sum should be f64 or decimal128 on average".to_string(),
+ )),
+ }
}
#[inline(always)]
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs
index 4cda3779ba..077804b1a6 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -500,6 +500,9 @@ macro_rules! min_max_v2 {
ScalarValue::Int8(rhs) => {
typed_min_max_v2!($INDEX, $ACC, rhs, i8, $OP)
}
+ ScalarValue::Decimal128(rhs, ..) => {
+ typed_min_max_v2!($INDEX, $ACC, rhs, i128, $OP)
+ }
e => {
return Err(DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
diff --git a/datafusion/physical-expr/src/aggregate/row_accumulator.rs b/datafusion/physical-expr/src/aggregate/row_accumulator.rs
index d26da8f4ce..00717a113f 100644
--- a/datafusion/physical-expr/src/aggregate/row_accumulator.rs
+++ b/datafusion/physical-expr/src/aggregate/row_accumulator.rs
@@ -79,5 +79,6 @@ pub fn is_row_accumulator_support_dtype(data_type: &DataType) -> bool {
| DataType::Int64
| DataType::Float32
| DataType::Float64
+ | DataType::Decimal128(_, _)
)
}
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs
index f8d4d303ea..abf67933eb 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -258,6 +258,9 @@ pub(crate) fn add_to_row(
ScalarValue::Int64(rhs) => {
sum_row!(index, accessor, rhs, i64)
}
+ ScalarValue::Decimal128(rhs, _, _) => {
+ sum_row!(index, accessor, rhs, i128)
+ }
_ => {
let msg =
format!("Row sum updater is not expected to receive a scalar {s:?}");
diff --git a/datafusion/row/src/accessor.rs b/datafusion/row/src/accessor.rs
index f8e34578db..e7b4ed8501 100644
--- a/datafusion/row/src/accessor.rs
+++ b/datafusion/row/src/accessor.rs
@@ -193,6 +193,7 @@ impl<'a> RowAccessor<'a> {
fn_get_idx!(i64, 8);
fn_get_idx!(f32, 4);
fn_get_idx!(f64, 8);
+ fn_get_idx!(i128, 16);
fn_get_idx_opt!(bool);
fn_get_idx_opt!(u8);
@@ -205,6 +206,7 @@ impl<'a> RowAccessor<'a> {
fn_get_idx_opt!(i64);
fn_get_idx_opt!(f32);
fn_get_idx_opt!(f64);
+ fn_get_idx_opt!(i128);
fn_get_idx_scalar!(bool, Boolean);
fn_get_idx_scalar!(u8, UInt8);
@@ -218,6 +220,14 @@ impl<'a> RowAccessor<'a> {
fn_get_idx_scalar!(f32, Float32);
fn_get_idx_scalar!(f64, Float64);
+ fn get_decimal128_scalar(&self, idx: usize, p: u8, s: i8) -> ScalarValue {
+ if self.is_valid_at(idx) {
+ ScalarValue::Decimal128(Some(self.get_i128(idx)), p, s)
+ } else {
+ ScalarValue::Decimal128(None, p, s)
+ }
+ }
+
pub fn get_as_scalar(&self, dt: &DataType, index: usize) -> ScalarValue {
match dt {
DataType::Boolean => self.get_bool_scalar(index),
@@ -231,6 +241,7 @@ impl<'a> RowAccessor<'a> {
DataType::UInt64 => self.get_u64_scalar(index),
DataType::Float32 => self.get_f32_scalar(index),
DataType::Float64 => self.get_f64_scalar(index),
+ DataType::Decimal128(p, s) => self.get_decimal128_scalar(index, *p, *s),
_ => unreachable!(),
}
}
@@ -264,6 +275,7 @@ impl<'a> RowAccessor<'a> {
fn_set_idx!(i64, 8);
fn_set_idx!(f32, 4);
fn_set_idx!(f64, 8);
+ fn_set_idx!(i128, 16);
fn set_i8(&mut self, idx: usize, value: i8) {
self.assert_index_valid(idx);
@@ -285,6 +297,7 @@ impl<'a> RowAccessor<'a> {
fn_add_idx!(i64);
fn_add_idx!(f32);
fn_add_idx!(f64);
+ fn_add_idx!(i128);
fn_max_min_idx!(u8, max);
fn_max_min_idx!(u16, max);
@@ -296,6 +309,7 @@ impl<'a> RowAccessor<'a> {
fn_max_min_idx!(i64, max);
fn_max_min_idx!(f32, max);
fn_max_min_idx!(f64, max);
+ fn_max_min_idx!(i128, max);
fn_max_min_idx!(u8, min);
fn_max_min_idx!(u16, min);
@@ -307,4 +321,5 @@ impl<'a> RowAccessor<'a> {
fn_max_min_idx!(i64, min);
fn_max_min_idx!(f32, min);
fn_max_min_idx!(f64, min);
+ fn_max_min_idx!(i128, min);
}
diff --git a/datafusion/row/src/layout.rs b/datafusion/row/src/layout.rs
index 502812cb9f..6a8e8a78ec 100644
--- a/datafusion/row/src/layout.rs
+++ b/datafusion/row/src/layout.rs
@@ -164,11 +164,13 @@ fn word_aligned_offsets(null_width: usize, schema: &Schema) -> (Vec<usize>, usiz
let mut offset = null_width;
for f in schema.fields() {
offsets.push(offset);
- assert!(!matches!(f.data_type(), DataType::Decimal128(_, _)));
- // All of the current support types can fit into one single 8-bytes word.
- // When we decide to support Decimal type in the future, its width would be
- // of two 8-bytes words and should adapt the width calculation below.
- offset += 8;
+ assert!(!matches!(f.data_type(), DataType::Decimal256(_, _)));
+ // All of the current support types can fit into one single 8-bytes word except for Decimal128.
+ // For Decimal128, its width is of two 8-bytes words.
+ match f.data_type() {
+ DataType::Decimal128(_, _) => offset += 16,
+ _ => offset += 8,
+ }
}
(offsets, offset - null_width)
}
@@ -241,6 +243,7 @@ fn supported_type(dt: &DataType, row_type: RowType) -> bool {
| Float64
| Date32
| Date64
+ | Decimal128(_, _)
)
}
}
diff --git a/datafusion/row/src/reader.rs b/datafusion/row/src/reader.rs
index 634b814ad3..a8dc8211f0 100644
--- a/datafusion/row/src/reader.rs
+++ b/datafusion/row/src/reader.rs
@@ -213,6 +213,10 @@ impl<'a> RowReader<'a> {
get_idx!(i64, self, idx, 8)
}
+ fn get_decimal128(&self, idx: usize) -> i128 {
+ get_idx!(i128, self, idx, 16)
+ }
+
fn get_utf8(&self, idx: usize) -> &str {
self.assert_index_valid(idx);
let offset_size = self.get_u64(idx);
@@ -260,6 +264,14 @@ impl<'a> RowReader<'a> {
}
}
+ fn get_decimal128_opt(&self, idx: usize) -> Option<i128> {
+ if self.is_valid_at(idx) {
+ Some(self.get_decimal128(idx))
+ } else {
+ None
+ }
+ }
+
fn get_utf8_opt(&self, idx: usize) -> Option<&str> {
if self.is_valid_at(idx) {
Some(self.get_utf8(idx))
@@ -328,6 +340,7 @@ fn_read_field!(f64, Float64Builder);
fn_read_field!(date32, Date32Builder);
fn_read_field!(date64, Date64Builder);
fn_read_field!(utf8, StringBuilder);
+fn_read_field!(decimal128, Decimal128Builder);
pub(crate) fn read_field_binary(
to: &mut Box<dyn ArrayBuilder>,
@@ -374,6 +387,7 @@ fn read_field(
Date64 => read_field_date64(to, col_idx, row),
Utf8 => read_field_utf8(to, col_idx, row),
Binary => read_field_binary(to, col_idx, row),
+ Decimal128(_, _) => read_field_decimal128(to, col_idx, row),
_ => unimplemented!(),
}
}
@@ -401,6 +415,7 @@ fn read_field_null_free(
Date64 => read_field_date64_null_free(to, col_idx, row),
Utf8 => read_field_utf8_null_free(to, col_idx, row),
Binary => read_field_binary_null_free(to, col_idx, row),
+ Decimal128(_, _) => read_field_decimal128_null_free(to, col_idx, row),
_ => unimplemented!(),
}
}
diff --git a/datafusion/row/src/writer.rs b/datafusion/row/src/writer.rs
index 12339afe77..7bf9ac0267 100644
--- a/datafusion/row/src/writer.rs
+++ b/datafusion/row/src/writer.rs
@@ -23,7 +23,8 @@ use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util::{round_upto_power_of_2, set_bit_raw, unset_bit_raw};
use datafusion_common::cast::{
- as_binary_array, as_date32_array, as_date64_array, as_string_array,
+ as_binary_array, as_date32_array, as_date64_array, as_decimal128_array,
+ as_string_array,
};
use datafusion_common::Result;
use std::cmp::max;
@@ -225,6 +226,10 @@ impl RowWriter {
set_idx!(8, self, idx, value)
}
+ fn set_decimal128(&mut self, idx: usize, value: i128) {
+ set_idx!(16, self, idx, value)
+ }
+
fn set_offset_size(&mut self, idx: usize, size: u32) {
let offset_and_size: u64 = (self.varlena_offset as u64) << 32 | (size as u64);
self.set_u64(idx, offset_and_size);
@@ -375,6 +380,16 @@ pub(crate) fn write_field_binary(
to.set_binary(col_idx, s);
}
+pub(crate) fn write_field_decimal128(
+ to: &mut RowWriter,
+ from: &Arc<dyn Array>,
+ col_idx: usize,
+ row_idx: usize,
+) {
+ let from = as_decimal128_array(from).unwrap();
+ to.set_decimal128(col_idx, from.value(row_idx));
+}
+
fn write_field(
col_idx: usize,
row_idx: usize,
@@ -399,6 +414,7 @@ fn write_field(
Date64 => write_field_date64(row, col, col_idx, row_idx),
Utf8 => write_field_utf8(row, col, col_idx, row_idx),
Binary => write_field_binary(row, col, col_idx, row_idx),
+ Decimal128(_, _) => write_field_decimal128(row, col, col_idx, row_idx),
_ => unimplemented!(),
}
}