You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/14 21:09:24 UTC

[arrow-datafusion] branch main updated: Minor: Add tests for User Defined Aggregate functions (#6669)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8eb51089dd Minor: Add tests for User Defined Aggregate functions (#6669)
8eb51089dd is described below

commit 8eb51089ddebe6643bec1ece470415471b007b57
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Jun 14 17:09:18 2023 -0400

    Minor: Add tests for User Defined Aggregate functions (#6669)
    
    * Add more tests for User Defined Aggregate functions
    
    * Apply suggestions from code review
    
    Co-authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    
    ---------
    
    Co-authored-by: Liang-Chi Hsieh <vi...@gmail.com>
---
 datafusion/core/tests/user_defined_aggregates.rs | 278 +++++++++++++++++++----
 datafusion/expr/src/signature.rs                 |   6 +-
 2 files changed, 238 insertions(+), 46 deletions(-)

diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs
index 1047f73df4..7c95b9a2d4 100644
--- a/datafusion/core/tests/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined_aggregates.rs
@@ -18,8 +18,11 @@
 //! This module contains end to end demonstrations of creating
 //! user defined aggregate functions
 
-use arrow::datatypes::Fields;
-use std::sync::Arc;
+use arrow::{array::AsArray, datatypes::Fields};
+use std::sync::{
+    atomic::{AtomicBool, Ordering},
+    Arc,
+};
 
 use datafusion::{
     arrow::{
@@ -39,10 +42,74 @@ use datafusion::{
 };
 use datafusion_common::cast::as_primitive_array;
 
+/// Basic user defined aggregate
+#[tokio::test]
+async fn test_udaf() {
+    let TestContext { ctx, counters } = TestContext::new();
+    assert!(!counters.update_batch());
+    let sql = "SELECT time_sum(time) from t";
+    let expected = vec![
+        "+----------------------------+",
+        "| time_sum(t.time)           |",
+        "+----------------------------+",
+        "| 1970-01-01T00:00:00.000019 |",
+        "+----------------------------+",
+    ];
+    assert_batches_eq!(expected, &execute(&ctx, sql).await);
+    // normal aggregates call update_batch
+    assert!(counters.update_batch());
+    assert!(!counters.retract_batch());
+}
+
+/// User defined aggregate used as a window function
 #[tokio::test]
+async fn test_udaf_as_window() {
+    let TestContext { ctx, counters } = TestContext::new();
+    let sql = "SELECT time_sum(time) OVER() as time_sum from t";
+    let expected = vec![
+        "+----------------------------+",
+        "| time_sum                   |",
+        "+----------------------------+",
+        "| 1970-01-01T00:00:00.000019 |",
+        "| 1970-01-01T00:00:00.000019 |",
+        "| 1970-01-01T00:00:00.000019 |",
+        "| 1970-01-01T00:00:00.000019 |",
+        "| 1970-01-01T00:00:00.000019 |",
+        "+----------------------------+",
+    ];
+    assert_batches_eq!(expected, &execute(&ctx, sql).await);
+    // aggregate over the entire window function call update_batch
+    assert!(counters.update_batch());
+    assert!(!counters.retract_batch());
+}
+
+/// User defined aggregate used as a window function with a window frame
+#[tokio::test]
+async fn test_udaf_as_window_with_frame() {
+    let TestContext { ctx, counters } = TestContext::new();
+    let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
+    let expected = vec![
+        "+----------------------------+",
+        "| time_sum                   |",
+        "+----------------------------+",
+        "| 1970-01-01T00:00:00.000005 |",
+        "| 1970-01-01T00:00:00.000009 |",
+        "| 1970-01-01T00:00:00.000014 |",
+        "| 1970-01-01T00:00:00.000019 |",
+        "| 1970-01-01T00:00:00.000019 |",
+        "+----------------------------+",
+    ];
+    assert_batches_eq!(expected, &execute(&ctx, sql).await);
+    // user defined aggregates with window frame should be calling retract batch
+    // but doesn't yet: https://github.com/apache/arrow-datafusion/issues/6611
+    assert!(counters.update_batch());
+    assert!(!counters.retract_batch());
+}
+
 /// Basic query for with a udaf returning a structure
-async fn test_udf_returning_struct() {
-    let ctx = udaf_struct_context();
+#[tokio::test]
+async fn test_udaf_returning_struct() {
+    let TestContext { ctx, counters: _ } = TestContext::new();
     let sql = "SELECT first(value, time) from t";
     let expected = vec![
         "+------------------------------------------------+",
@@ -54,10 +121,10 @@ async fn test_udf_returning_struct() {
     assert_batches_eq!(expected, &execute(&ctx, sql).await);
 }
 
+/// Demonstrate extracting the fields from a structure using a subquery
 #[tokio::test]
-/// Demonstrate extracting the fields from the a structure using a subquery
-async fn test_udf_returning_struct_sq() {
-    let ctx = udaf_struct_context();
+async fn test_udaf_returning_struct_subquery() {
+    let TestContext { ctx, counters: _ } = TestContext::new();
     let sql = "select sq.first['value'], sq.first['time'] from (SELECT first(value, time) as first from t) as sq";
     let expected = vec![
         "+-----------------+----------------------------+",
@@ -73,7 +140,8 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
     ctx.sql(sql).await.unwrap().collect().await.unwrap()
 }
 
-/// Returns an context with a table "t" and the "first" aggregate registered.
+/// Returns an context with a table "t" and the "first" and "time_sum"
+/// aggregate functions registered.
 ///
 /// "t" contains this data:
 ///
@@ -82,56 +150,151 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
 ///  3.0  | 1970-01-01T00:00:00.000003
 ///  2.0  | 1970-01-01T00:00:00.000002
 ///  1.0  | 1970-01-01T00:00:00.000004
+///  5.0  | 1970-01-01T00:00:00.000005
+///  5.0  | 1970-01-01T00:00:00.000005
 /// ```
-fn udaf_struct_context() -> SessionContext {
-    let value: Float64Array = vec![3.0, 2.0, 1.0].into_iter().map(Some).collect();
-    let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000]);
+struct TestContext {
+    ctx: SessionContext,
+    counters: Arc<TestCounters>,
+}
+
+impl TestContext {
+    fn new() -> Self {
+        let counters = Arc::new(TestCounters::new());
+
+        let value = Float64Array::from(vec![3.0, 2.0, 1.0, 5.0, 5.0]);
+        let time = TimestampNanosecondArray::from(vec![3000, 2000, 4000, 5000, 5000]);
+
+        let batch = RecordBatch::try_from_iter(vec![
+            ("value", Arc::new(value) as _),
+            ("time", Arc::new(time) as _),
+        ])
+        .unwrap();
 
-    let batch = RecordBatch::try_from_iter(vec![
-        ("value", Arc::new(value) as _),
-        ("time", Arc::new(time) as _),
-    ])
-    .unwrap();
+        let mut ctx = SessionContext::new();
 
-    let mut ctx = SessionContext::new();
-    ctx.register_batch("t", batch).unwrap();
+        ctx.register_batch("t", batch).unwrap();
 
-    // Tell datafusion about the "first" function
-    register_aggregate(&mut ctx);
+        // Tell DataFusion about the "first" function
+        FirstSelector::register(&mut ctx);
+        // Tell DataFusion about the "time_sum" function
+        TimeSum::register(&mut ctx, Arc::clone(&counters));
 
-    ctx
+        Self { ctx, counters }
+    }
+}
+
+#[derive(Debug, Default)]
+struct TestCounters {
+    /// was update_batch called?
+    update_batch: AtomicBool,
+    /// was retract_batch called?
+    retract_batch: AtomicBool,
 }
 
-fn register_aggregate(ctx: &mut SessionContext) {
-    let return_type = Arc::new(FirstSelector::output_datatype());
-    let state_type = Arc::new(FirstSelector::state_datatypes());
+impl TestCounters {
+    fn new() -> Self {
+        Default::default()
+    }
+
+    /// Has `update_batch` been called?
+    fn update_batch(&self) -> bool {
+        self.update_batch.load(Ordering::SeqCst)
+    }
+
+    /// Has `retract_batch` been called?
+    fn retract_batch(&self) -> bool {
+        self.retract_batch.load(Ordering::SeqCst)
+    }
+}
+
+/// Models a user defined aggregate function that computes the a sum
+/// of timestamps (not a quantity that has much real world meaning)
+#[derive(Debug)]
+struct TimeSum {
+    sum: i64,
+    counters: Arc<TestCounters>,
+}
+
+impl TimeSum {
+    fn new(counters: Arc<TestCounters>) -> Self {
+        Self { sum: 0, counters }
+    }
 
-    let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
-    let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
+    fn register(ctx: &mut SessionContext, counters: Arc<TestCounters>) {
+        let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
 
-    // Possible input signatures
-    let signatures = vec![TypeSignature::Exact(FirstSelector::input_datatypes())];
+        // Returns the same type as its input
+        let return_type = Arc::new(timestamp_type.clone());
+        let return_type: ReturnTypeFunction =
+            Arc::new(move |_| Ok(Arc::clone(&return_type)));
 
-    let accumulator: AccumulatorFunctionImplementation =
-        Arc::new(|_| Ok(Box::new(FirstSelector::new())));
+        let state_type = Arc::new(vec![timestamp_type.clone()]);
+        let state_type: StateTypeFunction =
+            Arc::new(move |_| Ok(Arc::clone(&state_type)));
 
-    let volatility = Volatility::Immutable;
+        let volatility = Volatility::Immutable;
 
-    let name = "first";
+        let signature = Signature::exact(vec![timestamp_type], volatility);
 
-    let first = AggregateUDF::new(
-        name,
-        &Signature::one_of(signatures, volatility),
-        &return_type,
-        &accumulator,
-        &state_type,
-    );
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&counters)))));
 
-    // register the selector as "first"
-    ctx.register_udaf(first)
+        let name = "time_sum";
+
+        let time_sum =
+            AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type);
+
+        // register the selector as "time_sum"
+        ctx.register_udaf(time_sum)
+    }
 }
 
-/// This structureg models a specialized timeseries aggregate function
+impl Accumulator for TimeSum {
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![self.evaluate()?])
+    }
+
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        self.counters.update_batch.store(true, Ordering::SeqCst);
+        assert_eq!(values.len(), 1);
+        let arr = &values[0];
+        let arr = arr.as_primitive::<TimestampNanosecondType>();
+
+        for v in arr.values().iter() {
+            self.sum += v;
+        }
+        Ok(())
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        // merge and update is the same for time sum
+        self.update_batch(states)
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
+    }
+
+    fn size(&self) -> usize {
+        // accurate size estimates are not important for this example
+        42
+    }
+
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        self.counters.retract_batch.store(true, Ordering::SeqCst);
+        assert_eq!(values.len(), 1);
+        let arr = &values[0];
+        let arr = arr.as_primitive::<TimestampNanosecondType>();
+
+        for v in arr.values().iter() {
+            self.sum -= v;
+        }
+        Ok(())
+    }
+}
+
+/// Models a specialized timeseries aggregate function
 /// called a "selector" in InfluxQL and Flux.
 ///
 /// It returns the value and corresponding timestamp of the
