You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/22 14:29:24 UTC

[arrow-datafusion] branch main updated: Update documentation for creating User Defined Aggregates (AggregateUDF) (#6729)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new eb290a0bb9 Update documentation for creating User Defined Aggregates (AggregateUDF) (#6729)
eb290a0bb9 is described below

commit eb290a0bb93eefce7f901152d77d22e39648fb67
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Jun 22 10:29:18 2023 -0400

    Update documentation for creating User Defined Aggregates (AggregateUDF) (#6729)
    
    * Update documentation for creating User Defined Aggregates (AggregateUDF)
    
    * Fix other references
---
 datafusion/core/src/lib.rs                         |   9 +-
 datafusion/core/tests/user_defined_aggregates.rs   |   6 +-
 datafusion/expr/src/accumulator.rs                 | 174 ++++++++++++++++-----
 datafusion/expr/src/expr_fn.rs                     |   4 +-
 datafusion/expr/src/function.rs                    |   2 +-
 datafusion/expr/src/lib.rs                         |   2 +-
 datafusion/expr/src/udaf.rs                        |  26 ++-
 datafusion/optimizer/src/analyzer/type_coercion.rs |   9 +-
 .../optimizer/src/common_subexpr_eliminate.rs      |   5 +-
 datafusion/proto/src/physical_plan/mod.rs          |   5 +-
 10 files changed, 171 insertions(+), 71 deletions(-)

diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs
index 8d090b0862..0accf77399 100644
--- a/datafusion/core/src/lib.rs
+++ b/datafusion/core/src/lib.rs
@@ -17,9 +17,9 @@
 #![warn(missing_docs, clippy::needless_borrow)]
 
 //! [DataFusion] is an extensible query engine written in Rust that
-//! uses [Apache Arrow] as its in-memory format. DataFusion's [use
-//! cases] include building very fast database and analytic systems,
-//! customized to particular workloads.
+//! uses [Apache Arrow] as its in-memory format. DataFusion's many [use
+//! cases] help developers build very fast and feature rich database
+//! and analytic systems, customized to particular workloads.
 //!
 //! "Out of the box," DataFusion quickly runs complex [SQL] and
 //! [`DataFrame`] queries using a sophisticated query planner, a columnar,
@@ -143,8 +143,7 @@
 //! * read from any datasource ([`TableProvider`])
 //! * define your own catalogs, schemas, and table lists ([`CatalogProvider`])
 //! * build your own query langue or plans using the ([`LogicalPlanBuilder`])
-//! * declare and use user-defined scalar functions ([`ScalarUDF`])
-//! * declare and use user-defined aggregate functions ([`AggregateUDF`])
+//! * declare and use user-defined functions: ([`ScalarUDF`], and [`AggregateUDF`])
 //! * add custom optimizer rewrite passes ([`OptimizerRule`] and [`PhysicalOptimizerRule`])
 //! * extend the planner to use user-defined logical and physical nodes ([`QueryPlanner`])
 //!
diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs
index 4202b9bea9..1e8299bfbd 100644
--- a/datafusion/core/tests/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined_aggregates.rs
@@ -33,7 +33,7 @@ use datafusion::{
     assert_batches_eq,
     error::Result,
     logical_expr::{
-        AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
+        AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
         StateTypeFunction, TypeSignature, Volatility,
     },
     physical_plan::Accumulator,
@@ -296,7 +296,7 @@ impl TimeSum {
         let signature = Signature::exact(vec![timestamp_type], volatility);
 
         let captured_state = Arc::clone(&test_state);
-        let accumulator: AccumulatorFunctionImplementation =
+        let accumulator: AccumulatorFactoryFunction =
             Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));
 
         let name = "time_sum";
@@ -396,7 +396,7 @@ impl FirstSelector {
         // Possible input signatures
         let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
 
-        let accumulator: AccumulatorFunctionImplementation =
+        let accumulator: AccumulatorFactoryFunction =
             Arc::new(|_| Ok(Box::new(Self::new())));
 
         let volatility = Volatility::Immutable;
diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs
index c448ed4235..9ac3dac5e5 100644
--- a/datafusion/expr/src/accumulator.rs
+++ b/datafusion/expr/src/accumulator.rs
@@ -21,49 +21,161 @@ use arrow::array::ArrayRef;
 use datafusion_common::{DataFusionError, Result, ScalarValue};
 use std::fmt::Debug;
 
-/// Accumulates an aggregate's state.
+/// Describes an aggregate functions's state.
 ///
-/// `Accumulator`s are stateful objects that lives throughout the
+/// `Accumulator`s are stateful objects that live throughout the
 /// evaluation of multiple rows and aggregate multiple values together
 /// into a final output aggregate.
 ///
 /// An accumulator knows how to:
-/// * update its state from inputs via `update_batch`
-/// * (optionally) retract an update to its state from given inputs via `retract_batch`
-/// * convert its internal state to a vector of aggregate values
-/// * update its state from multiple accumulators' states via `merge_batch`
-/// * compute the final value from its internal state via `evaluate`
+/// * update its state from inputs via [`update_batch`]
+///
+/// * compute the final value from its internal state via [`evaluate`]
+///
+/// * retract an update to its state from given inputs via
+/// [`retract_batch`] (when used as a window aggregate [window
+/// function])
+///
+/// * convert its internal state to a vector of aggregate values via
+/// [`state`] and combine the state from multiple accumulators'
+/// via [`merge_batch`], as part of efficient multi-phase grouping.
+///
+/// [`update_batch`]: Self::update_batch
+/// [`retract_batch`]: Self::retract_batch
+/// [`state`]: Self::state
+/// [`evaluate`]: Self::evaluate
+/// [`merge_batch`]: Self::merge_batch
+/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
 pub trait Accumulator: Send + Sync + Debug {
-    /// Returns the partial intermediate state of the accumulator. This
-    /// partial state is serialied as `Arrays` and then combined with
-    /// other partial states from different instances of this
-    /// accumulator (that ran on different partitions, for
-    /// example).
+    /// Updates the accumulator's state from its input.
     ///
-    /// The state can be and often is a different type than the output
-    /// type of the [`Accumulator`].
+    /// `values` contains the arguments to this aggregate function.
+    ///
+    /// For example, the `SUM` accumulator maintains a running sum,
+    /// and `update_batch` adds each of the input values to the
+    /// running sum.
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
+
+    /// Returns the final aggregate value.
+    ///
+    /// For example, the `SUM` accumulator maintains a running sum,
+    /// and `evaluate` will produce that running sum as its output.
+    fn evaluate(&self) -> Result<ScalarValue>;
+
+    /// Returns the allocated size required for this accumulator, in
+    /// bytes, including `Self`.
     ///
-    /// See [`Self::merge_batch`] for more details on the merging process.
+    /// This value is used to calculate the memory used during
+    /// execution so DataFusion can stay within its allotted limit.
+    ///
+    /// "Allocated" means that for internal containers such as `Vec`,
+    /// the `capacity` should be used not the `len`.
+    fn size(&self) -> usize;
+
+    /// Returns the intermediate state of the accumulator.
+    ///
+    /// Intermediate state is used for "multi-phase" grouping in
+    /// DataFusion, where an aggregate is computed in parallel with
+    /// multiple `Accumulator` instances, as illustrated below:
+    ///
+    /// # MultiPhase Grouping
+    ///
+    /// ```text
+    ///                               ▲
+    ///                               │                   evaluate() is called to
+    ///                               │                   produce the final aggregate
+    ///                               │                   value per group
+    ///                               │
+    ///                  ┌─────────────────────────┐
+    ///                  │GroupBy                  │
+    ///                  │(AggregateMode::Final)   │      state() is called for each
+    ///                  │                         │      group and the resulting
+    ///                  └─────────────────────────┘      RecordBatches passed to the
+    ///                               ▲
+    ///                               │
+    ///              ┌────────────────┴───────────────┐
+    ///              │                                │
+    ///              │                                │
+    /// ┌─────────────────────────┐      ┌─────────────────────────┐
+    /// │        GroubyBy         │      │        GroubyBy         │
+    /// │(AggregateMode::Partial) │      │(AggregateMode::Partial) │
+    /// └─────────────────────────┘      └────────────▲────────────┘
+    ///              ▲                                │
+    ///              │                                │    update_batch() is called for
+    ///              │                                │    each input RecordBatch
+    ///         .─────────.                      .─────────.
+    ///      ,─'           '─.                ,─'           '─.
+    ///     ;      Input      :              ;      Input      :
+    ///     :   Partition 0   ;              :   Partition 1   ;
+    ///      ╲               ╱                ╲               ╱
+    ///       '─.         ,─'                  '─.         ,─'
+    ///          `───────'                        `───────'
+    /// ```
+    ///
+    /// The partial state is serialied as `Arrays` and then combined
+    /// with other partial states from different instances of this
+    /// Accumulator (that ran on different partitions, for example).
+    ///
+    /// The state can be and often is a different type than the output
+    /// type of the [`Accumulator`] and needs different merge
+    /// operations (for example, the partial state for `COUNT` needs
+    /// to be summed together)
     ///
     /// Some accumulators can return multiple values for their
     /// intermediate states. For example average, tracks `sum` and
     ///  `n`, and this function should return
     /// a vector of two values, sum and n.
     ///
-    /// `ScalarValue::List` can also be used to pass multiple values
-    /// if the number of intermediate values is not known at planning
-    /// time (e.g. median)
+    /// Note that [`ScalarValue::List`] can be used to pass multiple
+    /// values if the number of intermediate values is not known at
+    /// planning time (e.g. for `MEDIAN`)
     fn state(&self) -> Result<Vec<ScalarValue>>;
 
-    /// Updates the accumulator's state from a vector of arrays.
-    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
+    /// Updates the accumulator's state from an `Array` containing one
+    /// or more intermediate values.
+    ///
+    /// For some aggregates (such as `SUM`), merge_batch is the same
+    /// as `update_batch`, but for some aggregrates (such as `COUNT`)
+    /// the operations differ. See [`Self::state`] for more details on how
+    /// state is used and merged.
+    ///
+    /// The `states` array passed was formed by concatenating the
+    /// results of calling [`Self::state`] on zero or more other
+    /// `Accumulator` instances.
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;
 
-    /// Retracts an update (caused by the given inputs) to
+    /// Retracts (removed) an update (caused by the given inputs) to
     /// accumulator's state.
     ///
     /// This is the inverse operation of [`Self::update_batch`] and is used
-    /// to incrementally calculate window aggregates where the OVER
+    /// to incrementally calculate window aggregates where the `OVER`
     /// clause defines a bounded window.
+    ///
+    /// # Example
+    ///
+    /// For example, given the following input partition
+    ///
+    /// ```text
+    ///                     │      current      │
+    ///                            window
+    ///                     │                   │
+    ///                ┌────┬────┬────┬────┬────┬────┬────┬────┬────┐
+    ///     Input      │ A  │ B  │ C  │ D  │ E  │ F  │ G  │ H  │ I  │
+    ///   partition    └────┴────┴────┴────┼────┴────┴────┴────┼────┘
+    ///
+    ///                                    │         next      │
+    ///                                             window
+    /// ```
+    ///
+    /// First, [`Self::evaluate`] will be called to produce the output
+    /// for the current window.
+    ///
+    /// Then, to advance to the next window:
+    ///
+    /// First, [`Self::retract_batch`] will be called with the values
+    /// that are leaving the window, `[B, C, D]` and then
+    /// [`Self::update_batch`] will be called with the values that are
+    /// entering the window, `[F, G, H]`.
     fn retract_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
         // TODO add retract for all accumulators
         Err(DataFusionError::Internal(
@@ -80,22 +192,4 @@ pub trait Accumulator: Send + Sync + Debug {
     fn supports_retract_batch(&self) -> bool {
         false
     }
-
-    /// Updates the accumulator's state from an `Array` containing one
-    /// or more intermediate values.
-    ///
-    /// The `states` array passed was formed by concatenating the
-    /// results of calling `[state]` on zero or more other accumulator
-    /// instances.
-    ///
-    /// `states`  is an array of the same types as returned by [`Self::state`]
-    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;
-
-    /// Returns the final aggregate value based on its current state.
-    fn evaluate(&self) -> Result<ScalarValue>;
-
-    /// Allocated size required for this accumulator, in bytes, including `Self`.
-    /// Allocated means that for internal containers such as `Vec`, the `capacity` should be used
-    /// not the `len`
-    fn size(&self) -> usize;
 }
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 64d0280203..a3fa7c6ac3 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -23,7 +23,7 @@ use crate::expr::{
 };
 use crate::{
     aggregate_function, built_in_function, conditional_expressions::CaseBuilder,
-    logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
+    logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF,
     BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
     ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility,
 };
@@ -777,7 +777,7 @@ pub fn create_udaf(
     input_type: DataType,
     return_type: Arc<DataType>,
     volatility: Volatility,
-    accumulator: AccumulatorFunctionImplementation,
+    accumulator: AccumulatorFactoryFunction,
     state_type: Arc<Vec<DataType>>,
 ) -> AggregateUDF {
     let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index bd242c493e..71f87d538d 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -42,7 +42,7 @@ pub type ReturnTypeFunction =
 
 /// Factory that returns an accumulator for the given aggregate, given
 /// its return datatype.
-pub type AccumulatorFunctionImplementation =
+pub type AccumulatorFactoryFunction =
     Arc<dyn Fn(&DataType) -> Result<Box<dyn Accumulator>> + Send + Sync>;
 
 /// Factory that returns the types used by an aggregator to serialize
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index ccb9728877..caebf45a7b 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -64,7 +64,7 @@ pub use expr::{
 pub use expr_fn::*;
 pub use expr_schema::ExprSchemable;
 pub use function::{
-    AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation,
+    AccumulatorFactoryFunction, ReturnTypeFunction, ScalarFunctionImplementation,
     StateTypeFunction,
 };
 pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 1b455a0985..3cb6d1b91a 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -19,20 +19,30 @@
 
 use crate::Expr;
 use crate::{
-    AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction,
+    AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction,
 };
 use std::fmt::{self, Debug, Formatter};
 use std::sync::Arc;
 
-/// Logical representation of a user-defined aggregate function (UDAF).
+/// Logical representation of a user-defined [aggregate function] (UDAF).
 ///
-/// A UDAF is different from a user-defined scalar function (UDF) in
-/// that it is stateful across batches. UDAFs can be used as normal
-/// aggregate functions as well as window functions (the `OVER` clause)
+/// An aggregate function combines the values from multiple input rows
+/// into a single output "aggregate" (summary) row. It is different
+/// from a scalar function because it is stateful across batches. User
+/// defined aggregate functions can be used as normal SQL aggregate
+/// functions (`GROUP BY` clause) as well as window functions (`OVER`
+/// clause).
 ///
-/// For more information, please see [the examples]
+/// `AggregateUDF` provides DataFusion the information needed to plan
+/// and call aggregate functions, including name, type information,
+/// and a factory function to create [`Accumulator`], which peform the
+/// actual aggregation.
+///
+/// For more information, please see [the examples].
 ///
 /// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process
+/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
+/// [`Accumulator`]: crate::Accumulator
 #[derive(Clone)]
 pub struct AggregateUDF {
     /// name
@@ -42,7 +52,7 @@ pub struct AggregateUDF {
     /// Return type
     pub return_type: ReturnTypeFunction,
     /// actual implementation
-    pub accumulator: AccumulatorFunctionImplementation,
+    pub accumulator: AccumulatorFactoryFunction,
     /// the accumulator's state's description as a function of the return type
     pub state_type: StateTypeFunction,
 }
@@ -78,7 +88,7 @@ impl AggregateUDF {
         name: &str,
         signature: &Signature,
         return_type: &ReturnTypeFunction,
-        accumulator: &AccumulatorFunctionImplementation,
+        accumulator: &AccumulatorFactoryFunction,
         state_type: &StateTypeFunction,
     ) -> Self {
         Self {
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 8f551e5d80..412abbfae6 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -783,10 +783,9 @@ mod test {
     use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue};
     use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
     use datafusion_expr::{
-        cast, col, concat, concat_ws, create_udaf, is_true,
-        AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, BinaryExpr,
-        BuiltinScalarFunction, Case, ColumnarValue, ExprSchemable, Filter, Operator,
-        StateTypeFunction, Subquery,
+        cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction,
+        AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case,
+        ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery,
     };
     use datafusion_expr::{
         lit,
@@ -941,7 +940,7 @@ mod test {
             Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
         let state_type: StateTypeFunction =
             Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
-        let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| {
+        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| {
             Ok(Box::new(AvgAccumulator::try_new(
                 &DataType::Float64,
                 &DataType::Float64,
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 0f63ecc2cc..42318d4181 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -773,7 +773,7 @@ mod test {
         avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
     };
     use datafusion_expr::{
-        AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
+        AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
         StateTypeFunction, Volatility,
     };
 
@@ -898,8 +898,7 @@ mod test {
             assert_eq!(inputs, &[DataType::UInt32]);
             Ok(Arc::new(DataType::UInt32))
         });
-        let accumulator: AccumulatorFunctionImplementation =
-            Arc::new(|_| unimplemented!());
+        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
         let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
         let udf_agg = |inner: Expr| {
             Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new(
diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs
index 3c14981355..904cf08f7b 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1297,7 +1297,7 @@ mod roundtrip_tests {
     };
     use datafusion_common::Result;
     use datafusion_expr::{
-        Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction,
+        Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction,
         Signature, StateTypeFunction,
     };
 
@@ -1484,8 +1484,7 @@ mod roundtrip_tests {
 
         let rt_func: ReturnTypeFunction =
             Arc::new(move |_| Ok(Arc::new(DataType::Int64)));
-        let accumulator: AccumulatorFunctionImplementation =
-            Arc::new(|_| Ok(Box::new(Example)));
+        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example)));
         let st_func: StateTypeFunction =
             Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64])));