You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by th...@apache.org on 2023/04/19 17:28:30 UTC
[arrow-datafusion] 01/01: Add support for UDAF in physical plan serialization
This is an automated email from the ASF dual-hosted git repository.
thinkharderdev pushed a commit to branch issue-6062
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 15532a0123f5d0c9237ffc623d4ef6090c82b4ef
Author: Dan Harris <da...@thinkharder.dev>
AuthorDate: Wed Apr 19 13:28:13 2023 -0400
Add support for UDAF in physical plan serialization
---
datafusion/core/src/physical_plan/udaf.rs | 7 ++
datafusion/proto/proto/datafusion.proto | 5 +-
datafusion/proto/src/generated/pbjson.rs | 57 +++++++----
datafusion/proto/src/generated/prost.rs | 17 +++-
datafusion/proto/src/physical_plan/mod.rs | 129 ++++++++++++++++++++-----
datafusion/proto/src/physical_plan/to_proto.rs | 33 +++++--
6 files changed, 196 insertions(+), 52 deletions(-)
diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs
index cbbb851865..07e5cc3e6d 100644
--- a/datafusion/core/src/physical_plan/udaf.rs
+++ b/datafusion/core/src/physical_plan/udaf.rs
@@ -65,6 +65,13 @@ pub struct AggregateFunctionExpr {
name: String,
}
+impl AggregateFunctionExpr {
+ /// Return the `AggregateUDF` used by this `AggregateFunctionExpr`
+ pub fn fun(&self) -> &AggregateUDF {
+ &self.fun
+ }
+}
+
impl AggregateExpr for AggregateFunctionExpr {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 3023cbc264..7d02fda86c 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1062,7 +1062,10 @@ message PhysicalScalarUdfNode {
}
message PhysicalAggregateExprNode {
- AggregateFunction aggr_function = 1;
+ oneof AggregateFunction {
+ AggregateFunction aggr_function = 1;
+ string user_defined_aggr_function = 4;
+ }
repeated PhysicalExprNode expr = 2;
bool distinct = 3;
}
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index 6a416d37c9..553f3f2911 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -12661,27 +12661,34 @@ impl serde::Serialize for PhysicalAggregateExprNode {
{
use serde::ser::SerializeStruct;
let mut len = 0;
- if self.aggr_function != 0 {
- len += 1;
- }
if !self.expr.is_empty() {
len += 1;
}
if self.distinct {
len += 1;
}
- let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?;
- if self.aggr_function != 0 {
- let v = AggregateFunction::from_i32(self.aggr_function)
- .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.aggr_function)))?;
- struct_ser.serialize_field("aggrFunction", &v)?;
+ if self.aggregate_function.is_some() {
+ len += 1;
}
+ let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?;
if !self.expr.is_empty() {
struct_ser.serialize_field("expr", &self.expr)?;
}
if self.distinct {
struct_ser.serialize_field("distinct", &self.distinct)?;
}
+ if let Some(v) = self.aggregate_function.as_ref() {
+ match v {
+ physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => {
+ let v = AggregateFunction::from_i32(*v)
+ .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?;
+ struct_ser.serialize_field("aggrFunction", &v)?;
+ }
+ physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => {
+ struct_ser.serialize_field("userDefinedAggrFunction", v)?;
+ }
+ }
+ }
struct_ser.end()
}
}
@@ -12692,17 +12699,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode {
D: serde::Deserializer<'de>,
{
const FIELDS: &[&str] = &[
- "aggr_function",
- "aggrFunction",
"expr",
"distinct",
+ "aggr_function",
+ "aggrFunction",
+ "user_defined_aggr_function",
+ "userDefinedAggrFunction",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
- AggrFunction,
Expr,
Distinct,
+ AggrFunction,
+ UserDefinedAggrFunction,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
@@ -12724,9 +12734,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode {
E: serde::de::Error,
{
match value {
- "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction),
"expr" => Ok(GeneratedField::Expr),
"distinct" => Ok(GeneratedField::Distinct),
+ "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction),
+ "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
@@ -12746,17 +12757,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode {
where
V: serde::de::MapAccess<'de>,
{
- let mut aggr_function__ = None;
let mut expr__ = None;
let mut distinct__ = None;
+ let mut aggregate_function__ = None;
while let Some(k) = map.next_key()? {
match k {
- GeneratedField::AggrFunction => {
- if aggr_function__.is_some() {
- return Err(serde::de::Error::duplicate_field("aggrFunction"));
- }
- aggr_function__ = Some(map.next_value::<AggregateFunction>()? as i32);
- }
GeneratedField::Expr => {
if expr__.is_some() {
return Err(serde::de::Error::duplicate_field("expr"));
@@ -12769,12 +12774,24 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode {
}
distinct__ = Some(map.next_value()?);
}
+ GeneratedField::AggrFunction => {
+ if aggregate_function__.is_some() {
+ return Err(serde::de::Error::duplicate_field("aggrFunction"));
+ }
+ aggregate_function__ = map.next_value::<::std::option::Option<AggregateFunction>>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32));
+ }
+ GeneratedField::UserDefinedAggrFunction => {
+ if aggregate_function__.is_some() {
+ return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction"));
+ }
+ aggregate_function__ = map.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction);
+ }
}
}
Ok(PhysicalAggregateExprNode {
- aggr_function: aggr_function__.unwrap_or_default(),
expr: expr__.unwrap_or_default(),
distinct: distinct__.unwrap_or_default(),
+ aggregate_function: aggregate_function__,
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index 8ec16070ee..fd3cdc1292 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1487,12 +1487,25 @@ pub struct PhysicalScalarUdfNode {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct PhysicalAggregateExprNode {
- #[prost(enumeration = "AggregateFunction", tag = "1")]
- pub aggr_function: i32,
#[prost(message, repeated, tag = "2")]
pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
#[prost(bool, tag = "3")]
pub distinct: bool,
+ #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")]
+ pub aggregate_function: ::core::option::Option<
+ physical_aggregate_expr_node::AggregateFunction,
+ >,
+}
+/// Nested message and enum types in `PhysicalAggregateExprNode`.
+pub mod physical_aggregate_expr_node {
+ #[allow(clippy::derive_partial_eq_without_eq)]
+ #[derive(Clone, PartialEq, ::prost::Oneof)]
+ pub enum AggregateFunction {
+ #[prost(enumeration = "super::AggregateFunction", tag = "1")]
+ AggrFunction(i32),
+ #[prost(string, tag = "4")]
+ UserDefinedAggrFunction(::prost::alloc::string::String),
+ }
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs
index 381073dec0..2c35428e86 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -45,7 +45,7 @@ use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMerge
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec};
use datafusion::physical_plan::{
- AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
+ udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
};
use datafusion_common::{DataFusionError, Result};
use prost::bytes::BufMut;
@@ -56,6 +56,7 @@ use crate::common::{csv_delimiter_to_string, str_to_byte};
use crate::physical_plan::from_proto::{
parse_physical_expr, parse_protobuf_file_scan_config,
};
+use crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
use crate::protobuf::physical_expr_node::ExprType;
use crate::protobuf::physical_plan_node::PhysicalPlanType;
use crate::protobuf::repartition_exec_node::PartitionMethod;
@@ -427,29 +428,38 @@ impl AsExecutionPlan for PhysicalPlanNode {
match expr_type {
ExprType::AggregateExpr(agg_node) => {
- let aggr_function =
- protobuf::AggregateFunction::from_i32(
- agg_node.aggr_function,
- )
- .ok_or_else(
- || {
- proto_error(format!(
- "Received an unknown aggregate function: {}",
- agg_node.aggr_function
- ))
- },
- )?;
-
let input_phy_expr: Vec<Arc<dyn PhysicalExpr>> = agg_node.expr.iter()
.map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect();
- Ok(create_aggregate_expr(
- &aggr_function.into(),
- agg_node.distinct,
- input_phy_expr.as_slice(),
- &physical_schema,
- name.to_string(),
- )?)
+ agg_node.aggregate_function.as_ref().map(|func| {
+ match func {
+ AggregateFunction::AggrFunction(i) => {
+ let aggr_function = protobuf::AggregateFunction::from_i32(*i)
+ .ok_or_else(
+ || {
+ proto_error(format!(
+ "Received an unknown aggregate function: {}",
+ i
+ ))
+ },
+ )?;
+
+ create_aggregate_expr(
+ &aggr_function.into(),
+ agg_node.distinct,
+ input_phy_expr.as_slice(),
+ &physical_schema,
+ name.to_string(),
+ )
+ }
+ AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
+ let agg_udf = registry.udaf(udaf_name)?;
+ udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name)
+ }
+ }
+ }).transpose()?.ok_or_else(|| {
+ proto_error("Invalid AggregateExpr, missing aggregate_function")
+ })
}
_ => Err(DataFusionError::Internal(
"Invalid aggregate expression for AggregateExec"
@@ -1238,9 +1248,9 @@ mod roundtrip_tests {
use datafusion::physical_expr::ScalarFunctionExpr;
use datafusion::physical_plan::aggregates::PhysicalGroupBy;
use datafusion::physical_plan::expressions::{like, BinaryExpr, GetIndexedFieldExpr};
- use datafusion::physical_plan::functions;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::physical_plan::projection::ProjectionExec;
+ use datafusion::physical_plan::{functions, udaf};
use datafusion::{
arrow::{
compute::kernels::sort::SortOptions,
@@ -1264,6 +1274,10 @@ mod roundtrip_tests {
scalar::ScalarValue,
};
use datafusion_common::Result;
+ use datafusion_expr::{
+ Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction,
+ Signature, StateTypeFunction,
+ };
fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) -> Result<()> {
let ctx = SessionContext::new();
@@ -1419,6 +1433,77 @@ mod roundtrip_tests {
)?))
}
+ #[test]
+ fn roundtrip_aggregate_udaf() -> Result<()> {
+ let field_a = Field::new("a", DataType::Int64, false);
+ let field_b = Field::new("b", DataType::Int64, false);
+ let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
+ #[derive(Debug)]
+ struct Example;
+ impl Accumulator for Example {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::Int64(Some(0))])
+ }
+
+ fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> {
+ Ok(())
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(0)))
+ }
+
+ fn size(&self) -> usize {
+ 0
+ }
+ }
+
+ let rt_func: ReturnTypeFunction =
+ Arc::new(move |_| Ok(Arc::new(DataType::Int64)));
+ let accumulator: AccumulatorFunctionImplementation =
+ Arc::new(|_| Ok(Box::new(Example)));
+ let st_func: StateTypeFunction =
+ Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64])));
+
+ let udaf = AggregateUDF::new(
+ "example",
+ &Signature::exact(vec![DataType::Int64], Volatility::Immutable),
+ &rt_func,
+ &accumulator,
+ &st_func,
+ );
+
+ let ctx = SessionContext::new();
+ ctx.register_udaf(udaf.clone());
+
+ let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
+ vec![(col("a", &schema)?, "unused".to_string())];
+
+ let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![udaf::create_aggregate_expr(
+ &udaf,
+ &[col("b", &schema)?],
+ &schema,
+ "example_agg",
+ )?];
+
+ roundtrip_test_with_context(
+ Arc::new(AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::new_single(groups.clone()),
+ aggregates.clone(),
+ vec![None],
+ Arc::new(EmptyExec::new(false, schema.clone())),
+ schema,
+ )?),
+ ctx,
+ )
+ }
+
#[test]
fn roundtrip_filter_with_not_and_in_list() -> Result<()> {
let field_a = Field::new("a", DataType::Boolean, false);
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs
index e18932575c..9495c841be 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -42,11 +42,12 @@ use datafusion::physical_plan::expressions::{
use datafusion::physical_plan::{AggregateExpr, PhysicalExpr};
use crate::protobuf;
-use crate::protobuf::{PhysicalSortExprNode, ScalarValue};
+use crate::protobuf::{physical_aggregate_expr_node, PhysicalSortExprNode, ScalarValue};
use datafusion::logical_expr::BuiltinScalarFunction;
use datafusion::physical_expr::expressions::{DateTimeIntervalExpr, GetIndexedFieldExpr};
use datafusion::physical_expr::ScalarFunctionExpr;
use datafusion::physical_plan::joins::utils::JoinSide;
+use datafusion::physical_plan::udaf::AggregateFunctionExpr;
use datafusion_common::{DataFusionError, Result};
impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode {
@@ -56,6 +57,12 @@ impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode {
use datafusion::physical_plan::expressions;
use protobuf::AggregateFunction;
+ let expressions: Vec<protobuf::PhysicalExprNode> = a
+ .expressions()
+ .iter()
+ .map(|e| e.clone().try_into())
+ .collect::<Result<Vec<_>>>()?;
+
let mut distinct = false;
let aggr_function = if a.as_any().downcast_ref::<Avg>().is_some() {
Ok(AggregateFunction::Avg.into())
@@ -131,19 +138,31 @@ impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode {
{
Ok(AggregateFunction::ApproxMedian.into())
} else {
+ if let Some(a) = a.as_any().downcast_ref::<AggregateFunctionExpr>() {
+ return Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
+ protobuf::PhysicalAggregateExprNode {
+ aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(a.fun().name.clone())),
+ expr: expressions,
+ distinct,
+ },
+ )),
+ });
+ }
+
Err(DataFusionError::NotImplemented(format!(
"Aggregate function not supported: {a:?}"
)))
}?;
- let expressions: Vec<protobuf::PhysicalExprNode> = a
- .expressions()
- .iter()
- .map(|e| e.clone().try_into())
- .collect::<Result<Vec<_>>>()?;
+
Ok(protobuf::PhysicalExprNode {
expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
protobuf::PhysicalAggregateExprNode {
- aggr_function,
+ aggregate_function: Some(
+ physical_aggregate_expr_node::AggregateFunction::AggrFunction(
+ aggr_function,
+ ),
+ ),
expr: expressions,
distinct,
},