@@ -151,6 +314,35 @@ impl FirstSelector {
         }
     }
 
+    fn register(ctx: &mut SessionContext) {
+        let return_type = Arc::new(Self::output_datatype());
+        let state_type = Arc::new(Self::state_datatypes());
+
+        let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
+        let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
+
+        // Possible input signatures
+        let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
+
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(|_| Ok(Box::new(Self::new())));
+
+        let volatility = Volatility::Immutable;
+
+        let name = "first";
+
+        let first = AggregateUDF::new(
+            name,
+            &Signature::one_of(signatures, volatility),
+            &return_type,
+            &accumulator,
+            &state_type,
+        );
+
+        // register the selector as "first"
+        ctx.register_udaf(first)
+    }
+
     /// Return the schema fields
     fn fields() -> Fields {
         vec![
@@ -164,12 +356,10 @@ impl FirstSelector {
         .into()
     }
 
-    // output data type
     fn output_datatype() -> DataType {
         DataType::Struct(Self::fields())
     }
 
-    // input argument data types
     fn input_datatypes() -> Vec<DataType> {
         vec![
             DataType::Float64,
diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs
index e4ffd74d8d..988fe7c91d 100644
--- a/datafusion/expr/src/signature.rs
+++ b/datafusion/expr/src/signature.rs
@@ -49,8 +49,10 @@ pub enum TypeSignature {
     /// arbitrary number of arguments with arbitrary types
     VariadicAny,
     /// fixed number of arguments of an arbitrary but equal type out of a list of valid types
-    // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])`
-    // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])`
+    ///
+    /// # Examples
+    /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])`
+    /// 2. A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])`
     Uniform(usize, Vec<DataType>),
     /// exact number of arguments of an exact type
     Exact(Vec<DataType>),