You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/07 16:43:17 UTC

[arrow-datafusion] branch master updated: Refactor window aggregation, simplify batch processing logic (#516)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f73e79  Refactor window aggregation, simplify batch processing logic (#516)
2f73e79 is described below

commit 2f73e795d3ae68638d6509bfa02388bfa3727381
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Tue Jun 8 00:43:09 2021 +0800

    Refactor window aggregation, simplify batch processing logic (#516)
    
    * refactor sort exec stream and combine batches
    
    * refactor async function
---
 datafusion/src/physical_plan/sort.rs    |   1 -
 datafusion/src/physical_plan/windows.rs | 149 +++++++++++++++-----------------
 2 files changed, 71 insertions(+), 79 deletions(-)

diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs
index 7747030..437519a 100644
--- a/datafusion/src/physical_plan/sort.rs
+++ b/datafusion/src/physical_plan/sort.rs
@@ -241,7 +241,6 @@ impl SortStream {
         sort_time: Arc<SQLMetric>,
     ) -> Self {
         let (tx, rx) = futures::channel::oneshot::channel();
-
         let schema = input.schema();
         tokio::spawn(async move {
             let schema = input.schema();
diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs
index 659d218..7eb1494 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -19,7 +19,7 @@
 
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{
-    aggregates,
+    aggregates, common,
     expressions::{Literal, NthValue, RowNumber},
     type_coercion::coerce,
     window_functions::signature_for_built_in,
@@ -29,20 +29,18 @@ use crate::physical_plan::{
     RecordBatchStream, SendableRecordBatchStream, WindowAccumulator, WindowExpr,
 };
 use crate::scalar::ScalarValue;
-use arrow::compute::concat;
 use arrow::{
-    array::{Array, ArrayRef},
+    array::ArrayRef,
     datatypes::{Field, Schema, SchemaRef},
     error::{ArrowError, Result as ArrowResult},
     record_batch::RecordBatch,
 };
 use async_trait::async_trait;
-use futures::stream::{Stream, StreamExt};
+use futures::stream::Stream;
 use futures::Future;
 use pin_project_lite::pin_project;
 use std::any::Any;
 use std::convert::TryInto;
-use std::iter;
 use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
@@ -339,22 +337,15 @@ fn window_aggregate_batch(
     window_accumulators: &mut [WindowAccumulatorItem],
     expressions: &[Vec<Arc<dyn PhysicalExpr>>],
 ) -> Result<Vec<Option<ArrayRef>>> {
-    // 1.1 iterate accumulators and respective expressions together
-    // 1.2 evaluate expressions
-    // 1.3 update / merge window accumulators with the expressions' values
-
-    // 1.1
     window_accumulators
         .iter_mut()
         .zip(expressions)
         .map(|(window_acc, expr)| {
-            // 1.2
             let values = &expr
                 .iter()
-                .map(|e| e.evaluate(batch))
+                .map(|e| e.evaluate(&batch))
                 .map(|r| r.map(|v| v.into_array(batch.num_rows())))
                 .collect::<Result<Vec<_>>>()?;
-
             window_acc.scan_batch(batch.num_rows(), values)
         })
         .into_iter()
@@ -380,60 +371,50 @@ fn create_window_accumulators(
         .collect::<Result<Vec<_>>>()
 }
 
-async fn compute_window_aggregate(
-    schema: SchemaRef,
+/// Compute the window aggregate columns
+///
+/// 1. get a list of window accumulators
+/// 2. evaluate the args
+/// 3. scan args with window functions
+/// 4. concat with final aggregations
+///
+/// FIXME so far this fn does not support:
+/// 1. partition by
+/// 2. order by
+/// 3. window frame
+///
+/// which will require further work:
+/// 1. inter-partition order by using vec partition-point (https://github.com/apache/arrow-datafusion/issues/360)
+/// 2. inter-partition parallelism using one-shot channel (https://github.com/apache/arrow-datafusion/issues/299)
+/// 3. convert aggregation based window functions to be self-contain so that: (https://github.com/apache/arrow-datafusion/issues/361)
+///    a. some can be grow-only window-accumulating
+///    b. some can be grow-and-shrink window-accumulating
+///    c. some can be based on segment tree
+fn compute_window_aggregates(
     window_expr: Vec<Arc<dyn WindowExpr>>,
-    mut input: SendableRecordBatchStream,
-) -> ArrowResult<RecordBatch> {
-    let mut window_accumulators = create_window_accumulators(&window_expr)
-        .map_err(DataFusionError::into_arrow_external_error)?;
-
-    let expressions = window_expressions(&window_expr)
-        .map_err(DataFusionError::into_arrow_external_error)?;
-
-    let expressions = Arc::new(expressions);
-
-    // TODO each element shall have some size hint
-    let mut accumulator: Vec<Vec<ArrayRef>> =
-        iter::repeat(vec![]).take(window_expr.len()).collect();
-
-    let mut original_batches: Vec<RecordBatch> = vec![];
-
-    let mut total_num_rows = 0;
-
-    while let Some(batch) = input.next().await {
-        let batch = batch?;
-        total_num_rows += batch.num_rows();
-        original_batches.push(batch.clone());
-
-        let batch_aggregated =
-            window_aggregate_batch(&batch, &mut window_accumulators, &expressions)
-                .map_err(DataFusionError::into_arrow_external_error)?;
-        accumulator.iter_mut().zip(batch_aggregated).for_each(
-            |(acc_for_window, window_batch)| {
-                if let Some(data) = window_batch {
-                    acc_for_window.push(data);
-                }
-            },
-        );
+    batch: &RecordBatch,
+) -> Result<Vec<ArrayRef>> {
+    let mut window_accumulators = create_window_accumulators(&window_expr)?;
+    let expressions = Arc::new(window_expressions(&window_expr)?);
+    let num_rows = batch.num_rows();
+    let window_aggregates =
+        window_aggregate_batch(batch, &mut window_accumulators, &expressions)?;
+    let final_aggregates = finalize_window_aggregation(&window_accumulators)?;
+
+    // both must equal to window_expr.len()
+    if window_aggregates.len() != final_aggregates.len() {
+        return Err(DataFusionError::Internal(
+            "Impossibly got len mismatch".to_owned(),
+        ));
     }
 
-    let aggregated_mapped = finalize_window_aggregation(&window_accumulators)
-        .map_err(DataFusionError::into_arrow_external_error)?;
-
-    let mut columns: Vec<ArrayRef> = accumulator
+    window_aggregates
         .iter()
-        .zip(aggregated_mapped)
-        .map(|(acc, agg)| {
-            Ok(match (acc, agg) {
-                (acc, Some(scalar_value)) if acc.is_empty() => {
-                    scalar_value.to_array_of_size(total_num_rows)
-                }
-                (acc, None) if !acc.is_empty() => {
-                    let vec_array: Vec<&dyn Array> =
-                        acc.iter().map(|arc| arc.as_ref()).collect();
-                    concat(&vec_array)?
-                }
+        .zip(final_aggregates)
+        .map(|(wa, fa)| {
+            Ok(match (wa, fa) {
+                (None, Some(fa)) => fa.to_array_of_size(num_rows),
+                (Some(wa), None) if wa.len() == num_rows => wa.clone(),
                 _ => {
                     return Err(DataFusionError::Execution(
                         "Invalid window function behavior".to_owned(),
@@ -441,20 +422,7 @@ async fn compute_window_aggregate(
                 }
             })
         })
-        .collect::<Result<Vec<ArrayRef>>>()
-        .map_err(DataFusionError::into_arrow_external_error)?;
-
-    for i in 0..(schema.fields().len() - window_expr.len()) {
-        let col = concat(
-            &original_batches
-                .iter()
-                .map(|batch| batch.column(i).as_ref())
-                .collect::<Vec<_>>(),
-        )?;
-        columns.push(col);
-    }
-
-    RecordBatch::try_new(schema.clone(), columns)
+        .collect()
 }
 
 impl WindowAggStream {
@@ -467,7 +435,8 @@ impl WindowAggStream {
         let (tx, rx) = futures::channel::oneshot::channel();
         let schema_clone = schema.clone();
         tokio::spawn(async move {
-            let result = compute_window_aggregate(schema_clone, window_expr, input).await;
+            let schema = schema_clone.clone();
+            let result = WindowAggStream::process(input, window_expr, schema).await;
             tx.send(result)
         });
 
@@ -477,6 +446,30 @@ impl WindowAggStream {
             schema,
         }
     }
+
+    async fn process(
+        input: SendableRecordBatchStream,
+        window_expr: Vec<Arc<dyn WindowExpr>>,
+        schema: SchemaRef,
+    ) -> ArrowResult<RecordBatch> {
+        let input_schema = input.schema();
+        let batches = common::collect(input)
+            .await
+            .map_err(DataFusionError::into_arrow_external_error)?;
+        let batch = common::combine_batches(&batches, input_schema.clone())?;
+        if let Some(batch) = batch {
+            // calculate window cols
+            let mut columns = compute_window_aggregates(window_expr, &batch)
+                .map_err(DataFusionError::into_arrow_external_error)?;
+            // combine with the original cols
+            // note the setup of window aggregates is that they newly calculated window
+            // expressions are always prepended to the columns
+            columns.extend_from_slice(batch.columns());
+            RecordBatch::try_new(schema, columns)
+        } else {
+            Ok(RecordBatch::new_empty(schema))
+        }
+    }
 }
 
 impl Stream for WindowAggStream {