You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/14 14:53:58 UTC
[arrow-datafusion] branch master updated: Remove `AggregateState` wrapper (#4582)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 5d424ef07 Remove `AggregateState` wrapper (#4582)
5d424ef07 is described below
commit 5d424ef07e4bb4202043f6ab3d7d52248ba852e7
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Dec 14 09:53:52 2022 -0500
Remove `AggregateState` wrapper (#4582)
* Remove AggregateState wrapper
* Remove more unwrap
* Fix logical conflicts
* Remove unecessary array
---
datafusion-examples/examples/simple_udaf.rs | 7 ++-
.../core/src/physical_plan/aggregates/hash.rs | 2 +-
datafusion/core/src/physical_plan/windows/mod.rs | 8 ++--
datafusion/core/tests/sql/udf.rs | 4 +-
datafusion/core/tests/user_defined_aggregates.rs | 9 +---
datafusion/expr/src/accumulator.rs | 51 +++++-----------------
datafusion/expr/src/lib.rs | 2 +-
.../physical-expr/src/aggregate/approx_distinct.rs | 6 +--
.../src/aggregate/approx_percentile_cont.rs | 11 ++---
.../approx_percentile_cont_with_weight.rs | 4 +-
.../physical-expr/src/aggregate/array_agg.rs | 6 +--
.../src/aggregate/array_agg_distinct.rs | 8 ++--
datafusion/physical-expr/src/aggregate/average.rs | 9 ++--
.../physical-expr/src/aggregate/correlation.rs | 16 +++----
datafusion/physical-expr/src/aggregate/count.rs | 8 ++--
.../physical-expr/src/aggregate/count_distinct.rs | 13 +++---
.../physical-expr/src/aggregate/covariance.rs | 12 ++---
datafusion/physical-expr/src/aggregate/median.rs | 6 +--
datafusion/physical-expr/src/aggregate/min_max.rs | 10 ++---
datafusion/physical-expr/src/aggregate/stddev.rs | 10 ++---
datafusion/physical-expr/src/aggregate/sum.rs | 9 ++--
.../physical-expr/src/aggregate/sum_distinct.rs | 11 +++--
datafusion/physical-expr/src/aggregate/utils.rs | 25 +++--------
datafusion/physical-expr/src/aggregate/variance.rs | 10 ++---
datafusion/proto/src/lib.rs | 4 +-
25 files changed, 97 insertions(+), 164 deletions(-)
diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs
index bb3a42b8b..f4e0d3dd9 100644
--- a/datafusion-examples/examples/simple_udaf.rs
+++ b/datafusion-examples/examples/simple_udaf.rs
@@ -21,7 +21,6 @@ use datafusion::arrow::{
array::ArrayRef, array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
};
use datafusion::from_slice::FromSlice;
-use datafusion::logical_expr::AggregateState;
use datafusion::{error::Result, physical_plan::Accumulator};
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
use datafusion_common::cast::as_float64_array;
@@ -108,10 +107,10 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
- AggregateState::Scalar(ScalarValue::from(self.prod)),
- AggregateState::Scalar(ScalarValue::from(self.n)),
+ ScalarValue::from(self.prod),
+ ScalarValue::from(self.n),
])
}
diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs
index 4ab2a0a06..0d35f5b0d 100644
--- a/datafusion/core/src/physical_plan/aggregates/hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/hash.rs
@@ -519,7 +519,7 @@ fn create_batch_from_map(
accumulators.group_states.iter().map(|group_state| {
group_state.accumulator_set[x]
.state()
- .and_then(|x| x[y].as_scalar().map(|v| v.clone()))
+ .map(|x| x[y].clone())
.expect("unexpected accumulator state in hash aggregate")
}),
)?;
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index 5cd0a1a9c..76d39a199 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -178,7 +178,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_primitive_array;
- use datafusion_expr::{create_udaf, Accumulator, AggregateState, Volatility};
+ use datafusion_expr::{create_udaf, Accumulator, Volatility};
use futures::FutureExt;
fn create_test_schema(partitions: usize) -> Result<(Arc<CsvExec>, SchemaRef)> {
@@ -193,10 +193,8 @@ mod tests {
struct MyCount(i64);
impl Accumulator for MyCount {
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
- self.0,
- )))])
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::Int64(Some(self.0))])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs
index 37a869f93..b71a13cdd 100644
--- a/datafusion/core/tests/sql/udf.rs
+++ b/datafusion/core/tests/sql/udf.rs
@@ -22,7 +22,7 @@ use datafusion::{
physical_plan::{expressions::AvgAccumulator, functions::make_scalar_function},
};
use datafusion_common::{cast::as_int32_array, ScalarValue};
-use datafusion_expr::{create_udaf, Accumulator, AggregateState, LogicalPlanBuilder};
+use datafusion_expr::{create_udaf, Accumulator, LogicalPlanBuilder};
/// test that casting happens on udfs.
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
@@ -175,7 +175,7 @@ fn udaf_as_window_func() -> Result<()> {
struct MyAccumulator;
impl Accumulator for MyAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
unimplemented!()
}
diff --git a/datafusion/core/tests/user_defined_aggregates.rs b/datafusion/core/tests/user_defined_aggregates.rs
index fd8ddb832..eec424fc8 100644
--- a/datafusion/core/tests/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined_aggregates.rs
@@ -28,7 +28,6 @@ use datafusion::{
},
assert_batches_eq,
error::Result,
- logical_expr::AggregateState,
logical_expr::{
AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, TypeSignature, Volatility,
@@ -210,12 +209,8 @@ impl FirstSelector {
}
impl Accumulator for FirstSelector {
- fn state(&self) -> Result<Vec<AggregateState>> {
- let state = self
- .to_state()
- .into_iter()
- .map(AggregateState::Scalar)
- .collect::<Vec<_>>();
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ let state = self.to_state().into_iter().collect::<Vec<_>>();
Ok(state)
}
diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs
index 5b8269ee2..4a6b7ea3d 100644
--- a/datafusion/expr/src/accumulator.rs
+++ b/datafusion/expr/src/accumulator.rs
@@ -37,14 +37,20 @@ pub trait Accumulator: Send + Sync + Debug {
/// accumulator (that ran on different partitions, for
/// example).
///
- /// The state can be a different type than the output of the
- /// [`Accumulator`]
+ /// The state can be and often is a different type than the output
+ /// type of the [`Accumulator`].
///
/// See [`merge_batch`] for more details on the merging process.
///
- /// For example, in the case of an average, for which we track `sum` and `n`,
- /// this function should return a vector of two values, sum and n.
- fn state(&self) -> Result<Vec<AggregateState>>;
+ /// Some accumulators can return multiple values for their
+ /// intermediate states. For example average, tracks `sum` and
+ /// `n`, and this function should return
+ /// a vector of two values, sum and n.
+ ///
+ /// `ScalarValue::List` can also be used to pass multiple values
+ /// if the number of intermediate values is not known at planning
+ /// time (e.g. median)
+ fn state(&self) -> Result<Vec<ScalarValue>>;
/// Updates the accumulator's state from a vector of arrays.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
@@ -80,38 +86,3 @@ pub trait Accumulator: Send + Sync + Debug {
/// not the `len`
fn size(&self) -> usize;
}
-
-/// Representation of internal accumulator state. Accumulators can potentially have a mix of
-/// scalar and array values. It may be desirable to add custom aggregator states here as well
-/// in the future (perhaps `Custom(Box<dyn Any>)`?).
-#[derive(Debug)]
-pub enum AggregateState {
- /// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple
- /// values around
- Scalar(ScalarValue),
- /// Arrays can be used instead of `ScalarValue::List` and could potentially have better
- /// performance with large data sets, although this has not been verified. It also allows
- /// for use of arrow kernels with less overhead.
- Array(ArrayRef),
-}
-
-impl AggregateState {
- /// Access the aggregate state as a scalar value. An error will occur if the
- /// state is not a scalar value.
- pub fn as_scalar(&self) -> Result<&ScalarValue> {
- match &self {
- Self::Scalar(v) => Ok(v),
- _ => Err(DataFusionError::Internal(
- "AggregateState is not a scalar aggregate".to_string(),
- )),
- }
- }
-
- /// Access the aggregate state as an array value.
- pub fn to_array(&self) -> ArrayRef {
- match &self {
- Self::Scalar(v) => v.to_array(),
- Self::Array(array) => array.clone(),
- }
- }
-}
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 3c18b0481..eb943cf3e 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -52,7 +52,7 @@ pub mod utils;
pub mod window_frame;
pub mod window_function;
-pub use accumulator::{Accumulator, AggregateState};
+pub use accumulator::Accumulator;
pub use aggregate_function::AggregateFunction;
pub use built_in_function::BuiltinScalarFunction;
pub use columnar_value::ColumnarValue;
diff --git a/datafusion/physical-expr/src/aggregate/approx_distinct.rs b/datafusion/physical-expr/src/aggregate/approx_distinct.rs
index a2ac7f093..698f619a9 100644
--- a/datafusion/physical-expr/src/aggregate/approx_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_distinct.rs
@@ -30,7 +30,7 @@ use arrow::datatypes::{
};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use std::any::Any;
use std::convert::TryFrom;
use std::convert::TryInto;
@@ -231,8 +231,8 @@ macro_rules! default_accumulator_impl {
Ok(())
}
- fn state(&self) -> Result<Vec<AggregateState>> {
- let value = AggregateState::Scalar(ScalarValue::from(&self.hll));
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ let value = ScalarValue::from(&self.hll);
Ok(vec![value])
}
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index 3cbf9c9cd..006688a66 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -29,7 +29,7 @@ use arrow::{
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::{downcast_value, ScalarValue};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use std::{any::Any, iter, sync::Arc};
/// APPROX_PERCENTILE_CONT aggregate expression
@@ -357,13 +357,8 @@ impl ApproxPercentileAccumulator {
}
impl Accumulator for ApproxPercentileAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(self
- .digest
- .to_scalar_state()
- .into_iter()
- .map(AggregateState::Scalar)
- .collect())
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(self.digest.to_scalar_state().into_iter().collect())
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
index 41f195f38..71fb3d242 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
@@ -26,7 +26,7 @@ use arrow::{
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use std::{any::Any, sync::Arc};
@@ -114,7 +114,7 @@ impl ApproxPercentileWithWeightAccumulator {
}
impl Accumulator for ApproxPercentileWithWeightAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
self.approx_percentile_cont_accumulator.state()
}
diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs
index c56a8adb7..e436789ee 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg.rs
@@ -23,7 +23,7 @@ use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;
@@ -143,8 +143,8 @@ impl Accumulator for ArrayAggAccumulator {
})
}
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(vec![AggregateState::Scalar(self.evaluate()?)])
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.evaluate()?])
}
fn evaluate(&self) -> Result<ScalarValue> {
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index bb176e98c..a9dd3fe35 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -29,7 +29,7 @@ use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
#[derive(Debug)]
@@ -119,11 +119,11 @@ impl DistinctArrayAggAccumulator {
}
impl Accumulator for DistinctArrayAggAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(vec![AggregateState::Scalar(ScalarValue::new_list(
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::new_list(
Some(self.values.clone().into_iter().collect()),
self.datatype.clone(),
- ))])
+ )])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index da70252d4..2d2ee5a87 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -33,7 +33,7 @@ use arrow::{
};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;
/// AVG aggregate expression
@@ -150,11 +150,8 @@ impl AvgAccumulator {
}
impl Accumulator for AvgAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(vec![
- AggregateState::Scalar(ScalarValue::from(self.count)),
- AggregateState::Scalar(self.sum.clone()),
- ])
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs
index a013bd43a..55389e199 100644
--- a/datafusion/physical-expr/src/aggregate/correlation.rs
+++ b/datafusion/physical-expr/src/aggregate/correlation.rs
@@ -25,7 +25,7 @@ use crate::{AggregateExpr, PhysicalExpr};
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;
@@ -133,14 +133,14 @@ impl CorrelationAccumulator {
}
impl Accumulator for CorrelationAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
- AggregateState::Scalar(ScalarValue::from(self.covar.get_count())),
- AggregateState::Scalar(ScalarValue::from(self.covar.get_mean1())),
- AggregateState::Scalar(ScalarValue::from(self.stddev1.get_m2())),
- AggregateState::Scalar(ScalarValue::from(self.covar.get_mean2())),
- AggregateState::Scalar(ScalarValue::from(self.stddev2.get_m2())),
- AggregateState::Scalar(ScalarValue::from(self.covar.get_algo_const())),
+ ScalarValue::from(self.covar.get_count()),
+ ScalarValue::from(self.covar.get_mean1()),
+ ScalarValue::from(self.stddev1.get_m2()),
+ ScalarValue::from(self.covar.get_mean2()),
+ ScalarValue::from(self.stddev2.get_m2()),
+ ScalarValue::from(self.covar.get_algo_const()),
])
}
diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs
index 4721bf8f2..1f8e03e55 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -29,7 +29,7 @@ use arrow::datatypes::DataType;
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;
use crate::expressions::format_state_name;
@@ -119,10 +119,8 @@ impl CountAccumulator {
}
impl Accumulator for CountAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
- self.count,
- )))])
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::Int64(Some(self.count))])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs
index 5484c8608..06966f712 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs
@@ -28,7 +28,7 @@ use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
struct DistinctScalarValues(Vec<ScalarValue>);
@@ -177,7 +177,7 @@ impl Accumulator for DistinctCountAccumulator {
self.merge(&v)
})
}
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
let mut cols_out = self
.state_data_types
.iter()
@@ -206,7 +206,7 @@ impl Accumulator for DistinctCountAccumulator {
)
});
- Ok(cols_out.into_iter().map(AggregateState::Scalar).collect())
+ Ok(cols_out.into_iter().collect())
}
fn evaluate(&self) -> Result<ScalarValue> {
@@ -243,7 +243,6 @@ impl Accumulator for DistinctCountAccumulator {
#[cfg(test)]
mod tests {
use super::*;
- use crate::aggregate::utils::get_accum_scalar_values;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
@@ -362,7 +361,7 @@ mod tests {
let mut accum = agg.create_accumulator()?;
accum.update_batch(arrays)?;
- Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
+ Ok((accum.state()?, accum.evaluate()?))
}
fn run_update(
@@ -393,7 +392,7 @@ mod tests {
accum.update_batch(&arrays)?;
- Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
+ Ok((accum.state()?, accum.evaluate()?))
}
fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)> {
@@ -411,7 +410,7 @@ mod tests {
let mut accum = agg.create_accumulator()?;
accum.merge_batch(arrays)?;
- Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
+ Ok((accum.state()?, accum.evaluate()?))
}
// Used trait to create associated constant for f32 and f64
diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs
index b5343059e..74a8b3c79 100644
--- a/datafusion/physical-expr/src/aggregate/covariance.rs
+++ b/datafusion/physical-expr/src/aggregate/covariance.rs
@@ -30,7 +30,7 @@ use arrow::{
};
use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue};
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use crate::aggregate::stats::StatsType;
use crate::expressions::format_state_name;
@@ -237,12 +237,12 @@ impl CovarianceAccumulator {
}
impl Accumulator for CovarianceAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
- AggregateState::Scalar(ScalarValue::from(self.count)),
- AggregateState::Scalar(ScalarValue::from(self.mean1)),
- AggregateState::Scalar(ScalarValue::from(self.mean2)),
- AggregateState::Scalar(ScalarValue::from(self.algo_const)),
+ ScalarValue::from(self.count),
+ ScalarValue::from(self.mean1),
+ ScalarValue::from(self.mean2),
+ ScalarValue::from(self.algo_const),
])
}
diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs
index abde3702f..f2c56515c 100644
--- a/datafusion/physical-expr/src/aggregate/median.rs
+++ b/datafusion/physical-expr/src/aggregate/median.rs
@@ -23,7 +23,7 @@ use arrow::array::{Array, ArrayRef, UInt32Array};
use arrow::compute::sort_to_indices;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use std::any::Any;
use std::sync::Arc;
@@ -101,10 +101,10 @@ struct MedianAccumulator {
}
impl Accumulator for MedianAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
let state =
ScalarValue::new_list(Some(self.all_values.clone()), self.data_type.clone());
- Ok(vec![AggregateState::Scalar(state)])
+ Ok(vec![state])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs
index 73f898c42..5762dead7 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -37,7 +37,7 @@ use arrow::{
};
use datafusion_common::ScalarValue;
use datafusion_common::{downcast_value, DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use crate::aggregate::row_accumulator::RowAccumulator;
use crate::expressions::format_state_name;
@@ -564,8 +564,8 @@ impl Accumulator for MaxAccumulator {
self.update_batch(states)
}
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(vec![AggregateState::Scalar(self.max.clone())])
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.max.clone()])
}
fn evaluate(&self) -> Result<ScalarValue> {
@@ -721,8 +721,8 @@ impl MinAccumulator {
}
impl Accumulator for MinAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(vec![AggregateState::Scalar(self.min.clone())])
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.min.clone()])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs
index 94ec2418b..18a69478b 100644
--- a/datafusion/physical-expr/src/aggregate/stddev.rs
+++ b/datafusion/physical-expr/src/aggregate/stddev.rs
@@ -27,7 +27,7 @@ use crate::{AggregateExpr, PhysicalExpr};
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression
#[derive(Debug)]
@@ -180,11 +180,11 @@ impl StddevAccumulator {
}
impl Accumulator for StddevAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
- AggregateState::Scalar(ScalarValue::from(self.variance.get_count())),
- AggregateState::Scalar(ScalarValue::from(self.variance.get_mean())),
- AggregateState::Scalar(ScalarValue::from(self.variance.get_m2())),
+ ScalarValue::from(self.variance.get_count()),
+ ScalarValue::from(self.variance.get_mean()),
+ ScalarValue::from(self.variance.get_m2()),
])
}
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs
index b330455a1..e72f48ca3 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -32,7 +32,7 @@ use arrow::{
datatypes::Field,
};
use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
use crate::aggregate::row_accumulator::RowAccumulator;
use crate::expressions::format_state_name;
@@ -242,11 +242,8 @@ pub(crate) fn add_to_row(
}
impl Accumulator for SumAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
- Ok(vec![
- AggregateState::Scalar(self.sum.clone()),
- AggregateState::Scalar(ScalarValue::from(self.count)),
- ])
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.sum.clone(), ScalarValue::from(self.count)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
index c835419a5..b9f8759b6 100644
--- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
@@ -28,7 +28,7 @@ use std::collections::HashSet;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
/// Expression for a SUM(DISTINCT) aggregation.
#[derive(Debug)]
@@ -127,7 +127,7 @@ impl DistinctSumAccumulator {
}
impl Accumulator for DistinctSumAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
// 1. Stores aggregate state in `ScalarValue::List`
// 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
let state_out = {
@@ -135,10 +135,10 @@ impl Accumulator for DistinctSumAccumulator {
self.hash_values
.iter()
.for_each(|distinct_value| distinct_values.push(distinct_value.clone()));
- vec![AggregateState::Scalar(ScalarValue::new_list(
+ vec![ScalarValue::new_list(
Some(distinct_values),
self.data_type.clone(),
- ))]
+ )]
};
Ok(state_out)
}
@@ -187,7 +187,6 @@ impl Accumulator for DistinctSumAccumulator {
#[cfg(test)]
mod tests {
use super::*;
- use crate::aggregate::utils::get_accum_scalar_values;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
use arrow::record_batch::RecordBatch;
@@ -203,7 +202,7 @@ mod tests {
let mut accum = agg.create_accumulator()?;
accum.update_batch(arrays)?;
- Ok((get_accum_scalar_values(accum.as_ref())?, accum.evaluate()?))
+ Ok((accum.state()?, accum.evaluate()?))
}
macro_rules! generic_test_sum_distinct {
diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs
index 1cac5b98a..a63c5e208 100644
--- a/datafusion/physical-expr/src/aggregate/utils.rs
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -18,31 +18,16 @@
//! Utilities used in aggregates
use arrow::array::ArrayRef;
-use datafusion_common::{Result, ScalarValue};
+use datafusion_common::Result;
use datafusion_expr::Accumulator;
-/// Extract scalar values from an accumulator. This can return an error if the accumulator
-/// has any non-scalar values.
-pub fn get_accum_scalar_values(accum: &dyn Accumulator) -> Result<Vec<ScalarValue>> {
- accum
- .state()?
- .iter()
- .map(|agg| agg.as_scalar().map(|v| v.clone()))
- .collect::<Result<Vec<_>>>()
-}
-
-/// Convert scalar values from an accumulator into arrays. This can return an error if the
-/// accumulator has any non-scalar values.
+/// Convert scalar values from an accumulator into arrays.
pub fn get_accum_scalar_values_as_arrays(
accum: &dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
- accum
+ Ok(accum
.state()?
.iter()
- .map(|v| {
- v.as_scalar()
- .map(|s| vec![s.clone()])
- .and_then(ScalarValue::iter_to_array)
- })
- .collect::<Result<Vec<_>>>()
+ .map(|s| s.to_array_of_size(1))
+ .collect::<Vec<_>>())
}
diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs
index 8af810a9e..512166385 100644
--- a/datafusion/physical-expr/src/aggregate/variance.rs
+++ b/datafusion/physical-expr/src/aggregate/variance.rs
@@ -33,7 +33,7 @@ use arrow::{
use datafusion_common::downcast_value;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, AggregateState};
+use datafusion_expr::Accumulator;
/// VAR and VAR_SAMP aggregate expression
#[derive(Debug)]
@@ -211,11 +211,11 @@ impl VarianceAccumulator {
}
impl Accumulator for VarianceAccumulator {
- fn state(&self) -> Result<Vec<AggregateState>> {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
- AggregateState::Scalar(ScalarValue::from(self.count)),
- AggregateState::Scalar(ScalarValue::from(self.mean)),
- AggregateState::Scalar(ScalarValue::from(self.m2)),
+ ScalarValue::from(self.count),
+ ScalarValue::from(self.mean),
+ ScalarValue::from(self.m2),
])
}
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index 7847ae068..5f5975a90 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -73,7 +73,7 @@ mod roundtrip_tests {
use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like};
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
use datafusion_expr::{
- col, lit, Accumulator, AggregateFunction, AggregateState,
+ col, lit, Accumulator, AggregateFunction,
BuiltinScalarFunction::{Sqrt, Substr},
Expr, LogicalPlan, Operator, Volatility,
};
@@ -1209,7 +1209,7 @@ mod roundtrip_tests {
struct Dummy {}
impl Accumulator for Dummy {
- fn state(&self) -> datafusion::error::Result<Vec<AggregateState>> {
+ fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
Ok(vec![])
}