You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/14 21:09:24 UTC
[arrow-datafusion] branch main updated: Minor: Add tests for User Defined Aggregate functions (#6669)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8eb51089dd Minor: Add tests for User Defined Aggregate functions (#6669)
8eb51089dd is described below
commit 8eb51089ddebe6643bec1ece470415471b007b57
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Jun 14 17:09:18 2023 -0400
Minor: Add tests for User Defined Aggregate functions (#6669)
* Add more tests for User Defined Aggregate functions
* Apply suggestions from code review
Co-authored-by: Liang-Chi Hsieh <vi...@gmail.com>
---------
Co-authored-by: Liang-Chi Hsieh <vi...@gmail.com>
---
datafusion/core/tests/user_defined_aggregates.rs | 278 +++++++++++++++++++----
datafusion/expr/src/signature.rs | 6 +-
2 files changed, 238 insertions(+), 46 deletions(-)
diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs
index 1047f73df4..7c95b9a2d4 100644
--- a/datafusion/core/tests/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined_aggregates.rs
@@ -18,8 +18,11 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions
-use arrow::datatypes::Fields;
-use std::sync::Arc;
+use arrow::{array::AsArray, datatypes::Fields};
+use std::sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc,
+};
use datafusion::{
arrow::{
@@ -39,10 +42,74 @@ use datafusion::{
};
use datafusion_common::cast::as_primitive_array;
+/// Basic user defined aggregate
+#[tokio::test]
+async fn test_udaf() {
+ let TestContext { ctx, counters } = TestContext::new();
+ assert!(!counters.update_batch());
+ let sql = "SELECT time_sum(time) from t";
+ let expected = vec![
+ "+----------------------------+",
+ "| time_sum(t.time) |",
+ "+----------------------------+",
+ "| 1970-01-01T00:00:00.000019 |",
+ "+----------------------------+",
+ ];
+ assert_batches_eq!(expected, &execute(&ctx, sql).await);
+ // normal aggregates call update_batch
+ assert!(counters.update_batch());
+ assert!(!counters.retract_batch());
+}
+
+/// User defined aggregate used as a window function
#[tokio::test]
+async fn test_udaf_as_window() {
+ let TestContext { ctx, counters } = TestContext::new();
+ let sql = "SELECT time_sum(time) OVER() as time_sum from t";
+ let expected = vec![
+ "+----------------------------+",
+ "| time_sum |",
+ "+----------------------------+",
+ "| 1970-01-01T00:00:00.000019 |",
+ "| 1970-01-01T00:00:00.000019 |",
+ "| 1970-01-01T00:00:00.000019 |",
+ "| 1970-01-01T00:00:00.000019 |",
+ "| 1970-01-01T00:00:00.000019 |",
+ "+----------------------------+",
+ ];
+ assert_batches_eq!(expected, &execute(&ctx, sql).await);
+ // aggregate over the entire window function call update_batch
+ assert!(counters.update_batch());
+ assert!(!counters.retract_batch());
+}
+
+/// User defined aggregate used as a window function with a window frame
+#[tokio::test]
+async fn test_udaf_as_window_with_frame() {
+ let TestContext { ctx, counters } = TestContext::new();
+ let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
+ let expected = vec![
+ "+----------------------------+",
+ "| time_sum |",
+ "+----------------------------+",
+ "| 1970-01-01T00:00:00.000005 |",
+ "| 1970-01-01T00:00:00.000009 |",
+ "| 1970-01-01T00:00:00.000014 |",
+ "| 1970-01-01T00:00:00.000019 |",
+ "| 1970-01-01T00:00:00.000019 |",
+ "+----------------------------+",
+ ];
+ assert_batches_eq!(expected, &execute(&ctx, sql).await);
+ // user defined aggregates with window frame should be calling retract batch
+ // but doesn't yet: https://github.com/apache/arrow-datafusion/issues/6611
+ assert!(counters.update_batch());
+ assert!(!counters.retract_batch());
+}
+
/// Basic query for with a udaf returning a structure
-async fn test_udf_returning_struct() {
- let ctx = udaf_struct_context();
+#[tokio::test]
+async fn test_udaf_returning_struct() {
+ let TestContext { ctx, counters: _ } = TestContext::new();
let sql = "SELECT first(value, time) from t";
let expected = vec![
"+------------------------------------------------+",
@@ -54,10 +121,10 @@ async fn test_udf_returning_struct() {
assert_batches_eq!(expected, &execute(&ctx, sql).await);
}
+/// Demonstrate extracting the fields from a structure using a subquery
#[tokio::test]
-/// Demonstrate extracting the fields from the a structure using a subquery
-async fn test_udf_returning_struct_sq() {
- let ctx = udaf_struct_context();
+async fn test_udaf_returning_struct_subquery() {
+ let TestContext { ctx, counters: _ } = TestContext::new();
let sql = "select sq.first['value'], sq.first['time'] from (SELECT first(value, time) as first from t) as sq";
let expected = vec![
"+-----------------+----------------------------+",
@@ -73,7 +140,8 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
ctx.sql(sql).await.unwrap().collect().await.unwrap()
}
-/// Returns an context with a table "t" and the "first" aggregate registered.
+/// Returns an context with a table "t" and the "first" and "time_sum"
+/// aggregate functions registered.
///
/// "t" contains this data:
///
@@ -82,56 +150,151 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
/// 3.0 | 1970-01-01T00:00:00.000003
/// 2.0 | 1970-01-01T00:00:00.000002
/// 1.0 | 1970-01-01T00:00:00.000004
+/// 5.0 | 1970-01-01T00:00:00.000005
+/// 5.0 | 1970-01-01T00:00:00.000005
/// ```
-fn udaf_struct_context() -> SessionContext {
- let value: Float64Array = vec![3.0, 2.0, 1.0].into_iter().map(Some).collect();
- let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000]);
+struct TestContext {
+ ctx: SessionContext,
+ counters: Arc<TestCounters>,
+}
+
+impl TestContext {
+ fn new() -> Self {
+ let counters = Arc::new(TestCounters::new());
+
+ let value = Float64Array::from(vec![3.0, 2.0, 1.0, 5.0, 5.0]);
+ let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000, 5000, 5000]);
+
+ let batch = RecordBatch::try_from_iter(vec![
+ ("value", Arc::new(value) as _),
+ ("time", Arc::new(time) as _),
+ ])
+ .unwrap();
- let batch = RecordBatch::try_from_iter(vec![
- ("value", Arc::new(value) as _),
- ("time", Arc::new(time) as _),
- ])
- .unwrap();
+ let mut ctx = SessionContext::new();
- let mut ctx = SessionContext::new();
- ctx.register_batch("t", batch).unwrap();
+ ctx.register_batch("t", batch).unwrap();
- // Tell datafusion about the "first" function
- register_aggregate(&mut ctx);
+ // Tell DataFusion about the "first" function
+ FirstSelector::register(&mut ctx);
+ // Tell DataFusion about the "time_sum" function
+ TimeSum::register(&mut ctx, Arc::clone(&counters));
- ctx
+ Self { ctx, counters }
+ }
+}
+
+#[derive(Debug, Default)]
+struct TestCounters {
+ /// was update_batch called?
+ update_batch: AtomicBool,
+ /// was retract_batch called?
+ retract_batch: AtomicBool,
}
-fn register_aggregate(ctx: &mut SessionContext) {
- let return_type = Arc::new(FirstSelector::output_datatype());
- let state_type = Arc::new(FirstSelector::state_datatypes());
+impl TestCounters {
+ fn new() -> Self {
+ Default::default()
+ }
+
+ /// Has `update_batch` been called?
+ fn update_batch(&self) -> bool {
+ self.update_batch.load(Ordering::SeqCst)
+ }
+
+ /// Has `retract_batch` been called?
+ fn retract_batch(&self) -> bool {
+ self.retract_batch.load(Ordering::SeqCst)
+ }
+}
+
+/// Models a user defined aggregate function that computes the a sum
+/// of timestamps (not a quantity that has much real world meaning)
+#[derive(Debug)]
+struct TimeSum {
+ sum: i64,
+ counters: Arc<TestCounters>,
+}
+
+impl TimeSum {
+ fn new(counters: Arc<TestCounters>) -> Self {
+ Self { sum: 0, counters }
+ }
- let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
- let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
+ fn register(ctx: &mut SessionContext, counters: Arc<TestCounters>) {
+ let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
- // Possible input signatures
- let signatures = vec![TypeSignature::Exact(FirstSelector::input_datatypes())];
+ // Returns the same type as its input
+ let return_type = Arc::new(timestamp_type.clone());
+ let return_type: ReturnTypeFunction =
+ Arc::new(move |_| Ok(Arc::clone(&return_type)));
- let accumulator: AccumulatorFunctionImplementation =
- Arc::new(|_| Ok(Box::new(FirstSelector::new())));
+ let state_type = Arc::new(vec![timestamp_type.clone()]);
+ let state_type: StateTypeFunction =
+ Arc::new(move |_| Ok(Arc::clone(&state_type)));
- let volatility = Volatility::Immutable;
+ let volatility = Volatility::Immutable;
- let name = "first";
+ let signature = Signature::exact(vec![timestamp_type], volatility);
- let first = AggregateUDF::new(
- name,
- &Signature::one_of(signatures, volatility),
- &return_type,
- &accumulator,
- &state_type,
- );
+ let accumulator: AccumulatorFunctionImplementation =
+ Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&counters)))));
- // register the selector as "first"
- ctx.register_udaf(first)
+ let name = "time_sum";
+
+ let time_sum =
+ AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type);
+
+ // register the selector as "time_sum"
+ ctx.register_udaf(time_sum)
+ }
}
-/// This structureg models a specialized timeseries aggregate function
+impl Accumulator for TimeSum {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.evaluate()?])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ self.counters.update_batch.store(true, Ordering::SeqCst);
+ assert_eq!(values.len(), 1);
+ let arr = &values[0];
+ let arr = arr.as_primitive::<TimestampNanosecondType>();
+
+ for v in arr.values().iter() {
+ self.sum += v;
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ // merge and update is the same for time sum
+ self.update_batch(states)
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
+ }
+
+ fn size(&self) -> usize {
+ // accurate size estimates are not important for this example
+ 42
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ self.counters.retract_batch.store(true, Ordering::SeqCst);
+ assert_eq!(values.len(), 1);
+ let arr = &values[0];
+ let arr = arr.as_primitive::<TimestampNanosecondType>();
+
+ for v in arr.values().iter() {
+ self.sum -= v;
+ }
+ Ok(())
+ }
+}
+
+/// Models a specialized timeseries aggregate function
/// called a "selector" in InfluxQL and Flux.
///
/// It returns the value and corresponding timestamp of the
@@ -151,6 +314,35 @@ impl FirstSelector {
}
}
+ fn register(ctx: &mut SessionContext) {
+ let return_type = Arc::new(Self::output_datatype());
+ let state_type = Arc::new(Self::state_datatypes());
+
+ let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
+ let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
+
+ // Possible input signatures
+ let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
+
+ let accumulator: AccumulatorFunctionImplementation =
+ Arc::new(|_| Ok(Box::new(Self::new())));
+
+ let volatility = Volatility::Immutable;
+
+ let name = "first";
+
+ let first = AggregateUDF::new(
+ name,
+ &Signature::one_of(signatures, volatility),
+ &return_type,
+ &accumulator,
+ &state_type,
+ );
+
+ // register the selector as "first"
+ ctx.register_udaf(first)
+ }
+
/// Return the schema fields
fn fields() -> Fields {
vec![
@@ -164,12 +356,10 @@ impl FirstSelector {
.into()
}
- // output data type
fn output_datatype() -> DataType {
DataType::Struct(Self::fields())
}
- // input argument data types
fn input_datatypes() -> Vec<DataType> {
vec![
DataType::Float64,
diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs
index e4ffd74d8d..988fe7c91d 100644
--- a/datafusion/expr/src/signature.rs
+++ b/datafusion/expr/src/signature.rs
@@ -49,8 +49,10 @@ pub enum TypeSignature {
/// arbitrary number of arguments with arbitrary types
VariadicAny,
/// fixed number of arguments of an arbitrary but equal type out of a list of valid types
- // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])`
- // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])`
+ ///
+ /// # Examples
+ /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])`
+ /// 2. A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])`
Uniform(usize, Vec<DataType>),
/// exact number of arguments of an exact type
Exact(Vec<DataType>),