You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/10/06 19:28:02 UTC
[arrow-datafusion] branch master updated: Remove type coercions from ScalarValue and aggregation function code (#3705)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8dcef9180 Remove type coercions from ScalarValue and aggregation function code (#3705)
8dcef9180 is described below
commit 8dcef91806443f9b9b512bf6d819dc20961b29c8
Author: Mehmet Ozan Kabak <oz...@gmail.com>
AuthorDate: Thu Oct 6 14:27:57 2022 -0500
Remove type coercions from ScalarValue and aggregation function code (#3705)
* Sanitize ScalarValue and aggregation code from type coercions
* Remove forced type cast from sum_row! macro used in SumRowAccumulator
---
datafusion/common/src/scalar.rs | 470 +++++++++------------
datafusion/physical-expr/src/aggregate/average.rs | 66 +--
.../physical-expr/src/aggregate/correlation.rs | 24 +-
datafusion/physical-expr/src/aggregate/count.rs | 48 +--
.../physical-expr/src/aggregate/covariance.rs | 27 +-
datafusion/physical-expr/src/aggregate/min_max.rs | 214 ++--------
datafusion/physical-expr/src/aggregate/stddev.rs | 29 +-
datafusion/physical-expr/src/aggregate/sum.rs | 160 ++-----
.../physical-expr/src/aggregate/sum_distinct.rs | 28 +-
datafusion/physical-expr/src/aggregate/variance.rs | 49 +--
datafusion/physical-expr/src/expressions/mod.rs | 14 +
11 files changed, 334 insertions(+), 795 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 42f0a7d16..c3f91dd9b 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -312,155 +312,186 @@ impl Eq for ScalarValue {}
// TODO implement this in arrow-rs with simd
// https://github.com/apache/arrow-rs/issues/1010
macro_rules! decimal_op {
- ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr, $OPERATION:tt ) => {{
- let (difference, side) = if $LHS_SCALE > $RHS_SCALE {
- ($LHS_SCALE - $RHS_SCALE, true)
- } else {
- ($RHS_SCALE - $LHS_SCALE, false)
- };
- let scale = max($LHS_SCALE, $RHS_SCALE);
- match ($LHS, $RHS, difference) {
- (None, None, _) => ScalarValue::Decimal128(None, $PRECISION, scale),
- (None, Some(rhs_value), 0) => ScalarValue::Decimal128(Some((0 as i128) $OPERATION rhs_value), $PRECISION, scale),
- (None, Some(rhs_value), _) => {
- let mut new_value = ((0 as i128) $OPERATION rhs_value);
- if side {
- new_value *= 10_i128.pow((difference) as u32)
- };
- ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
- }
- (Some(lhs_value), None, 0) => ScalarValue::Decimal128(Some(lhs_value $OPERATION (0 as i128)), $PRECISION, scale),
- (Some(lhs_value), None, _) => {
- let mut new_value = (lhs_value $OPERATION (0 as i128));
- if !!!side {
- new_value *= 10_i128.pow((difference) as u32)
+ ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr, $OPERATION:tt) => {{
+ let (difference, side) = if $LHS_SCALE > $RHS_SCALE {
+ ($LHS_SCALE - $RHS_SCALE, true)
+ } else {
+ ($RHS_SCALE - $LHS_SCALE, false)
+ };
+ let scale = max($LHS_SCALE, $RHS_SCALE);
+ Ok(match ($LHS, $RHS, difference) {
+ (None, None, _) => ScalarValue::Decimal128(None, $PRECISION, scale),
+ (lhs, None, 0) => ScalarValue::Decimal128(*lhs, $PRECISION, scale),
+ (Some(lhs_value), None, _) => {
+ let mut new_value = *lhs_value;
+ if !side {
+ new_value *= 10_i128.pow(difference as u32)
+ }
+ ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
}
- ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
- }
- (Some(lhs_value), Some(rhs_value), 0) => {
- ScalarValue::Decimal128(Some(lhs_value $OPERATION rhs_value), $PRECISION, scale)
- }
- (Some(lhs_value), Some(rhs_value), _) => {
- let new_value = if side {
- rhs_value * 10_i128.pow((difference) as u32) $OPERATION lhs_value
- } else {
- lhs_value * 10_i128.pow((difference) as u32) $OPERATION rhs_value
- };
- ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
- }
- }}
+ (None, Some(rhs_value), 0) => {
+ let value = decimal_right!(*rhs_value, $OPERATION);
+ ScalarValue::Decimal128(Some(value), $PRECISION, scale)
+ }
+ (None, Some(rhs_value), _) => {
+ let mut new_value = decimal_right!(*rhs_value, $OPERATION);
+ if side {
+ new_value *= 10_i128.pow(difference as u32)
+ };
+ ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
+ }
+ (Some(lhs_value), Some(rhs_value), 0) => {
+ decimal_binary_op!(lhs_value, rhs_value, $OPERATION, $PRECISION, scale)
+ }
+ (Some(lhs_value), Some(rhs_value), _) => {
+ let (left_arg, right_arg) = if side {
+ (*lhs_value, rhs_value * 10_i128.pow(difference as u32))
+ } else {
+ (lhs_value * 10_i128.pow(difference as u32), *rhs_value)
+ };
+ decimal_binary_op!(left_arg, right_arg, $OPERATION, $PRECISION, scale)
+ }
+ })
+ }};
+}
- }
+macro_rules! decimal_binary_op {
+ ($LHS:expr, $RHS:expr, $OPERATION:tt, $PRECISION:expr, $SCALE:expr) => {
+ // TODO: This simple implementation loses precision for calculations like
+ // multiplication and division. Improve this implementation for such
+ // operations.
+ ScalarValue::Decimal128(Some($LHS $OPERATION $RHS), $PRECISION, $SCALE)
+ };
}
-// Returns the result of applying operation to two scalar values, including coercion into $TYPE.
-macro_rules! typed_op {
- ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $TYPE:ident, $OPERATION:tt) => {
- Some(ScalarValue::$SCALAR(match ($LEFT, $RIGHT) {
- (None, None) => None,
- (Some(a), None) => Some((*a as $TYPE) $OPERATION (0 as $TYPE)),
- (None, Some(b)) => Some((0 as $TYPE) $OPERATION (*b as $TYPE)),
- (Some(a), Some(b)) => Some((*a as $TYPE) $OPERATION (*b as $TYPE)),
- }))
+macro_rules! decimal_right {
+ ($TERM:expr, +) => {
+ $TERM
+ };
+ ($TERM:expr, *) => {
+ $TERM
+ };
+ ($TERM:expr, -) => {
+ -$TERM
+ };
+ ($TERM:expr, /) => {
+ Err(DataFusionError::NotImplemented(format!(
+ "Decimal reciprocation not yet supported",
+ )))
};
}
-macro_rules! impl_common_symmetric_cases_op {
- ($LHS:expr, $RHS:expr, $OPERATION:tt, [$([$L_TYPE:ident, $R_TYPE:ident, $O_TYPE:ident, $O_PRIM:ident]),+]) => {
- match ($LHS, $RHS) {
- $(
- (ScalarValue::$L_TYPE(lhs), ScalarValue::$R_TYPE(rhs)) => {
- typed_op!(lhs, rhs, $O_TYPE, $O_PRIM, $OPERATION)
- }
- (ScalarValue::$R_TYPE(lhs), ScalarValue::$L_TYPE(rhs)) => {
- typed_op!(lhs, rhs, $O_TYPE, $O_PRIM, $OPERATION)
- }
- )+
- _ => None
+// Returns the result of applying operation to two scalar values.
+macro_rules! primitive_op {
+ ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $OPERATION:tt) => {
+ match ($LEFT, $RIGHT) {
+ (lhs, None) => Ok(ScalarValue::$SCALAR(*lhs)),
+ #[allow(unused_variables)]
+ (None, Some(b)) => { primitive_right!(*b, $OPERATION, $SCALAR) },
+ (Some(a), Some(b)) => Ok(ScalarValue::$SCALAR(Some(*a $OPERATION *b))),
}
- }
+ };
}
-macro_rules! impl_common_cases_op {
+macro_rules! primitive_right {
+ ($TERM:expr, +, $SCALAR:ident) => {
+ Ok(ScalarValue::$SCALAR(Some($TERM)))
+ };
+ ($TERM:expr, *, $SCALAR:ident) => {
+ Ok(ScalarValue::$SCALAR(Some($TERM)))
+ };
+ ($TERM:expr, -, UInt64) => {
+ unsigned_subtraction_error!("UInt64")
+ };
+ ($TERM:expr, -, UInt32) => {
+ unsigned_subtraction_error!("UInt32")
+ };
+ ($TERM:expr, -, UInt16) => {
+ unsigned_subtraction_error!("UInt16")
+ };
+ ($TERM:expr, -, UInt8) => {
+ unsigned_subtraction_error!("UInt8")
+ };
+ ($TERM:expr, -, $SCALAR:ident) => {
+ Ok(ScalarValue::$SCALAR(Some(-$TERM)))
+ };
+ ($TERM:expr, /, Float64) => {
+ Ok(ScalarValue::$SCALAR(Some($TERM.recip())))
+ };
+ ($TERM:expr, /, Float32) => {
+ Ok(ScalarValue::$SCALAR(Some($TERM.recip())))
+ };
+ ($TERM:expr, /, $SCALAR:ident) => {
+ Err(DataFusionError::Internal(format!(
+ "Can not divide an uninitialized value to a non-floating point value",
+ )))
+ };
+}
+
+macro_rules! unsigned_subtraction_error {
+ ($SCALAR:expr) => {{
+ let msg = format!(
+ "Can not subtract a {} value from an uninitialized value",
+ $SCALAR
+ );
+ Err(DataFusionError::Internal(msg))
+ }};
+}
+
+macro_rules! impl_op {
($LHS:expr, $RHS:expr, $OPERATION:tt) => {
match ($LHS, $RHS) {
(
ScalarValue::Decimal128(v1, p1, s1),
ScalarValue::Decimal128(v2, p2, s2),
) => {
- let max_precision = *p1.max(p2);
- Some(decimal_op!(v1, v2, max_precision, *s1, *s2, $OPERATION))
+ decimal_op!(v1, v2, *p1.max(p2), *s1, *s2, $OPERATION)
}
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
- typed_op!(lhs, rhs, Float64, f64, $OPERATION)
+ primitive_op!(lhs, rhs, Float64, $OPERATION)
}
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
- typed_op!(lhs, rhs, Float32, f32, $OPERATION)
+ primitive_op!(lhs, rhs, Float32, $OPERATION)
}
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
- typed_op!(lhs, rhs, UInt64, u64, $OPERATION)
+ primitive_op!(lhs, rhs, UInt64, $OPERATION)
}
(ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
- typed_op!(lhs, rhs, Int64, i64, $OPERATION)
+ primitive_op!(lhs, rhs, Int64, $OPERATION)
}
(ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
- typed_op!(lhs, rhs, UInt32, u32, $OPERATION)
+ primitive_op!(lhs, rhs, UInt32, $OPERATION)
}
(ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
- typed_op!(lhs, rhs, Int32, i32, $OPERATION)
+ primitive_op!(lhs, rhs, Int32, $OPERATION)
}
(ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
- typed_op!(lhs, rhs, UInt16, u16, $OPERATION)
+ primitive_op!(lhs, rhs, UInt16, $OPERATION)
}
(ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
- typed_op!(lhs, rhs, Int16, i16, $OPERATION)
+ primitive_op!(lhs, rhs, Int16, $OPERATION)
}
(ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
- typed_op!(lhs, rhs, UInt8, u8, $OPERATION)
+ primitive_op!(lhs, rhs, UInt8, $OPERATION)
}
(ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
- typed_op!(lhs, rhs, Int8, i8, $OPERATION)
+ primitive_op!(lhs, rhs, Int8, $OPERATION)
+ }
+ _ => {
+ impl_distinct_cases_op!($LHS, $RHS, $OPERATION)
}
- _ => impl_common_symmetric_cases_op!(
- $LHS,
- $RHS,
- $OPERATION,
- [
- // Float64 coerces everything to f64:
- [Float64, Float32, Float64, f64],
- [Float64, Int64, Float64, f64],
- [Float64, Int32, Float64, f64],
- [Float64, Int16, Float64, f64],
- [Float64, Int8, Float64, f64],
- [Float64, UInt64, Float64, f64],
- [Float64, UInt32, Float64, f64],
- [Float64, UInt16, Float64, f64],
- [Float64, UInt8, Float64, f64],
- // UInt64 coerces all smaller unsigned types to u64:
- [UInt64, UInt32, UInt64, u64],
- [UInt64, UInt16, UInt64, u64],
- [UInt64, UInt8, UInt64, u64],
- // Int64 coerces all smaller integral types to i64:
- [Int64, Int32, Int64, i64],
- [Int64, Int16, Int64, i64],
- [Int64, Int8, Int64, i64],
- [Int64, UInt32, Int64, i64],
- [Int64, UInt16, Int64, i64],
- [Int64, UInt8, Int64, i64]
- ]
- ),
}
};
}
-/// If we want a special implementation for an ooperation this is the place to implement it
-/// For instance, in the future we may want to implement subtraction for dates but not summation
-/// so we can implement special case in the corresponding place
+// If we want a special implementation for an operation this is the place to implement it.
+// For instance, in the future we may want to implement subtraction for dates but not addition.
+// We can implement such special cases here.
macro_rules! impl_distinct_cases_op {
($LHS:expr, $RHS:expr, +) => {
match ($LHS, $RHS) {
e => Err(DataFusionError::Internal(format!(
- "Summation is not implemented for {:?}",
+ "Addition is not implemented for {:?}",
e
))),
}
@@ -475,15 +506,6 @@ macro_rules! impl_distinct_cases_op {
};
}
-macro_rules! impl_op {
- ($LHS:expr, $RHS:expr, $OPERATION:tt) => {
- match impl_common_cases_op!($LHS, $RHS, $OPERATION) {
- Some(elem) => Ok(elem),
- None => impl_distinct_cases_op!($LHS, $RHS, $OPERATION),
- }
- };
-}
-
// manual implementation of `Hash` that uses OrderedFloat to
// get defined behavior for floating point
impl std::hash::Hash for ScalarValue {
@@ -938,11 +960,13 @@ impl ScalarValue {
}
pub fn is_unsigned(&self) -> bool {
- let value_type = self.get_datatype();
- value_type == DataType::UInt64
- || value_type == DataType::UInt32
- || value_type == DataType::UInt16
- || value_type == DataType::UInt8
+ matches!(
+ self,
+ ScalarValue::UInt8(_)
+ | ScalarValue::UInt16(_)
+ | ScalarValue::UInt32(_)
+ | ScalarValue::UInt64(_)
+ )
}
/// whether this value is null or not.
@@ -2180,35 +2204,43 @@ impl TryFrom<&DataType> for ScalarValue {
}
}
+// TODO: Remove these coercions once the hardcoded "u64" offset is changed to a
+// ScalarValue in WindowFrameBound.
pub trait TryFromValue<T> {
fn try_from_value(datatype: &DataType, value: T) -> Result<ScalarValue>;
}
macro_rules! impl_try_from_value {
- ($NATIVE:ty, [$([$SCALAR:ident, $PRIMITIVE:tt]),+]) => {
+ ($NATIVE:ty, [$([$SCALAR:ident, $PRIMITIVE:ty]),+]) => {
impl TryFromValue<$NATIVE> for ScalarValue {
fn try_from_value(datatype: &DataType, value: $NATIVE) -> Result<ScalarValue> {
- Ok(match datatype {
- $(DataType::$SCALAR => ScalarValue::$SCALAR(Some(value as $PRIMITIVE)),)+
+ match datatype {
+ $(DataType::$SCALAR => Ok(ScalarValue::$SCALAR(Some(value as $PRIMITIVE))),)+
_ => {
- return Err(DataFusionError::NotImplemented(format!(
- "Can't create a scalar from data_type \"{:?}\"",
- datatype
- )));
+ let msg = format!("Can't create a scalar from data_type \"{:?}\"", datatype);
+ Err(DataFusionError::NotImplemented(msg))
}
- })
+ }
}
}
};
}
-macro_rules! impl_try_from_value_all {
- ([$($NATIVE:ty),+]) => {
- $(impl_try_from_value!($NATIVE, [[Float64, f64], [Float32, f32], [UInt64, u64], [UInt32, u32], [UInt16, u16], [UInt8, u8], [Int64, i64], [Int32, i32], [Int16, i16], [Int8, i8]]);)+
- }
-}
-
-impl_try_from_value_all!([f64, f32, u64, u32, u16, u8, i64, i32, i16, i8]);
+impl_try_from_value!(
+ u64,
+ [
+ [Float64, f64],
+ [Float32, f32],
+ [UInt64, u64],
+ [UInt32, u32],
+ [UInt16, u16],
+ [UInt8, u8],
+ [Int64, i64],
+ [Int32, i32],
+ [Int16, i16],
+ [Int8, i8]
+ ]
+);
macro_rules! format_option {
($F:expr, $EXPR:expr) => {{
@@ -2440,18 +2472,6 @@ mod tests {
float_value.sub(&float_value_2)?,
ScalarValue::Float64(Some(0.))
);
- assert_eq!(
- float_value.sub(&float_value_2)?,
- ScalarValue::Float64(Some(0.))
- );
- assert_eq!(
- float_value.sub(&float_value_2)?,
- ScalarValue::Float64(Some(0.))
- );
- assert_eq!(
- float_value.sub(&float_value_2)?,
- ScalarValue::Float64(Some(0.))
- );
assert_eq!(
float_value.sub(float_value_2)?,
ScalarValue::Float64(Some(0.))
@@ -3693,37 +3713,36 @@ mod tests {
Ok(())
}
- #[test]
- fn test_subtraction() {
- let lhs = ScalarValue::Float64(Some(11.0));
- let rhs = ScalarValue::Float64(Some(12.0));
- assert_eq!(lhs.sub(rhs).unwrap(), ScalarValue::Float64(Some(-1.0)));
- }
-
- #[test]
- fn expect_subtraction_error() {
- let lhs = ScalarValue::UInt64(Some(12));
- let rhs = ScalarValue::Int32(Some(-3));
- let expected_error = "Subtraction is not implemented";
- match lhs.sub(&rhs) {
- Ok(_result) => {
- panic!(
- "Expected summation error between lhs: '{:?}', rhs: {:?}",
- lhs, rhs
- );
- }
- Err(e) => {
- let error_message = e.to_string();
- assert!(
- error_message.contains(expected_error),
- "Expected error '{}' not found in actual error '{}'",
- expected_error,
- error_message
- );
+ macro_rules! expect_operation_error {
+ ($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => {
+ #[test]
+ fn $TEST_NAME() {
+ let lhs = ScalarValue::UInt64(Some(12));
+ let rhs = ScalarValue::Int32(Some(-3));
+ match lhs.$FUNCTION(&rhs) {
+ Ok(_result) => {
+ panic!(
+ "Expected summation error between lhs: '{:?}', rhs: {:?}",
+ lhs, rhs
+ );
+ }
+ Err(e) => {
+ let error_message = e.to_string();
+ assert!(
+ error_message.contains($EXPECTED_ERROR),
+ "Expected error '{}' not found in actual error '{}'",
+ $EXPECTED_ERROR,
+ error_message
+ );
+ }
+ }
}
- }
+ };
}
+ expect_operation_error!(expect_add_error, add, "Addition is not implemented");
+ expect_operation_error!(expect_sub_error, sub, "Subtraction is not implemented");
+
macro_rules! decimal_op_test_cases {
($OPERATION:ident, [$([$L_VALUE:expr, $L_PRECISION:expr, $L_SCALE:expr, $R_VALUE:expr, $R_PRECISION:expr, $R_SCALE:expr, $O_VALUE:expr, $O_PRECISION:expr, $O_SCALE:expr]),+]) => {
$(
@@ -3791,109 +3810,4 @@ mod tests {
]
);
}
-
- macro_rules! op_test_cases {
- ($LHS:expr, $RHS:expr, $OUT:expr, $OPERATION:ident, [$([$L_TYPE:ident, $L_PRIM:ident, $R_TYPE:ident, $R_PRIM:ident, $O_TYPE:ident, $O_PRIM:ident]),+]) => {
- $(
- // From left
- let lhs = ScalarValue::$L_TYPE(Some($LHS as $L_PRIM));
- let rhs = ScalarValue::$R_TYPE(Some($RHS as $R_PRIM));
- assert_eq!(lhs.$OPERATION(rhs).unwrap(), ScalarValue::$O_TYPE(Some($OUT as $O_PRIM)));
- // From right. The values ($RHS and $LHS) also crossed to produce same output for subtraction.
- let lhs = ScalarValue::$L_TYPE(Some($RHS as $L_PRIM));
- let rhs = ScalarValue::$R_TYPE(Some($LHS as $R_PRIM));
- assert_eq!(rhs.$OPERATION(lhs).unwrap(), ScalarValue::$O_TYPE(Some($OUT as $O_PRIM)));
- )+
- };
- }
-
- #[test]
- fn test_sum_operation_different_types() {
- op_test_cases!(
- 11,
- 12,
- 23,
- add,
- [
- // FloatXY coerces everything to fXY:
- [Float64, f64, Float32, f32, Float64, f64],
- [Float64, f64, Int64, i64, Float64, f64],
- [Float64, f64, Int32, i32, Float64, f64],
- [Float64, f64, Int16, i16, Float64, f64],
- [Float64, f64, Int8, i8, Float64, f64],
- [Float64, f64, UInt64, u64, Float64, f64],
- [Float64, f64, UInt32, u32, Float64, f64],
- [Float64, f64, UInt16, u16, Float64, f64],
- [Float64, f64, UInt8, u8, Float64, f64],
- // UIntXY coerces all smaller unsigned types to uXY:
- [UInt64, u64, UInt32, u32, UInt64, u64],
- [UInt64, u64, UInt16, u16, UInt64, u64],
- [UInt64, u64, UInt8, u8, UInt64, u64],
- // IntXY types coerce smaller integral types to iXY:
- [Int64, i64, Int32, i32, Int64, i64],
- [Int64, i64, Int16, i16, Int64, i64],
- [Int64, i64, Int8, i8, Int64, i64],
- [Int64, i64, UInt32, u32, Int64, i64],
- [Int64, i64, UInt16, u16, Int64, i64],
- [Int64, i64, UInt8, u8, Int64, i64]
- ]
- );
- }
-
- #[test]
- fn test_sub_operation_different_types() {
- op_test_cases!(
- 20,
- 8,
- 12,
- sub,
- [
- // FloatXY coerces everything to fXY:
- [Float64, f64, Float32, f32, Float64, f64],
- [Float64, f64, Int64, i64, Float64, f64],
- [Float64, f64, Int32, i32, Float64, f64],
- [Float64, f64, Int16, i16, Float64, f64],
- [Float64, f64, Int8, i8, Float64, f64],
- [Float64, f64, UInt64, u64, Float64, f64],
- [Float64, f64, UInt32, u32, Float64, f64],
- [Float64, f64, UInt16, u16, Float64, f64],
- [Float64, f64, UInt8, u8, Float64, f64],
- // UIntXY coerces all smaller unsigned types to uXY:
- [UInt64, u64, UInt32, u32, UInt64, u64],
- [UInt64, u64, UInt16, u16, UInt64, u64],
- [UInt64, u64, UInt8, u8, UInt64, u64],
- // IntXY types coerce smaller integral types to iXY:
- [Int64, i64, Int32, i32, Int64, i64],
- [Int64, i64, Int16, i16, Int64, i64],
- [Int64, i64, Int8, i8, Int64, i64],
- [Int64, i64, UInt32, u32, Int64, i64],
- [Int64, i64, UInt16, u16, Int64, i64],
- [Int64, i64, UInt8, u8, Int64, i64]
- ]
- );
- }
-
- #[test]
- fn expect_summation_error() {
- let lhs = ScalarValue::UInt64(Some(12));
- let rhs = ScalarValue::Int32(Some(-3));
- let expected_error = "Summation is not implemented";
- match lhs.add(&rhs) {
- Ok(_result) => {
- panic!(
- "Expected summation error between lhs: '{:?}', rhs: {:?}",
- lhs, rhs
- );
- }
- Err(e) => {
- let error_message = e.to_string();
- assert!(
- error_message.contains(expected_error),
- "Expected error '{}' not found in actual error '{}'",
- expected_error,
- error_message
- );
- }
- }
- }
}
diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index 723ae7e9a..f034e3d56 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -230,7 +230,6 @@ impl RowAccumulator for AvgRowAccumulator {
// sum
sum::add_to_row(
- &self.sum_datatype,
self.state_index() + 1,
accessor,
&sum::sum_batch(values, &self.sum_datatype)?,
@@ -249,12 +248,8 @@ impl RowAccumulator for AvgRowAccumulator {
accessor.add_u64(self.state_index(), delta);
// sum
- sum::add_to_row(
- &self.sum_datatype,
- self.state_index() + 1,
- accessor,
- &sum::sum_batch(&states[1], &self.sum_datatype)?,
- )?;
+ let difference = sum::sum_batch(&states[1], &self.sum_datatype)?;
+ sum::add_to_row(self.state_index() + 1, accessor, &difference)?;
Ok(())
}
@@ -301,8 +296,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Avg,
- ScalarValue::Decimal128(Some(35000), 14, 4),
- DataType::Decimal128(14, 4)
+ ScalarValue::Decimal128(Some(35000), 14, 4)
)
}
@@ -318,8 +312,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Avg,
- ScalarValue::Decimal128(Some(32500), 14, 4),
- DataType::Decimal128(14, 4)
+ ScalarValue::Decimal128(Some(32500), 14, 4)
)
}
@@ -337,21 +330,14 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Avg,
- ScalarValue::Decimal128(None, 14, 4),
- DataType::Decimal128(14, 4)
+ ScalarValue::Decimal128(None, 14, 4)
)
}
#[test]
fn avg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Int32,
- Avg,
- ScalarValue::from(3_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3_f64))
}
#[test]
@@ -363,63 +349,33 @@ mod tests {
Some(4),
Some(5),
]));
- generic_test_op!(
- a,
- DataType::Int32,
- Avg,
- ScalarValue::from(3.25f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3.25f64))
}
#[test]
fn avg_i32_all_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
- generic_test_op!(
- a,
- DataType::Int32,
- Avg,
- ScalarValue::Float64(None),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Int32, Avg, ScalarValue::Float64(None))
}
#[test]
fn avg_u32() -> Result<()> {
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
- generic_test_op!(
- a,
- DataType::UInt32,
- Avg,
- ScalarValue::from(3.0f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::UInt32, Avg, ScalarValue::from(3.0f64))
}
#[test]
fn avg_f32() -> Result<()> {
let a: ArrayRef =
Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
- generic_test_op!(
- a,
- DataType::Float32,
- Avg,
- ScalarValue::from(3_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float32, Avg, ScalarValue::from(3_f64))
}
#[test]
fn avg_f64() -> Result<()> {
let a: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
- generic_test_op!(
- a,
- DataType::Float64,
- Avg,
- ScalarValue::from(3_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float64, Avg, ScalarValue::from(3_f64))
}
}
diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs
index f25cb5790..8645bd549 100644
--- a/datafusion/physical-expr/src/aggregate/correlation.rs
+++ b/datafusion/physical-expr/src/aggregate/correlation.rs
@@ -217,8 +217,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Correlation,
- ScalarValue::from(0.9819805060619659),
- DataType::Float64
+ ScalarValue::from(0.9819805060619659_f64)
)
}
@@ -233,8 +232,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Correlation,
- ScalarValue::from(0.17066403719657236),
- DataType::Float64
+ ScalarValue::from(0.17066403719657236_f64)
)
}
@@ -249,8 +247,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Correlation,
- ScalarValue::from(1_f64),
- DataType::Float64
+ ScalarValue::from(1_f64)
)
}
@@ -269,8 +266,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Correlation,
- ScalarValue::from(0.9860135594710389),
- DataType::Float64
+ ScalarValue::from(0.9860135594710389_f64)
)
}
@@ -285,8 +281,7 @@ mod tests {
DataType::Int32,
DataType::Int32,
Correlation,
- ScalarValue::from(1_f64),
- DataType::Float64
+ ScalarValue::from(1_f64)
)
}
@@ -300,8 +295,7 @@ mod tests {
DataType::UInt32,
DataType::UInt32,
Correlation,
- ScalarValue::from(1_f64),
- DataType::Float64
+ ScalarValue::from(1_f64)
)
}
@@ -315,8 +309,7 @@ mod tests {
DataType::Float32,
DataType::Float32,
Correlation,
- ScalarValue::from(1_f64),
- DataType::Float64
+ ScalarValue::from(1_f64)
)
}
@@ -333,8 +326,7 @@ mod tests {
DataType::Int32,
DataType::Int32,
Correlation,
- ScalarValue::from(0.1889822365046137),
- DataType::Float64
+ ScalarValue::from(0.1889822365046137_f64)
)
}
diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs
index 8cfb85fa9..b64328aa3 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -210,13 +210,7 @@ mod tests {
#[test]
fn count_elements() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Int32,
- Count,
- ScalarValue::from(5i64),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::Int32, Count, ScalarValue::from(5i64))
}
#[test]
@@ -229,13 +223,7 @@ mod tests {
Some(3),
None,
]));
- generic_test_op!(
- a,
- DataType::Int32,
- Count,
- ScalarValue::from(3i64),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::Int32, Count, ScalarValue::from(3i64))
}
#[test]
@@ -243,51 +231,27 @@ mod tests {
let a: ArrayRef = Arc::new(BooleanArray::from(vec![
None, None, None, None, None, None, None, None,
]));
- generic_test_op!(
- a,
- DataType::Boolean,
- Count,
- ScalarValue::from(0i64),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::Boolean, Count, ScalarValue::from(0i64))
}
#[test]
fn count_empty() -> Result<()> {
let a: Vec<bool> = vec![];
let a: ArrayRef = Arc::new(BooleanArray::from(a));
- generic_test_op!(
- a,
- DataType::Boolean,
- Count,
- ScalarValue::from(0i64),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::Boolean, Count, ScalarValue::from(0i64))
}
#[test]
fn count_utf8() -> Result<()> {
let a: ArrayRef =
Arc::new(StringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"]));
- generic_test_op!(
- a,
- DataType::Utf8,
- Count,
- ScalarValue::from(5i64),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::Utf8, Count, ScalarValue::from(5i64))
}
#[test]
fn count_large_utf8() -> Result<()> {
let a: ArrayRef =
Arc::new(LargeStringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"]));
- generic_test_op!(
- a,
- DataType::LargeUtf8,
- Count,
- ScalarValue::from(5i64),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::LargeUtf8, Count, ScalarValue::from(5i64))
}
}
diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs
index 0911111a3..63a8137c2 100644
--- a/datafusion/physical-expr/src/aggregate/covariance.rs
+++ b/datafusion/physical-expr/src/aggregate/covariance.rs
@@ -397,8 +397,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
CovariancePop,
- ScalarValue::from(0.6666666666666666),
- DataType::Float64
+ ScalarValue::from(0.6666666666666666_f64)
)
}
@@ -413,8 +412,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Covariance,
- ScalarValue::from(1_f64),
- DataType::Float64
+ ScalarValue::from(1_f64)
)
}
@@ -429,8 +427,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
Covariance,
- ScalarValue::from(0.9033333333333335_f64),
- DataType::Float64
+ ScalarValue::from(0.9033333333333335_f64)
)
}
@@ -445,8 +442,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
CovariancePop,
- ScalarValue::from(0.6022222222222223_f64),
- DataType::Float64
+ ScalarValue::from(0.6022222222222223_f64)
)
}
@@ -465,8 +461,7 @@ mod tests {
DataType::Float64,
DataType::Float64,
CovariancePop,
- ScalarValue::from(0.7616666666666666),
- DataType::Float64
+ ScalarValue::from(0.7616666666666666_f64)
)
}
@@ -481,8 +476,7 @@ mod tests {
DataType::Int32,
DataType::Int32,
CovariancePop,
- ScalarValue::from(0.6666666666666666_f64),
- DataType::Float64
+ ScalarValue::from(0.6666666666666666_f64)
)
}
@@ -496,8 +490,7 @@ mod tests {
DataType::UInt32,
DataType::UInt32,
CovariancePop,
- ScalarValue::from(0.6666666666666666_f64),
- DataType::Float64
+ ScalarValue::from(0.6666666666666666_f64)
)
}
@@ -511,8 +504,7 @@ mod tests {
DataType::Float32,
DataType::Float32,
CovariancePop,
- ScalarValue::from(0.6666666666666666_f64),
- DataType::Float64
+ ScalarValue::from(0.6666666666666666_f64)
)
}
@@ -527,8 +519,7 @@ mod tests {
DataType::Int32,
DataType::Int32,
CovariancePop,
- ScalarValue::from(1_f64),
- DataType::Float64
+ ScalarValue::from(1_f64)
)
}
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs
index bdccdf522..36d58c780 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -154,16 +154,10 @@ macro_rules! typed_min_max_batch_string {
// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
macro_rules! typed_min_max_batch {
- ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
- let array = downcast_value!($VALUES, $ARRAYTYPE);
- let value = compute::$OP(array);
- ScalarValue::$SCALAR(value)
- }};
-
- ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident, $TZ:expr) => {{
+ ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
let array = downcast_value!($VALUES, $ARRAYTYPE);
let value = compute::$OP(array);
- ScalarValue::$SCALAR(value, $TZ.clone())
+ ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*)
}};
}
@@ -296,41 +290,18 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
_ => min_max_batch!(values, max),
})
}
-macro_rules! typed_min_max_decimal {
- ($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{
- ScalarValue::$SCALAR(
- match ($VALUE, $DELTA) {
- (None, None) => None,
- (Some(a), None) => Some(a.clone()),
- (None, Some(b)) => Some(b.clone()),
- (Some(a), Some(b)) => Some((*a).$OP(*b)),
- },
- $PRECISION.clone(),
- $SCALE.clone(),
- )
- }};
-}
// min/max of two non-string scalar values.
macro_rules! typed_min_max {
- ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
- ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
- (None, None) => None,
- (Some(a), None) => Some(a.clone()),
- (None, Some(b)) => Some(b.clone()),
- (Some(a), Some(b)) => Some((*a).$OP(*b)),
- })
- }};
-
- ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $TZ:expr) => {{
+ ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{
ScalarValue::$SCALAR(
match ($VALUE, $DELTA) {
(None, None) => None,
- (Some(a), None) => Some(a.clone()),
- (None, Some(b)) => Some(b.clone()),
+ (Some(a), None) => Some(*a),
+ (None, Some(b)) => Some(*b),
(Some(a), Some(b)) => Some((*a).$OP(*b)),
},
- $TZ.clone(),
+ $($EXTRA_ARGS.clone()),*
)
}};
}
@@ -363,13 +334,16 @@ macro_rules! typed_min_max_string {
macro_rules! min_max {
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
Ok(match ($VALUE, $DELTA) {
- (ScalarValue::Decimal128(lhsv,lhsp,lhss), ScalarValue::Decimal128(rhsv,rhsp,rhss)) => {
+ (
+ lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss),
+ rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss)
+ ) => {
if lhsp.eq(rhsp) && lhss.eq(rhss) {
- typed_min_max_decimal!(lhsv, rhsv, lhsp, lhss, Decimal128, $OP)
+ typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss)
} else {
return Err(DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
- (ScalarValue::Decimal128(*lhsv,*lhsp,*lhss),ScalarValue::Decimal128(*rhsv,*rhsp,*rhss))
+ (lhs, rhs)
)));
}
}
@@ -815,8 +789,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Min,
- ScalarValue::Decimal128(Some(1), 10, 0),
- DataType::Decimal128(10, 0)
+ ScalarValue::Decimal128(Some(1), 10, 0)
)
}
@@ -834,8 +807,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Min,
- ScalarValue::Decimal128(None, 10, 0),
- DataType::Decimal128(10, 0)
+ ScalarValue::Decimal128(None, 10, 0)
)
}
@@ -853,8 +825,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Min,
- ScalarValue::Decimal128(Some(1), 10, 0),
- DataType::Decimal128(10, 0)
+ ScalarValue::Decimal128(Some(1), 10, 0)
)
}
@@ -906,8 +877,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Max,
- ScalarValue::Decimal128(Some(5), 10, 0),
- DataType::Decimal128(10, 0)
+ ScalarValue::Decimal128(Some(5), 10, 0)
)
}
@@ -923,8 +893,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Max,
- ScalarValue::Decimal128(Some(5), 10, 0),
- DataType::Decimal128(10, 0)
+ ScalarValue::Decimal128(Some(5), 10, 0)
)
}
@@ -941,33 +910,20 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Min,
- ScalarValue::Decimal128(None, 10, 0),
- DataType::Decimal128(10, 0)
+ ScalarValue::Decimal128(None, 10, 0)
)
}
#[test]
fn max_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Int32,
- Max,
- ScalarValue::from(5i32),
- DataType::Int32
- )
+ generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32))
}
#[test]
fn min_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Int32,
- Min,
- ScalarValue::from(1i32),
- DataType::Int32
- )
+ generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32))
}
#[test]
@@ -977,8 +933,7 @@ mod tests {
a,
DataType::Utf8,
Max,
- ScalarValue::Utf8(Some("d".to_string())),
- DataType::Utf8
+ ScalarValue::Utf8(Some("d".to_string()))
)
}
@@ -989,8 +944,7 @@ mod tests {
a,
DataType::LargeUtf8,
Max,
- ScalarValue::LargeUtf8(Some("d".to_string())),
- DataType::LargeUtf8
+ ScalarValue::LargeUtf8(Some("d".to_string()))
)
}
@@ -1001,8 +955,7 @@ mod tests {
a,
DataType::Utf8,
Min,
- ScalarValue::Utf8(Some("a".to_string())),
- DataType::Utf8
+ ScalarValue::Utf8(Some("a".to_string()))
)
}
@@ -1013,8 +966,7 @@ mod tests {
a,
DataType::LargeUtf8,
Min,
- ScalarValue::LargeUtf8(Some("a".to_string())),
- DataType::LargeUtf8
+ ScalarValue::LargeUtf8(Some("a".to_string()))
)
}
@@ -1027,13 +979,7 @@ mod tests {
Some(4),
Some(5),
]));
- generic_test_op!(
- a,
- DataType::Int32,
- Max,
- ScalarValue::from(5i32),
- DataType::Int32
- )
+ generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32))
}
#[test]
@@ -1045,163 +991,85 @@ mod tests {
Some(4),
Some(5),
]));
- generic_test_op!(
- a,
- DataType::Int32,
- Min,
- ScalarValue::from(1i32),
- DataType::Int32
- )
+ generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32))
}
#[test]
fn max_i32_all_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
- generic_test_op!(
- a,
- DataType::Int32,
- Max,
- ScalarValue::Int32(None),
- DataType::Int32
- )
+ generic_test_op!(a, DataType::Int32, Max, ScalarValue::Int32(None))
}
#[test]
fn min_i32_all_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
- generic_test_op!(
- a,
- DataType::Int32,
- Min,
- ScalarValue::Int32(None),
- DataType::Int32
- )
+ generic_test_op!(a, DataType::Int32, Min, ScalarValue::Int32(None))
}
#[test]
fn max_u32() -> Result<()> {
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
- generic_test_op!(
- a,
- DataType::UInt32,
- Max,
- ScalarValue::from(5_u32),
- DataType::UInt32
- )
+ generic_test_op!(a, DataType::UInt32, Max, ScalarValue::from(5_u32))
}
#[test]
fn min_u32() -> Result<()> {
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
- generic_test_op!(
- a,
- DataType::UInt32,
- Min,
- ScalarValue::from(1u32),
- DataType::UInt32
- )
+ generic_test_op!(a, DataType::UInt32, Min, ScalarValue::from(1u32))
}
#[test]
fn max_f32() -> Result<()> {
let a: ArrayRef =
Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
- generic_test_op!(
- a,
- DataType::Float32,
- Max,
- ScalarValue::from(5_f32),
- DataType::Float32
- )
+ generic_test_op!(a, DataType::Float32, Max, ScalarValue::from(5_f32))
}
#[test]
fn min_f32() -> Result<()> {
let a: ArrayRef =
Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
- generic_test_op!(
- a,
- DataType::Float32,
- Min,
- ScalarValue::from(1_f32),
- DataType::Float32
- )
+ generic_test_op!(a, DataType::Float32, Min, ScalarValue::from(1_f32))
}
#[test]
fn max_f64() -> Result<()> {
let a: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
- generic_test_op!(
- a,
- DataType::Float64,
- Max,
- ScalarValue::from(5_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float64, Max, ScalarValue::from(5_f64))
}
#[test]
fn min_f64() -> Result<()> {
let a: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
- generic_test_op!(
- a,
- DataType::Float64,
- Min,
- ScalarValue::from(1_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float64, Min, ScalarValue::from(1_f64))
}
#[test]
fn min_date32() -> Result<()> {
let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Date32,
- Min,
- ScalarValue::Date32(Some(1)),
- DataType::Date32
- )
+ generic_test_op!(a, DataType::Date32, Min, ScalarValue::Date32(Some(1)))
}
#[test]
fn min_date64() -> Result<()> {
let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Date64,
- Min,
- ScalarValue::Date64(Some(1)),
- DataType::Date64
- )
+ generic_test_op!(a, DataType::Date64, Min, ScalarValue::Date64(Some(1)))
}
#[test]
fn max_date32() -> Result<()> {
let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Date32,
- Max,
- ScalarValue::Date32(Some(5)),
- DataType::Date32
- )
+ generic_test_op!(a, DataType::Date32, Max, ScalarValue::Date32(Some(5)))
}
#[test]
fn max_date64() -> Result<()> {
let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Date64,
- Max,
- ScalarValue::Date64(Some(5)),
- DataType::Date64
- )
+ generic_test_op!(a, DataType::Date64, Max, ScalarValue::Date64(Some(5)))
}
#[test]
@@ -1211,8 +1079,7 @@ mod tests {
a,
DataType::Time64(TimeUnit::Nanosecond),
Max,
- ScalarValue::Time64(Some(5)),
- DataType::Time64(TimeUnit::Nanosecond)
+ ScalarValue::Time64(Some(5))
)
}
@@ -1223,8 +1090,7 @@ mod tests {
a,
DataType::Time64(TimeUnit::Nanosecond),
Max,
- ScalarValue::Time64(Some(5)),
- DataType::Time64(TimeUnit::Nanosecond)
+ ScalarValue::Time64(Some(5))
)
}
}
diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs
index 77f080293..5197018a5 100644
--- a/datafusion/physical-expr/src/aggregate/stddev.rs
+++ b/datafusion/physical-expr/src/aggregate/stddev.rs
@@ -227,13 +227,7 @@ mod tests {
#[test]
fn stddev_f64_1() -> Result<()> {
let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64]));
- generic_test_op!(
- a,
- DataType::Float64,
- StddevPop,
- ScalarValue::from(0.5_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float64, StddevPop, ScalarValue::from(0.5_f64))
}
#[test]
@@ -243,8 +237,7 @@ mod tests {
a,
DataType::Float64,
StddevPop,
- ScalarValue::from(0.7760297817881877),
- DataType::Float64
+ ScalarValue::from(0.7760297817881877_f64)
)
}
@@ -256,8 +249,7 @@ mod tests {
a,
DataType::Float64,
StddevPop,
- ScalarValue::from(std::f64::consts::SQRT_2),
- DataType::Float64
+ ScalarValue::from(std::f64::consts::SQRT_2)
)
}
@@ -268,8 +260,7 @@ mod tests {
a,
DataType::Float64,
Stddev,
- ScalarValue::from(0.9504384952922168),
- DataType::Float64
+ ScalarValue::from(0.9504384952922168_f64)
)
}
@@ -280,8 +271,7 @@ mod tests {
a,
DataType::Int32,
StddevPop,
- ScalarValue::from(std::f64::consts::SQRT_2),
- DataType::Float64
+ ScalarValue::from(std::f64::consts::SQRT_2)
)
}
@@ -293,8 +283,7 @@ mod tests {
a,
DataType::UInt32,
StddevPop,
- ScalarValue::from(std::f64::consts::SQRT_2),
- DataType::Float64
+ ScalarValue::from(std::f64::consts::SQRT_2)
)
}
@@ -306,8 +295,7 @@ mod tests {
a,
DataType::Float32,
StddevPop,
- ScalarValue::from(std::f64::consts::SQRT_2),
- DataType::Float64
+ ScalarValue::from(std::f64::consts::SQRT_2)
)
}
@@ -341,8 +329,7 @@ mod tests {
a,
DataType::Int32,
StddevPop,
- ScalarValue::from(1.479019945774904),
- DataType::Float64
+ ScalarValue::from(1.479019945774904_f64)
)
}
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs
index 892ef5964..ca9b4c819 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -168,12 +168,11 @@ fn sum_decimal_batch(values: &ArrayRef, precision: u8, scale: u8) -> Result<Scal
return Ok(ScalarValue::Decimal128(None, precision, scale));
}
- let mut result = 0_i128;
- for i in 0..array.len() {
- if array.is_valid(i) {
- result += array.value(i).as_i128();
- }
- }
+ let result = array.into_iter().fold(0_i128, |s, element| match element {
+ Some(v) => s + v.as_i128(),
+ None => s,
+ });
+
Ok(ScalarValue::Decimal128(Some(result), precision, scale))
}
@@ -206,87 +205,37 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result<Scalar
macro_rules! sum_row {
($INDEX:ident, $ACC:ident, $DELTA:expr, $TYPE:ident) => {{
paste::item! {
- match $DELTA {
- None => {}
- Some(v) => $ACC.[<add_ $TYPE>]($INDEX, *v as $TYPE)
+ if let Some(v) = $DELTA {
+ $ACC.[<add_ $TYPE>]($INDEX, *v)
}
}
}};
}
pub(crate) fn add_to_row(
- dt: &DataType,
index: usize,
accessor: &mut RowAccessor,
s: &ScalarValue,
) -> Result<()> {
- match (dt, s) {
- // float64 coerces everything to f64
- (DataType::Float64, ScalarValue::Float64(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Float32(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Int64(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Int32(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Int16(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::Int8(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::UInt64(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::UInt32(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::UInt16(rhs)) => {
- sum_row!(index, accessor, rhs, f64)
- }
- (DataType::Float64, ScalarValue::UInt8(rhs)) => {
+ match s {
+ ScalarValue::Float64(rhs) => {
sum_row!(index, accessor, rhs, f64)
}
- // float32 has no cast
- (DataType::Float32, ScalarValue::Float32(rhs)) => {
+ ScalarValue::Float32(rhs) => {
sum_row!(index, accessor, rhs, f32)
}
- // u64 coerces u* to u64
- (DataType::UInt64, ScalarValue::UInt64(rhs)) => {
+ ScalarValue::UInt64(rhs) => {
sum_row!(index, accessor, rhs, u64)
}
- (DataType::UInt64, ScalarValue::UInt32(rhs)) => {
- sum_row!(index, accessor, rhs, u64)
- }
- (DataType::UInt64, ScalarValue::UInt16(rhs)) => {
- sum_row!(index, accessor, rhs, u64)
- }
- (DataType::UInt64, ScalarValue::UInt8(rhs)) => {
- sum_row!(index, accessor, rhs, u64)
- }
- // i64 coerces i* to i64
- (DataType::Int64, ScalarValue::Int64(rhs)) => {
+ ScalarValue::Int64(rhs) => {
sum_row!(index, accessor, rhs, i64)
}
- (DataType::Int64, ScalarValue::Int32(rhs)) => {
- sum_row!(index, accessor, rhs, i64)
- }
- (DataType::Int64, ScalarValue::Int16(rhs)) => {
- sum_row!(index, accessor, rhs, i64)
- }
- (DataType::Int64, ScalarValue::Int8(rhs)) => {
- sum_row!(index, accessor, rhs, i64)
- }
- e => {
- return Err(DataFusionError::Internal(format!(
+ _ => {
+ let msg = format!(
"Row sum updater is not expected to receive a scalar {:?}",
- e
- )));
+ s
+ );
+ return Err(DataFusionError::Internal(msg));
}
}
Ok(())
@@ -303,18 +252,16 @@ impl Accumulator for SumAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
self.count += (values.len() - values.data().null_count()) as u64;
- self.sum = self
- .sum
- .add(&sum_batch(values, &self.sum.get_datatype())?)?;
+ let delta = sum_batch(values, &self.sum.get_datatype())?;
+ self.sum = self.sum.add(&delta)?;
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
self.count -= (values.len() - values.data().null_count()) as u64;
- self.sum = self
- .sum
- .sub(&sum_batch(values, &self.sum.get_datatype())?)?;
+ let delta = sum_batch(values, &self.sum.get_datatype())?;
+ self.sum = self.sum.sub(&delta)?;
Ok(())
}
@@ -353,12 +300,8 @@ impl RowAccumulator for SumRowAccumulator {
accessor: &mut RowAccessor,
) -> Result<()> {
let values = &values[0];
- add_to_row(
- &self.datatype,
- self.index,
- accessor,
- &sum_batch(values, &self.datatype)?,
- )?;
+ let delta = sum_batch(values, &self.datatype)?;
+ add_to_row(self.index, accessor, &delta)?;
Ok(())
}
@@ -414,8 +357,7 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Sum,
- ScalarValue::Decimal128(Some(15), 20, 0),
- DataType::Decimal128(20, 0)
+ ScalarValue::Decimal128(Some(15), 20, 0)
)
}
@@ -442,8 +384,7 @@ mod tests {
array,
DataType::Decimal128(35, 0),
Sum,
- ScalarValue::Decimal128(Some(13), 38, 0),
- DataType::Decimal128(38, 0)
+ ScalarValue::Decimal128(Some(13), 38, 0)
)
}
@@ -465,21 +406,14 @@ mod tests {
array,
DataType::Decimal128(10, 0),
Sum,
- ScalarValue::Decimal128(None, 20, 0),
- DataType::Decimal128(20, 0)
+ ScalarValue::Decimal128(None, 20, 0)
)
}
#[test]
fn sum_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Int32,
- Sum,
- ScalarValue::from(15i64),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::Int32, Sum, ScalarValue::from(15i32))
}
#[test]
@@ -491,63 +425,33 @@ mod tests {
Some(4),
Some(5),
]));
- generic_test_op!(
- a,
- DataType::Int32,
- Sum,
- ScalarValue::from(13i64),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::Int32, Sum, ScalarValue::from(13i32))
}
#[test]
fn sum_i32_all_nulls() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
- generic_test_op!(
- a,
- DataType::Int32,
- Sum,
- ScalarValue::Int64(None),
- DataType::Int64
- )
+ generic_test_op!(a, DataType::Int32, Sum, ScalarValue::Int32(None))
}
#[test]
fn sum_u32() -> Result<()> {
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
- generic_test_op!(
- a,
- DataType::UInt32,
- Sum,
- ScalarValue::from(15u64),
- DataType::UInt64
- )
+ generic_test_op!(a, DataType::UInt32, Sum, ScalarValue::from(15u32))
}
#[test]
fn sum_f32() -> Result<()> {
let a: ArrayRef =
Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
- generic_test_op!(
- a,
- DataType::Float32,
- Sum,
- ScalarValue::from(15_f32),
- DataType::Float32
- )
+ generic_test_op!(a, DataType::Float32, Sum, ScalarValue::from(15_f32))
}
#[test]
fn sum_f64() -> Result<()> {
let a: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
- generic_test_op!(
- a,
- DataType::Float64,
- Sum,
- ScalarValue::from(15_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float64, Sum, ScalarValue::from(15_f64))
}
}
diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
index c8abdcac0..73c4828e8 100644
--- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
@@ -200,7 +200,7 @@ mod tests {
}
macro_rules! generic_test_sum_distinct {
- ($ARRAY:expr, $DATATYPE:expr, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{
+ ($ARRAY:expr, $DATATYPE:expr, $EXPECTED:expr) => {{
let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?;
@@ -208,7 +208,7 @@ mod tests {
let agg = Arc::new(DistinctSum::new(
vec![col("a", &schema)?],
"count_distinct_a".to_string(),
- $EXPECTED_DATATYPE,
+ $EXPECTED.get_datatype(),
));
let actual = aggregate(&batch, agg)?;
let expected = ScalarValue::from($EXPECTED);
@@ -241,12 +241,7 @@ mod tests {
Some(2),
Some(3),
]));
- generic_test_sum_distinct!(
- array,
- DataType::Int32,
- ScalarValue::from(6i64),
- DataType::Int64
- )
+ generic_test_sum_distinct!(array, DataType::Int32, ScalarValue::from(6_i32))
}
#[test]
@@ -258,24 +253,14 @@ mod tests {
Some(3_u32),
None,
]));
- generic_test_sum_distinct!(
- array,
- DataType::UInt32,
- ScalarValue::from(4i64),
- DataType::Int64
- )
+ generic_test_sum_distinct!(array, DataType::UInt32, ScalarValue::from(4_u32))
}
#[test]
fn sum_distinct_f64() -> Result<()> {
let array: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 1_f64, 3_f64, 3_f64, 3_f64]));
- generic_test_sum_distinct!(
- array,
- DataType::Float64,
- ScalarValue::from(4_f64),
- DataType::Float64
- )
+ generic_test_sum_distinct!(array, DataType::Float64, ScalarValue::from(4_f64))
}
#[test]
@@ -289,8 +274,7 @@ mod tests {
generic_test_sum_distinct!(
array,
DataType::Decimal128(35, 0),
- ScalarValue::Decimal128(Some(1), 38, 0),
- DataType::Decimal128(38, 0)
+ ScalarValue::Decimal128(Some(1), 38, 0)
)
}
}
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs
index 7ccec55ac..d6ed8c957 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -326,8 +326,7 @@ mod tests {
a,
DataType::Float64,
VariancePop,
- ScalarValue::from(0.25_f64),
- DataType::Float64
+ ScalarValue::from(0.25_f64)
)
}
@@ -335,26 +334,14 @@ mod tests {
fn variance_f64_2() -> Result<()> {
let a: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
- generic_test_op!(
- a,
- DataType::Float64,
- VariancePop,
- ScalarValue::from(2_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float64, VariancePop, ScalarValue::from(2_f64))
}
#[test]
fn variance_f64_3() -> Result<()> {
let a: ArrayRef =
Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]));
- generic_test_op!(
- a,
- DataType::Float64,
- Variance,
- ScalarValue::from(2.5_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float64, Variance, ScalarValue::from(2.5_f64))
}
#[test]
@@ -364,47 +351,28 @@ mod tests {
a,
DataType::Float64,
Variance,
- ScalarValue::from(0.9033333333333333_f64),
- DataType::Float64
+ ScalarValue::from(0.9033333333333333_f64)
)
}
#[test]
fn variance_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
- generic_test_op!(
- a,
- DataType::Int32,
- VariancePop,
- ScalarValue::from(2_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Int32, VariancePop, ScalarValue::from(2_f64))
}
#[test]
fn variance_u32() -> Result<()> {
let a: ArrayRef =
Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]));
- generic_test_op!(
- a,
- DataType::UInt32,
- VariancePop,
- ScalarValue::from(2.0f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::UInt32, VariancePop, ScalarValue::from(2.0f64))
}
#[test]
fn variance_f32() -> Result<()> {
let a: ArrayRef =
Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]));
- generic_test_op!(
- a,
- DataType::Float32,
- VariancePop,
- ScalarValue::from(2_f64),
- DataType::Float64
- )
+ generic_test_op!(a, DataType::Float32, VariancePop, ScalarValue::from(2_f64))
}
#[test]
@@ -437,8 +405,7 @@ mod tests {
a,
DataType::Int32,
VariancePop,
- ScalarValue::from(2.1875f64),
- DataType::Float64
+ ScalarValue::from(2.1875_f64)
)
}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs
index 00bf6aafa..208e6d0b5 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -101,6 +101,9 @@ pub(crate) mod tests {
/// macro to perform an aggregation and verify the result.
#[macro_export]
macro_rules! generic_test_op {
+ ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => {
+ generic_test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.get_datatype())
+ };
($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{
let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]);
@@ -123,6 +126,17 @@ pub(crate) mod tests {
/// macro to perform an aggregation with two inputs and verify the result.
#[macro_export]
macro_rules! generic_test_op2 {
+ ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr) => {
+ generic_test_op2!(
+ $ARRAY1,
+ $ARRAY2,
+ $DATATYPE1,
+ $DATATYPE2,
+ $OP,
+ $EXPECTED,
+ $EXPECTED.get_datatype()
+ )
+ };
($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{
let schema = Schema::new(vec![
Field::new("a", $DATATYPE1, true),