You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/22 14:29:24 UTC
[arrow-datafusion] branch main updated: Update documentation for creating User Defined Aggregates (AggregateUDF) (#6729)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new eb290a0bb9 Update documentation for creating User Defined Aggregates (AggregateUDF) (#6729)
eb290a0bb9 is described below
commit eb290a0bb93eefce7f901152d77d22e39648fb67
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Jun 22 10:29:18 2023 -0400
Update documentation for creating User Defined Aggregates (AggregateUDF) (#6729)
* Update documentation for creating User Defined Aggregates (AggregateUDF)
* Fix other references
---
datafusion/core/src/lib.rs | 9 +-
datafusion/core/tests/user_defined_aggregates.rs | 6 +-
datafusion/expr/src/accumulator.rs | 174 ++++++++++++++++-----
datafusion/expr/src/expr_fn.rs | 4 +-
datafusion/expr/src/function.rs | 2 +-
datafusion/expr/src/lib.rs | 2 +-
datafusion/expr/src/udaf.rs | 26 ++-
datafusion/optimizer/src/analyzer/type_coercion.rs | 9 +-
.../optimizer/src/common_subexpr_eliminate.rs | 5 +-
datafusion/proto/src/physical_plan/mod.rs | 5 +-
10 files changed, 171 insertions(+), 71 deletions(-)
diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs
index 8d090b0862..0accf77399 100644
--- a/datafusion/core/src/lib.rs
+++ b/datafusion/core/src/lib.rs
@@ -17,9 +17,9 @@
#![warn(missing_docs, clippy::needless_borrow)]
//! [DataFusion] is an extensible query engine written in Rust that
-//! uses [Apache Arrow] as its in-memory format. DataFusion's [use
-//! cases] include building very fast database and analytic systems,
-//! customized to particular workloads.
+//! uses [Apache Arrow] as its in-memory format. DataFusion's many [use
+//! cases] help developers build very fast and feature rich database
+//! and analytic systems, customized to particular workloads.
//!
//! "Out of the box," DataFusion quickly runs complex [SQL] and
//! [`DataFrame`] queries using a sophisticated query planner, a columnar,
@@ -143,8 +143,7 @@
//! * read from any datasource ([`TableProvider`])
//! * define your own catalogs, schemas, and table lists ([`CatalogProvider`])
//! * build your own query langue or plans using the ([`LogicalPlanBuilder`])
-//! * declare and use user-defined scalar functions ([`ScalarUDF`])
-//! * declare and use user-defined aggregate functions ([`AggregateUDF`])
+//! * declare and use user-defined functions: ([`ScalarUDF`], and [`AggregateUDF`])
//! * add custom optimizer rewrite passes ([`OptimizerRule`] and [`PhysicalOptimizerRule`])
//! * extend the planner to use user-defined logical and physical nodes ([`QueryPlanner`])
//!
diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs
index 4202b9bea9..1e8299bfbd 100644
--- a/datafusion/core/tests/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined_aggregates.rs
@@ -33,7 +33,7 @@ use datafusion::{
assert_batches_eq,
error::Result,
logical_expr::{
- AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
+ AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, TypeSignature, Volatility,
},
physical_plan::Accumulator,
@@ -296,7 +296,7 @@ impl TimeSum {
let signature = Signature::exact(vec![timestamp_type], volatility);
let captured_state = Arc::clone(&test_state);
- let accumulator: AccumulatorFunctionImplementation =
+ let accumulator: AccumulatorFactoryFunction =
Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));
let name = "time_sum";
@@ -396,7 +396,7 @@ impl FirstSelector {
// Possible input signatures
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
- let accumulator: AccumulatorFunctionImplementation =
+ let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::new(Self::new())));
let volatility = Volatility::Immutable;
diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs
index c448ed4235..9ac3dac5e5 100644
--- a/datafusion/expr/src/accumulator.rs
+++ b/datafusion/expr/src/accumulator.rs
@@ -21,49 +21,161 @@ use arrow::array::ArrayRef;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::fmt::Debug;
-/// Accumulates an aggregate's state.
+/// Describes an aggregate functions's state.
///
-/// `Accumulator`s are stateful objects that lives throughout the
+/// `Accumulator`s are stateful objects that live throughout the
/// evaluation of multiple rows and aggregate multiple values together
/// into a final output aggregate.
///
/// An accumulator knows how to:
-/// * update its state from inputs via `update_batch`
-/// * (optionally) retract an update to its state from given inputs via `retract_batch`
-/// * convert its internal state to a vector of aggregate values
-/// * update its state from multiple accumulators' states via `merge_batch`
-/// * compute the final value from its internal state via `evaluate`
+/// * update its state from inputs via [`update_batch`]
+///
+/// * compute the final value from its internal state via [`evaluate`]
+///
+/// * retract an update to its state from given inputs via
+/// [`retract_batch`] (when used as a window aggregate [window
+/// function])
+///
+/// * convert its internal state to a vector of aggregate values via
+/// [`state`] and combine the state from multiple accumulators'
+/// via [`merge_batch`], as part of efficient multi-phase grouping.
+///
+/// [`update_batch`]: Self::update_batch
+/// [`retract_batch`]: Self::retract_batch
+/// [`state`]: Self::state
+/// [`evaluate`]: Self::evaluate
+/// [`merge_batch`]: Self::merge_batch
+/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
pub trait Accumulator: Send + Sync + Debug {
- /// Returns the partial intermediate state of the accumulator. This
- /// partial state is serialied as `Arrays` and then combined with
- /// other partial states from different instances of this
- /// accumulator (that ran on different partitions, for
- /// example).
+ /// Updates the accumulator's state from its input.
///
- /// The state can be and often is a different type than the output
- /// type of the [`Accumulator`].
+ /// `values` contains the arguments to this aggregate function.
+ ///
+ /// For example, the `SUM` accumulator maintains a running sum,
+ /// and `update_batch` adds each of the input values to the
+ /// running sum.
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
+
+ /// Returns the final aggregate value.
+ ///
+ /// For example, the `SUM` accumulator maintains a running sum,
+ /// and `evaluate` will produce that running sum as its output.
+ fn evaluate(&self) -> Result<ScalarValue>;
+
+ /// Returns the allocated size required for this accumulator, in
+ /// bytes, including `Self`.
///
- /// See [`Self::merge_batch`] for more details on the merging process.
+ /// This value is used to calculate the memory used during
+ /// execution so DataFusion can stay within its allotted limit.
+ ///
+ /// "Allocated" means that for internal containers such as `Vec`,
+ /// the `capacity` should be used not the `len`.
+ fn size(&self) -> usize;
+
+ /// Returns the intermediate state of the accumulator.
+ ///
+ /// Intermediate state is used for "multi-phase" grouping in
+ /// DataFusion, where an aggregate is computed in parallel with
+ /// multiple `Accumulator` instances, as illustrated below:
+ ///
+ /// # MultiPhase Grouping
+ ///
+ /// ```text
+ /// ▲
+ /// │ evaluate() is called to
+ /// │ produce the final aggregate
+ /// │ value per group
+ /// │
+ /// ┌─────────────────────────┐
+ /// │GroupBy │
+ /// │(AggregateMode::Final) │ state() is called for each
+ /// │ │ group and the resulting
+ /// └─────────────────────────┘ RecordBatches passed to the
+ /// ▲
+ /// │
+ /// ┌────────────────┴───────────────┐
+ /// │ │
+ /// │ │
+ /// ┌─────────────────────────┐ ┌─────────────────────────┐
+ /// │ GroubyBy │ │ GroubyBy │
+ /// │(AggregateMode::Partial) │ │(AggregateMode::Partial) │
+ /// └─────────────────────────┘ └────────────▲────────────┘
+ /// ▲ │
+ /// │ │ update_batch() is called for
+ /// │ │ each input RecordBatch
+ /// .─────────. .─────────.
+ /// ,─' '─. ,─' '─.
+ /// ; Input : ; Input :
+ /// : Partition 0 ; : Partition 1 ;
+ /// ╲ ╱ ╲ ╱
+ /// '─. ,─' '─. ,─'
+ /// `───────' `───────'
+ /// ```
+ ///
+ /// The partial state is serialied as `Arrays` and then combined
+ /// with other partial states from different instances of this
+ /// Accumulator (that ran on different partitions, for example).
+ ///
+ /// The state can be and often is a different type than the output
+ /// type of the [`Accumulator`] and needs different merge
+ /// operations (for example, the partial state for `COUNT` needs
+ /// to be summed together)
///
/// Some accumulators can return multiple values for their
/// intermediate states. For example average, tracks `sum` and
/// `n`, and this function should return
/// a vector of two values, sum and n.
///
- /// `ScalarValue::List` can also be used to pass multiple values
- /// if the number of intermediate values is not known at planning
- /// time (e.g. median)
+ /// Note that [`ScalarValue::List`] can be used to pass multiple
+ /// values if the number of intermediate values is not known at
+ /// planning time (e.g. for `MEDIAN`)
fn state(&self) -> Result<Vec<ScalarValue>>;
- /// Updates the accumulator's state from a vector of arrays.
- fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
+ /// Updates the accumulator's state from an `Array` containing one
+ /// or more intermediate values.
+ ///
+ /// For some aggregates (such as `SUM`), merge_batch is the same
+ /// as `update_batch`, but for some aggregrates (such as `COUNT`)
+ /// the operations differ. See [`Self::state`] for more details on how
+ /// state is used and merged.
+ ///
+ /// The `states` array passed was formed by concatenating the
+ /// results of calling [`Self::state`] on zero or more other
+ /// `Accumulator` instances.
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;
- /// Retracts an update (caused by the given inputs) to
+ /// Retracts (removed) an update (caused by the given inputs) to
/// accumulator's state.
///
/// This is the inverse operation of [`Self::update_batch`] and is used
- /// to incrementally calculate window aggregates where the OVER
+ /// to incrementally calculate window aggregates where the `OVER`
/// clause defines a bounded window.
+ ///
+ /// # Example
+ ///
+ /// For example, given the following input partition
+ ///
+ /// ```text
+ /// │ current │
+ /// window
+ /// │ │
+ /// ┌────┬────┬────┬────┬────┬────┬────┬────┬────┐
+ /// Input │ A │ B │ C │ D │ E │ F │ G │ H │ I │
+ /// partition └────┴────┴────┴────┼────┴────┴────┴────┼────┘
+ ///
+ /// │ next │
+ /// window
+ /// ```
+ ///
+ /// First, [`Self::evaluate`] will be called to produce the output
+ /// for the current window.
+ ///
+ /// Then, to advance to the next window:
+ ///
+ /// First, [`Self::retract_batch`] will be called with the values
+ /// that are leaving the window, `[B, C, D]` and then
+ /// [`Self::update_batch`] will be called with the values that are
+ /// entering the window, `[F, G, H]`.
fn retract_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
// TODO add retract for all accumulators
Err(DataFusionError::Internal(
@@ -80,22 +192,4 @@ pub trait Accumulator: Send + Sync + Debug {
fn supports_retract_batch(&self) -> bool {
false
}
-
- /// Updates the accumulator's state from an `Array` containing one
- /// or more intermediate values.
- ///
- /// The `states` array passed was formed by concatenating the
- /// results of calling `[state]` on zero or more other accumulator
- /// instances.
- ///
- /// `states` is an array of the same types as returned by [`Self::state`]
- fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;
-
- /// Returns the final aggregate value based on its current state.
- fn evaluate(&self) -> Result<ScalarValue>;
-
- /// Allocated size required for this accumulator, in bytes, including `Self`.
- /// Allocated means that for internal containers such as `Vec`, the `capacity` should be used
- /// not the `len`
- fn size(&self) -> usize;
}
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 64d0280203..a3fa7c6ac3 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -23,7 +23,7 @@ use crate::expr::{
};
use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder,
- logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
+ logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility,
};
@@ -777,7 +777,7 @@ pub fn create_udaf(
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
- accumulator: AccumulatorFunctionImplementation,
+ accumulator: AccumulatorFactoryFunction,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index bd242c493e..71f87d538d 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -42,7 +42,7 @@ pub type ReturnTypeFunction =
/// Factory that returns an accumulator for the given aggregate, given
/// its return datatype.
-pub type AccumulatorFunctionImplementation =
+pub type AccumulatorFactoryFunction =
Arc<dyn Fn(&DataType) -> Result<Box<dyn Accumulator>> + Send + Sync>;
/// Factory that returns the types used by an aggregator to serialize
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index ccb9728877..caebf45a7b 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -64,7 +64,7 @@ pub use expr::{
pub use expr_fn::*;
pub use expr_schema::ExprSchemable;
pub use function::{
- AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation,
+ AccumulatorFactoryFunction, ReturnTypeFunction, ScalarFunctionImplementation,
StateTypeFunction,
};
pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 1b455a0985..3cb6d1b91a 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -19,20 +19,30 @@
use crate::Expr;
use crate::{
- AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction,
+ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction,
};
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
-/// Logical representation of a user-defined aggregate function (UDAF).
+/// Logical representation of a user-defined [aggregate function] (UDAF).
///
-/// A UDAF is different from a user-defined scalar function (UDF) in
-/// that it is stateful across batches. UDAFs can be used as normal
-/// aggregate functions as well as window functions (the `OVER` clause)
+/// An aggregate function combines the values from multiple input rows
+/// into a single output "aggregate" (summary) row. It is different
+/// from a scalar function because it is stateful across batches. User
+/// defined aggregate functions can be used as normal SQL aggregate
+/// functions (`GROUP BY` clause) as well as window functions (`OVER`
+/// clause).
///
-/// For more information, please see [the examples]
+/// `AggregateUDF` provides DataFusion the information needed to plan
+/// and call aggregate functions, including name, type information,
+/// and a factory function to create [`Accumulator`], which peform the
+/// actual aggregation.
+///
+/// For more information, please see [the examples].
///
/// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process
+/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
+/// [`Accumulator`]: crate::Accumulator
#[derive(Clone)]
pub struct AggregateUDF {
/// name
@@ -42,7 +52,7 @@ pub struct AggregateUDF {
/// Return type
pub return_type: ReturnTypeFunction,
/// actual implementation
- pub accumulator: AccumulatorFunctionImplementation,
+ pub accumulator: AccumulatorFactoryFunction,
/// the accumulator's state's description as a function of the return type
pub state_type: StateTypeFunction,
}
@@ -78,7 +88,7 @@ impl AggregateUDF {
name: &str,
signature: &Signature,
return_type: &ReturnTypeFunction,
- accumulator: &AccumulatorFunctionImplementation,
+ accumulator: &AccumulatorFactoryFunction,
state_type: &StateTypeFunction,
) -> Self {
Self {
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 8f551e5d80..412abbfae6 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -783,10 +783,9 @@ mod test {
use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue};
use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
use datafusion_expr::{
- cast, col, concat, concat_ws, create_udaf, is_true,
- AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, BinaryExpr,
- BuiltinScalarFunction, Case, ColumnarValue, ExprSchemable, Filter, Operator,
- StateTypeFunction, Subquery,
+ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction,
+ AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case,
+ ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, Subquery,
};
use datafusion_expr::{
lit,
@@ -941,7 +940,7 @@ mod test {
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
let state_type: StateTypeFunction =
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
- let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| {
+ let accumulator: AccumulatorFactoryFunction = Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 0f63ecc2cc..42318d4181 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -773,7 +773,7 @@ mod test {
avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
};
use datafusion_expr::{
- AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
+ AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, Volatility,
};
@@ -898,8 +898,7 @@ mod test {
assert_eq!(inputs, &[DataType::UInt32]);
Ok(Arc::new(DataType::UInt32))
});
- let accumulator: AccumulatorFunctionImplementation =
- Arc::new(|_| unimplemented!());
+ let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
let udf_agg = |inner: Expr| {
Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new(
diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs
index 3c14981355..904cf08f7b 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1297,7 +1297,7 @@ mod roundtrip_tests {
};
use datafusion_common::Result;
use datafusion_expr::{
- Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction,
+ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction,
Signature, StateTypeFunction,
};
@@ -1484,8 +1484,7 @@ mod roundtrip_tests {
let rt_func: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Int64)));
- let accumulator: AccumulatorFunctionImplementation =
- Arc::new(|_| Ok(Box::new(Example)));
+ let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example)));
let st_func: StateTypeFunction =
Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64])));