You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/09/12 10:40:18 UTC

[GitHub] [arrow] alamb commented on a change in pull request #8172: ARROW-9937: [Rust] [DataFusion] Improved aggregations

alamb commented on a change in pull request #8172:
URL: https://github.com/apache/arrow/pull/8172#discussion_r487395554



##########
File path: rust/datafusion/src/physical_plan/expressions.rs
##########
@@ -125,192 +128,188 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
     }
 }
 
-impl AggregateExpr for Sum {
-    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
-        sum_return_type(&self.expr.data_type(input_schema)?)
+impl Sum {
+    /// Create a new SUM aggregate function
+    pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType) -> Self {
+        Self {
+            name,
+            expr,
+            data_type,
+            nullable: true,
+        }
     }
+}
 
-    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
-        // null should be returned if no rows are aggregated
-        Ok(true)
+impl AggregateExpr for Sum {
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(
+            &self.name,
+            self.data_type.clone(),
+            self.nullable,
+        ))
     }
 
-    fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
-        self.expr.evaluate(batch)
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![Field::new(
+            &format_state_name(&self.name, "sum"),
+            self.data_type.clone(),
+            self.nullable,
+        )])
     }
 
-    fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
-        Rc::new(RefCell::new(SumAccumulator { sum: None }))
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr.clone()]
     }
 
-    fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
-        Arc::new(Sum::new(Arc::new(Column::new(column_name))))
+    fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
+        Ok(Rc::new(RefCell::new(SumAccumulator::try_new(
+            &self.data_type,
+        )?)))
     }
 }
 
-macro_rules! sum_accumulate {
-    ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{
-        $SELF.sum = match $SELF.sum {
-            Some(ScalarValue::$SCALAR_VARIANT(n)) => {
-                Some(ScalarValue::$SCALAR_VARIANT(n + $VALUE as $TY))
-            }
-            Some(_) => {
-                return Err(ExecutionError::InternalError(
-                    "Unexpected ScalarValue variant".to_string(),
-                ))
-            }
-            None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)),
-        };
-    }};
-}
-
 #[derive(Debug)]
 struct SumAccumulator {
-    sum: Option<ScalarValue>,
+    sum: ScalarValue,
 }
 
