You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2021/05/27 04:58:06 UTC

[arrow-datafusion] branch master updated: Speed up `create_batch_from_map` (#339)

This is an automated email from the ASF dual-hosted git repository.

jorgecarleitao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e7bd2d  Speed up `create_batch_from_map` (#339)
9e7bd2d is described below

commit 9e7bd2d13643c81e474e023749998ec8efa770a4
Author: Daniƫl Heres <da...@gmail.com>
AuthorDate: Thu May 27 06:57:48 2021 +0200

    Speed up `create_batch_from_map` (#339)
---
 datafusion/src/physical_plan/hash_aggregate.rs | 156 +++++++++++--------------
 datafusion/src/scalar.rs                       | 140 ++++++++++++++++++----
 2 files changed, 182 insertions(+), 114 deletions(-)

diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs
index 5008f49..ffb51b2 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -20,6 +20,7 @@
 use std::any::Any;
 use std::sync::Arc;
 use std::task::{Context, Poll};
+use std::vec;
 
 use ahash::RandomState;
 use futures::{
@@ -32,6 +33,7 @@ use crate::physical_plan::{
     Accumulator, AggregateExpr, DisplayFormatType, Distribution, ExecutionPlan,
     Partitioning, PhysicalExpr, SQLMetric,
 };
+use crate::scalar::ScalarValue;
 
 use arrow::{
     array::{Array, UInt32Builder},
@@ -623,10 +625,12 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<(
             DataType::UInt64 => {
                 dictionary_create_key_for_col::<UInt64Type>(col, row, vec)?;
             }
-            _ => return Err(DataFusionError::Internal(format!(
+            _ => {
+                return Err(DataFusionError::Internal(format!(
                 "Unsupported GROUP BY type (dictionary index type not supported creating key) {}",
                 col.data_type(),
-            ))),
+            )))
+            }
         },
         _ => {
             // This is internal because we should have caught this before.
@@ -957,20 +961,6 @@ impl RecordBatchStream for HashAggregateStream {
     }
 }
 
-/// Given Vec<Vec<ArrayRef>>, concatenates the inners `Vec<ArrayRef>` into `ArrayRef`, returning `Vec<ArrayRef>`
-/// This assumes that `arrays` is not empty.
-fn concatenate(arrays: Vec<Vec<ArrayRef>>) -> ArrowResult<Vec<ArrayRef>> {
-    (0..arrays[0].len())
-        .map(|column| {
-            let array_list = arrays
-                .iter()
-                .map(|a| a[column].as_ref())
-                .collect::<Vec<_>>();
-            compute::concat(&array_list)
-        })
-        .collect::<ArrowResult<Vec<_>>>()
-}
-
 /// Create a RecordBatch with all group keys and accumulator' states or values.
 fn create_batch_from_map(
     mode: &AggregateMode,
@@ -978,84 +968,72 @@ fn create_batch_from_map(
     num_group_expr: usize,
     output_schema: &Schema,
 ) -> ArrowResult<RecordBatch> {
-    // 1. for each key
-    // 2. create single-row ArrayRef with all group expressions
-    // 3. create single-row ArrayRef with all aggregate states or values
-    // 4. collect all in a vector per key of vec<ArrayRef>, vec[i][j]
-    // 5. concatenate the arrays over the second index [j] into a single vec<ArrayRef>.
-    let arrays = accumulators
-        .iter()
-        .map(|(_, (group_by_values, accumulator_set, _))| {
-            // 2.
-            let mut groups = (0..num_group_expr)
-                .map(|i| match &group_by_values[i] {
-                    GroupByScalar::Float32(n) => {
-                        Arc::new(Float32Array::from(vec![(*n).into()] as Vec<f32>))
-                            as ArrayRef
-                    }
-                    GroupByScalar::Float64(n) => {
-                        Arc::new(Float64Array::from(vec![(*n).into()] as Vec<f64>))
-                            as ArrayRef
-                    }
-                    GroupByScalar::Int8(n) => {
-                        Arc::new(Int8Array::from(vec![*n])) as ArrayRef
-                    }
-                    GroupByScalar::Int16(n) => Arc::new(Int16Array::from(vec![*n])),
-                    GroupByScalar::Int32(n) => Arc::new(Int32Array::from(vec![*n])),
-                    GroupByScalar::Int64(n) => Arc::new(Int64Array::from(vec![*n])),
-                    GroupByScalar::UInt8(n) => Arc::new(UInt8Array::from(vec![*n])),
-                    GroupByScalar::UInt16(n) => Arc::new(UInt16Array::from(vec![*n])),
-                    GroupByScalar::UInt32(n) => Arc::new(UInt32Array::from(vec![*n])),
-                    GroupByScalar::UInt64(n) => Arc::new(UInt64Array::from(vec![*n])),
-                    GroupByScalar::Utf8(str) => {
-                        Arc::new(StringArray::from(vec![&***str]))
-                    }
-                    GroupByScalar::LargeUtf8(str) => {
-                        Arc::new(LargeStringArray::from(vec![&***str]))
-                    }
-                    GroupByScalar::Boolean(b) => Arc::new(BooleanArray::from(vec![*b])),
-                    GroupByScalar::TimeMillisecond(n) => {
-                        Arc::new(TimestampMillisecondArray::from(vec![*n]))
-                    }
-                    GroupByScalar::TimeMicrosecond(n) => {
-                        Arc::new(TimestampMicrosecondArray::from(vec![*n]))
-                    }
-                    GroupByScalar::TimeNanosecond(n) => {
-                        Arc::new(TimestampNanosecondArray::from_vec(vec![*n], None))
-                    }
-                    GroupByScalar::Date32(n) => Arc::new(Date32Array::from(vec![*n])),
-                })
-                .collect::<Vec<ArrayRef>>();
+    if accumulators.is_empty() {
+        return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned())));
+    }
+    let (_, (_, accs, _)) = accumulators.iter().next().unwrap();
+    let mut acc_data_types: Vec<usize> = vec![];
 
-            // 3.
-            groups.extend(
-                finalize_aggregation(accumulator_set, mode)
-                    .map_err(DataFusionError::into_arrow_external_error)?,
-            );
+    // Calculate number/shape of state arrays
+    match mode {
+        AggregateMode::Partial => {
+            for acc in accs.iter() {
+                let state = acc
+                    .state()
+                    .map_err(DataFusionError::into_arrow_external_error)?;
+                acc_data_types.push(state.len());
+            }
+        }
+        AggregateMode::Final | AggregateMode::FinalPartitioned => {
+            acc_data_types = vec![1; accs.len()];
+        }
+    }
 
-            Ok(groups)
+    let mut columns = (0..num_group_expr)
+        .map(|i| {
+            ScalarValue::iter_to_array(accumulators.into_iter().map(
+                |(_, (group_by_values, _, _))| ScalarValue::from(&group_by_values[i]),
+            ))
         })
-        // 4.
-        .collect::<ArrowResult<Vec<Vec<ArrayRef>>>>()?;
+        .collect::<Result<Vec<_>>>()
+        .map_err(|x| x.into_arrow_external_error())?;
+
+    // add state / evaluated arrays
+    for (x, &state_len) in acc_data_types.iter().enumerate() {
+        for y in 0..state_len {
+            match mode {
+                AggregateMode::Partial => {
+                    let res = ScalarValue::iter_to_array(accumulators.into_iter().map(
+                        |(_, (_, accumulator, _))| {
+                            let x = accumulator[x].state().unwrap();
+                            x[y].clone()
+                        },
+                    ))
+                    .map_err(DataFusionError::into_arrow_external_error)?;
+
+                    columns.push(res);
+                }
+                AggregateMode::Final | AggregateMode::FinalPartitioned => {
+                    let res = ScalarValue::iter_to_array(accumulators.into_iter().map(
+                        |(_, (_, accumulator, _))| accumulator[x].evaluate().unwrap(),
+                    ))
+                    .map_err(DataFusionError::into_arrow_external_error)?;
+                    columns.push(res);
+                }
+            }
+        }
+    }
 
-    let batch = if !arrays.is_empty() {
-        // 5.
-        let columns = concatenate(arrays)?;
+    // cast output if needed (e.g. for types like Dictionary where
+    // the intermediate GroupByScalar type was not the same as the
+    // output
+    let columns = columns
+        .iter()
+        .zip(output_schema.fields().iter())
+        .map(|(col, desired_field)| cast(col, desired_field.data_type()))
+        .collect::<ArrowResult<Vec<_>>>()?;
 
-        // cast output if needed (e.g. for types like Dictionary where
-        // the intermediate GroupByScalar type was not the same as the
-        // output
-        let columns = columns
-            .iter()
-            .zip(output_schema.fields().iter())
-            .map(|(col, desired_field)| cast(col, desired_field.data_type()))
-            .collect::<ArrowResult<Vec<_>>>()?;
-
-        RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)?
-    } else {
-        RecordBatch::new_empty(Arc::new(output_schema.to_owned()))
-    };
-    Ok(batch)
+    RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)
 }
 
 fn create_accumulators(
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index f3fa5b2..ac7deee 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -21,10 +21,10 @@ use crate::error::{DataFusionError, Result};
 use arrow::{
     array::*,
     datatypes::{
-        ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type,
-        Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit, TimestampMicrosecondType,
-        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
-        UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+        ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type,
+        Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, TimeUnit,
+        TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
+        TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
     },
 };
 use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
@@ -311,7 +311,7 @@ impl ScalarValue {
     /// ];
     ///
     /// // Build an Array from the list of ScalarValues
-    /// let array = ScalarValue::iter_to_array(scalars.iter())
+    /// let array = ScalarValue::iter_to_array(scalars.into_iter())
     ///   .unwrap();
     ///
     /// let expected: ArrayRef = std::sync::Arc::new(
@@ -324,8 +324,8 @@ impl ScalarValue {
     ///
     /// assert_eq!(&array, &expected);
     /// ```
-    pub fn iter_to_array<'a>(
-        scalars: impl IntoIterator<Item = &'a ScalarValue>,
+    pub fn iter_to_array(
+        scalars: impl IntoIterator<Item = ScalarValue>,
     ) -> Result<ArrayRef> {
         let mut scalars = scalars.into_iter().peekable();
 
@@ -344,10 +344,10 @@ impl ScalarValue {
         macro_rules! build_array_primitive {
             ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{
                 {
-                    let values = scalars
+                    let array = scalars
                         .map(|sv| {
                             if let ScalarValue::$SCALAR_TY(v) = sv {
-                                Ok(*v)
+                                Ok(v)
                             } else {
                                 Err(DataFusionError::Internal(format!(
                                     "Inconsistent types in ScalarValue::iter_to_array. \
@@ -356,9 +356,8 @@ impl ScalarValue {
                                 )))
                             }
                         })
-                        .collect::<Result<Vec<_>>>()?;
+                        .collect::<Result<$ARRAY_TY>>()?;
 
-                    let array: $ARRAY_TY = values.iter().collect();
                     Arc::new(array)
                 }
             }};
@@ -369,7 +368,7 @@ impl ScalarValue {
         macro_rules! build_array_string {
             ($ARRAY_TY:ident, $SCALAR_TY:ident) => {{
                 {
-                    let values = scalars
+                    let array = scalars
                         .map(|sv| {
                             if let ScalarValue::$SCALAR_TY(v) = sv {
                                 Ok(v)
@@ -381,19 +380,74 @@ impl ScalarValue {
                                 )))
                             }
                         })
-                        .collect::<Result<Vec<_>>>()?;
-
-                    // it is annoying that one can not create
-                    // StringArray et al directly from iter of &String,
-                    // requiring this map to &str
-                    let values = values.iter().map(|s| s.as_ref());
-
-                    let array: $ARRAY_TY = values.collect();
+                        .collect::<Result<$ARRAY_TY>>()?;
                     Arc::new(array)
                 }
             }};
         }
 
+        macro_rules! build_array_list_primitive {
+            ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{
+                Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>(
+                    scalars.into_iter().map(|x| match x {
+                        ScalarValue::List(xs, _) => xs.map(|x| {
+                            x.iter()
+                                .map(|x| match x {
+                                    ScalarValue::$SCALAR_TY(i) => *i,
+                                    sv => panic!("Inconsistent types in ScalarValue::iter_to_array. \
+                                    Expected {:?}, got {:?}", data_type, sv),
+                                })
+                                .collect::<Vec<Option<$NATIVE_TYPE>>>()
+                        }),
+                        sv => panic!("Inconsistent types in ScalarValue::iter_to_array. \
+                        Expected {:?}, got {:?}", data_type, sv),
+        }),
+                ))
+            }};
+        }
+
+        macro_rules! build_array_list_string {
+            ($BUILDER:ident, $SCALAR_TY:ident) => {{
+                let mut builder = ListBuilder::new($BUILDER::new(0));
+
+                for scalar in scalars.into_iter() {
+                    match scalar {
+                        ScalarValue::List(Some(xs), _) => {
+                            for s in xs {
+                                match s {
+                                    ScalarValue::$SCALAR_TY(Some(val)) => {
+                                        builder.values().append_value(val)?;
+                                    }
+                                    ScalarValue::$SCALAR_TY(None) => {
+                                        builder.values().append_null()?;
+                                    }
+                                    sv => return Err(DataFusionError::Internal(format!(
+                                        "Inconsistent types in ScalarValue::iter_to_array. \
+                                         Expected Utf8, got {:?}",
+                                        sv
+                                    ))),
+                                }
+                            }
+                            builder.append(true)?;
+                        }
+                        ScalarValue::List(None, _) => {
+                            builder.append(false)?;
+                        }
+                        sv => {
+                            return Err(DataFusionError::Internal(format!(
+                                "Inconsistent types in ScalarValue::iter_to_array. \
+                             Expected List, got {:?}",
+                                sv
+                            )))
+                        }
+                    }
+                }
+
+                Arc::new(builder.finish())
+
+            }}
+        }
+
         let array: ArrayRef = match &data_type {
             DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
             DataType::Float32 => build_array_primitive!(Float32Array, Float32),
@@ -430,6 +484,42 @@ impl ScalarValue {
             DataType::Interval(IntervalUnit::YearMonth) => {
                 build_array_primitive!(IntervalYearMonthArray, IntervalYearMonth)
             }
+            DataType::List(fields) if fields.data_type() == &DataType::Int8 => {
+                build_array_list_primitive!(Int8Type, Int8, i8)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::Int16 => {
+                build_array_list_primitive!(Int16Type, Int16, i16)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::Int32 => {
+                build_array_list_primitive!(Int32Type, Int32, i32)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::Int64 => {
+                build_array_list_primitive!(Int64Type, Int64, i64)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::UInt8 => {
+                build_array_list_primitive!(UInt8Type, UInt8, u8)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::UInt16 => {
+                build_array_list_primitive!(UInt16Type, UInt16, u16)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::UInt32 => {
+                build_array_list_primitive!(UInt32Type, UInt32, u32)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::UInt64 => {
+                build_array_list_primitive!(UInt64Type, UInt64, u64)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::Float32 => {
+                build_array_list_primitive!(Float32Type, Float32, f32)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::Float64 => {
+                build_array_list_primitive!(Float64Type, Float64, f64)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::Utf8 => {
+                build_array_list_string!(StringBuilder, Utf8)
+            }
+            DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => {
+                build_array_list_string!(LargeStringBuilder, LargeUtf8)
+            }
             _ => {
                 return Err(DataFusionError::Internal(format!(
                     "Unsupported creation of {:?} array from ScalarValue {:?}",
@@ -1102,7 +1192,7 @@ mod tests {
             let scalars: Vec<_> =
                 $INPUT.iter().map(|v| ScalarValue::$SCALAR_T(*v)).collect();
 
-            let array = ScalarValue::iter_to_array(scalars.iter()).unwrap();
+            let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap();
 
             let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT));
 
@@ -1119,7 +1209,7 @@ mod tests {
                 .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_string())))
                 .collect();
 
-            let array = ScalarValue::iter_to_array(scalars.iter()).unwrap();
+            let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap();
 
             let expected: ArrayRef = Arc::new($ARRAYTYPE::from($INPUT));
 
@@ -1136,7 +1226,7 @@ mod tests {
                 .map(|v| ScalarValue::$SCALAR_T(v.map(|v| v.to_vec())))
                 .collect();
 
-            let array = ScalarValue::iter_to_array(scalars.iter()).unwrap();
+            let array = ScalarValue::iter_to_array(scalars.into_iter()).unwrap();
 
             let expected: $ARRAYTYPE =
                 $INPUT.iter().map(|v| v.map(|v| v.to_vec())).collect();
@@ -1210,7 +1300,7 @@ mod tests {
     fn scalar_iter_to_array_empty() {
         let scalars = vec![] as Vec<ScalarValue>;
 
-        let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err();
+        let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err();
         assert!(
             result
                 .to_string()
@@ -1226,7 +1316,7 @@ mod tests {
         // If the scalar values are not all the correct type, error here
         let scalars: Vec<ScalarValue> = vec![Boolean(Some(true)), Int32(Some(5))];
 
-        let result = ScalarValue::iter_to_array(scalars.iter()).unwrap_err();
+        let result = ScalarValue::iter_to_array(scalars.into_iter()).unwrap_err();
         assert!(result.to_string().contains("Inconsistent types in ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"),
                 "{}", result);
     }