You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ak...@apache.org on 2023/06/01 07:03:09 UTC

[arrow-datafusion] branch main updated: Bug fix, first multiple batches. Add unit test (#6503)

This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e8bce3f9bc Bug fix, first multiple batches. Add unit test (#6503)
e8bce3f9bc is described below

commit e8bce3f9bc66eaa56a6fe66f0fe9408f3792bbf7
Author: Mustafa Akur <10...@users.noreply.github.com>
AuthorDate: Thu Jun 1 10:03:01 2023 +0300

    Bug fix, first multiple batches. Add unit test (#6503)
---
 .../physical-expr/src/aggregate/first_last.rs      | 54 ++++++++++++++++++++--
 1 file changed, 51 insertions(+), 3 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs
index f65360c751..5dd9620ce0 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -112,25 +112,35 @@ impl PartialEq<dyn Any> for FirstValue {
 #[derive(Debug)]
 struct FirstValueAccumulator {
     first: ScalarValue,
+    // At the beginning, `is_set` is `false`, this means `first` is not seen yet.
+    // Once we see (`is_set=true`) first value, we do not update `first`.
+    is_set: bool,
 }
 
 impl FirstValueAccumulator {
     /// Creates a new `FirstValueAccumulator` for the given `data_type`.
     pub fn try_new(data_type: &DataType) -> Result<Self> {
-        ScalarValue::try_from(data_type).map(|value| Self { first: value })
+        ScalarValue::try_from(data_type).map(|value| Self {
+            first: value,
+            is_set: false,
+        })
     }
 }
 
 impl Accumulator for FirstValueAccumulator {
     fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![self.first.clone()])
+        Ok(vec![
+            self.first.clone(),
+            ScalarValue::Boolean(Some(self.is_set)),
+        ])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
         // If we have seen first value, we shouldn't update it
         let values = &values[0];
-        if !values.is_empty() {
+        if !values.is_empty() && !self.is_set {
             self.first = ScalarValue::try_from_array(values, 0)?;
+            self.is_set = true;
         }
         Ok(())
     }
@@ -270,3 +280,41 @@ impl Accumulator for LastValueAccumulator {
         std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + self.last.size()
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator};
+    use arrow_array::{ArrayRef, Int64Array};
+    use arrow_schema::DataType;
+    use datafusion_common::{Result, ScalarValue};
+    use datafusion_expr::Accumulator;
+    use std::sync::Arc;
+
+    #[test]
+    fn test_first_last_value_value() -> Result<()> {
+        let mut first_accumulator = FirstValueAccumulator::try_new(&DataType::Int64)?;
+        let mut last_accumulator = LastValueAccumulator::try_new(&DataType::Int64)?;
+        // first value in the tuple is start of the range (inclusive),
+        // second value in the tuple is end of the range (exclusive)
+        let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
+        // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
+        let arrs = ranges
+            .into_iter()
+            .map(|(start, end)| {
+                Arc::new(Int64Array::from((start..end).collect::<Vec<_>>())) as ArrayRef
+            })
+            .collect::<Vec<_>>();
+        for arr in arrs {
+            // Once first_value is set, accumulator should remember it.
+            // It shouldn't update first_value for each new batch
+            first_accumulator.update_batch(&[arr.clone()])?;
+            // last_value should be updated for each new batch.
+            last_accumulator.update_batch(&[arr])?;
+        }
+        // First Value comes from the first value of the first batch which is 0
+        assert_eq!(first_accumulator.evaluate()?, ScalarValue::Int64(Some(0)));
+        // Last value comes from the last value of the last batch which is 12
+        assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
+        Ok(())
+    }
+}