-impl Accumulator for SumAccumulator {
-    fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
-        if let Some(value) = value {
-            match value {
-                ScalarValue::Int8(value) => {
-                    sum_accumulate!(self, value, Int8Array, Int64, i64);
-                }
-                ScalarValue::Int16(value) => {
-                    sum_accumulate!(self, value, Int16Array, Int64, i64);
-                }
-                ScalarValue::Int32(value) => {
-                    sum_accumulate!(self, value, Int32Array, Int64, i64);
-                }
-                ScalarValue::Int64(value) => {
-                    sum_accumulate!(self, value, Int64Array, Int64, i64);
-                }
-                ScalarValue::UInt8(value) => {
-                    sum_accumulate!(self, value, UInt8Array, UInt64, u64);
-                }
-                ScalarValue::UInt16(value) => {
-                    sum_accumulate!(self, value, UInt16Array, UInt64, u64);
-                }
-                ScalarValue::UInt32(value) => {
-                    sum_accumulate!(self, value, UInt32Array, UInt64, u64);
-                }
-                ScalarValue::UInt64(value) => {
-                    sum_accumulate!(self, value, UInt64Array, UInt64, u64);
-                }
-                ScalarValue::Float32(value) => {
-                    sum_accumulate!(self, value, Float32Array, Float32, f32);
-                }
-                ScalarValue::Float64(value) => {
-                    sum_accumulate!(self, value, Float64Array, Float64, f64);
-                }
-                other => {
-                    return Err(ExecutionError::General(format!(
-                        "SUM does not support {:?}",
-                        other
-                    )))
-                }
-            }
-        }
-        Ok(())
+impl SumAccumulator {
+    /// new sum accumulator
+    pub fn try_new(data_type: &DataType) -> Result<Self> {
+        Ok(Self {
+            sum: ScalarValue::try_from(data_type)?,
+        })
     }
+}
 
-    fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
-        let sum = match array.data_type() {
-            DataType::UInt8 => {
-                match compute::sum(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
-                    Some(n) => Ok(Some(ScalarValue::UInt8(n))),
-                    None => Ok(None),
-                }
+// returns the new value after sum with the new values, taking nullability into account
+macro_rules! typed_sum_accumulate {
+    ($OLD_VALUE:expr, $NEW_VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TYPE:ident) => {{
+        let array = $NEW_VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
+        let delta = compute::sum(array);
+        if $OLD_VALUE.is_none() {
+            ScalarValue::$SCALAR(delta.and_then(|e| Some(e as $TYPE)))

Review comment:
       👍 

##########
File path: rust/datafusion/src/physical_plan/planner.rs
##########
@@ -218,32 +217,18 @@ impl DefaultPhysicalPlanner {
                     .collect::<Result<Vec<_>>>()?;
                 let aggregates = aggr_expr
                     .iter()
-                    .map(|e| {
-                        tuple_err((
-                            self.create_aggregate_expr(e, &input_schema, ctx_state),
-                            e.name(&input_schema),
-                        ))
-                    })
+                    .map(|e| self.create_aggregate_expr(e, &input_schema, ctx_state))
                     .collect::<Result<Vec<_>>>()?;
 
-                let initial_aggr = HashAggregateExec::try_new(
+                let initial_aggr = Arc::new(HashAggregateExec::try_new(
                     AggregateMode::Partial,
                     groups.clone(),
                     aggregates.clone(),
                     input,
-                )?;
+                )?);
 
-                if initial_aggr.output_partitioning().partition_count() == 1 {

Review comment:
       I a not sure I understand what you are getting at with this comment

##########
File path: rust/arrow/src/compute/kernels/aggregate.rs
##########
@@ -19,9 +19,42 @@
 
 use std::ops::Add;
 
-use crate::array::{Array, PrimitiveArray};
+use crate::array::{Array, LargeStringArray, PrimitiveArray, StringArray};
 use crate::datatypes::ArrowNumericType;
 
+/// Helper macro to perform min/max of strings
+macro_rules! min_max_string_helper {
+    ($array:expr, $cmp:tt) => {{
+        let null_count = $array.null_count();
+
+        if null_count == $array.len() {
+            return None
+        }
+        let mut n = "";
+        let mut has_value = false;
+        let data = $array.data();
+
+        if null_count == 0 {

Review comment:
       so the idea here is that the code is faster if we don't have to check each element for nulls?

##########
File path: rust/datafusion/src/physical_plan/aggregates.rs
##########
@@ -103,42 +103,54 @@ pub fn create_aggregate_expr(
     fun: &AggregateFunction,
     args: &Vec<Arc<dyn PhysicalExpr>>,
     input_schema: &Schema,
+    name: String,
 ) -> Result<Arc<dyn AggregateExpr>> {
     // coerce
     let arg = coerce(args, input_schema, &signature(fun))?[0].clone();
 
+    let arg_types = args
+        .iter()
+        .map(|e| e.data_type(input_schema))
+        .collect::<Result<Vec<_>>>()?;
+
+    let return_type = return_type(&fun, &arg_types)?;
+
     Ok(match fun {
-        AggregateFunction::Count => expressions::count(arg),
-        AggregateFunction::Sum => expressions::sum(arg),
-        AggregateFunction::Min => expressions::min(arg),
-        AggregateFunction::Max => expressions::max(arg),
-        AggregateFunction::Avg => expressions::avg(arg),
+        AggregateFunction::Count => {
+            Arc::new(expressions::Count::new(arg, name, return_type))
+        }
+        AggregateFunction::Sum => Arc::new(expressions::Sum::new(arg, name, return_type)),
+        AggregateFunction::Min => Arc::new(expressions::Min::new(arg, name, return_type)),
+        AggregateFunction::Max => Arc::new(expressions::Max::new(arg, name, return_type)),
+        AggregateFunction::Avg => Arc::new(expressions::Avg::new(arg, name, return_type)),
     })
 }
 
+static NUMERICS: &'static [DataType] = &[
+    DataType::Int8,
+    DataType::Int16,
+    DataType::Int32,
+    DataType::Int64,
+    DataType::UInt8,
+    DataType::UInt16,
+    DataType::UInt32,
+    DataType::UInt64,
+    DataType::Float32,
+    DataType::Float64,
+];
+
 /// the signatures supported by the function `fun`.
 fn signature(fun: &AggregateFunction) -> Signature {
     // note: the physical expression must accept the type returned by this function or the execution panics.
-
     match fun {
         AggregateFunction::Count => Signature::Any(1),
-        AggregateFunction::Min
-        | AggregateFunction::Max
-        | AggregateFunction::Avg
-        | AggregateFunction::Sum => Signature::Uniform(
-            1,
-            vec![
-                DataType::Int8,
-                DataType::Int16,
-                DataType::Int32,
-                DataType::Int64,
-                DataType::UInt8,
-                DataType::UInt16,
-                DataType::UInt32,
-                DataType::UInt64,
-                DataType::Float32,
-                DataType::Float64,
-            ],
-        ),
+        AggregateFunction::Min | AggregateFunction::Max => {
+            let mut valid = vec![DataType::Utf8, DataType::LargeUtf8];
+            valid.extend_from_slice(NUMERICS);

Review comment:
       This is fancy -- do I correctly understand that it allows multiple columns in an aggregate expression (like `MAX(c1, c2)...`)?

##########
File path: rust/datafusion/src/physical_plan/expressions.rs
##########
@@ -1835,88 +1797,18 @@ mod tests {
         Ok(())
     }
 
-    #[test]
-    fn sum_contract() -> Result<()> {

Review comment:
       makes sense




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org