You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by th...@apache.org on 2023/04/19 17:28:30 UTC

[arrow-datafusion] 01/01: Add support for UDAF in physical plan serialization

This is an automated email from the ASF dual-hosted git repository.

thinkharderdev pushed a commit to branch issue-6062
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 15532a0123f5d0c9237ffc623d4ef6090c82b4ef
Author: Dan Harris <da...@thinkharder.dev>
AuthorDate: Wed Apr 19 13:28:13 2023 -0400

    Add support for UDAF in physical plan serialization
---
 datafusion/core/src/physical_plan/udaf.rs      |   7 ++
 datafusion/proto/proto/datafusion.proto        |   5 +-
 datafusion/proto/src/generated/pbjson.rs       |  57 +++++++----
 datafusion/proto/src/generated/prost.rs        |  17 +++-
 datafusion/proto/src/physical_plan/mod.rs      | 129 ++++++++++++++++++++-----
 datafusion/proto/src/physical_plan/to_proto.rs |  33 +++++--
 6 files changed, 196 insertions(+), 52 deletions(-)

diff --git a/datafusion/core/src/physical_plan/udaf.rs b/datafusion/core/src/physical_plan/udaf.rs
index cbbb851865..07e5cc3e6d 100644
--- a/datafusion/core/src/physical_plan/udaf.rs
+++ b/datafusion/core/src/physical_plan/udaf.rs
@@ -65,6 +65,13 @@ pub struct AggregateFunctionExpr {
     name: String,
 }
 
+impl AggregateFunctionExpr {
+    /// Return the `AggregateUDF` used by this `AggregateFunctionExpr`
+    pub fn fun(&self) -> &AggregateUDF {
+        &self.fun
+    }
+}
+
 impl AggregateExpr for AggregateFunctionExpr {
     /// Return a reference to Any that can be used for downcasting
     fn as_any(&self) -> &dyn Any {
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 3023cbc264..7d02fda86c 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1062,7 +1062,10 @@ message PhysicalScalarUdfNode {
 }
 
 message PhysicalAggregateExprNode {
-  AggregateFunction aggr_function = 1;
+  oneof AggregateFunction {
+    AggregateFunction aggr_function = 1;
+    string user_defined_aggr_function = 4;
+  }
   repeated PhysicalExprNode expr = 2;
   bool distinct = 3;
 }
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index 6a416d37c9..553f3f2911 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -12661,27 +12661,34 @@ impl serde::Serialize for PhysicalAggregateExprNode {
     {
         use serde::ser::SerializeStruct;
         let mut len = 0;
-        if self.aggr_function != 0 {
-            len += 1;
-        }
         if !self.expr.is_empty() {
             len += 1;
         }
         if self.distinct {
             len += 1;
         }
-        let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?;
-        if self.aggr_function != 0 {
-            let v = AggregateFunction::from_i32(self.aggr_function)
-                .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", self.aggr_function)))?;
-            struct_ser.serialize_field("aggrFunction", &v)?;
+        if self.aggregate_function.is_some() {
+            len += 1;
         }
+        let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?;
         if !self.expr.is_empty() {
             struct_ser.serialize_field("expr", &self.expr)?;
         }
         if self.distinct {
             struct_ser.serialize_field("distinct", &self.distinct)?;
         }
+        if let Some(v) = self.aggregate_function.as_ref() {
+            match v {
+                physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => {
+                    let v = AggregateFunction::from_i32(*v)
+                        .ok_or_else(|| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?;
+                    struct_ser.serialize_field("aggrFunction", &v)?;
+                }
+                physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => {
+                    struct_ser.serialize_field("userDefinedAggrFunction", v)?;
+                }
+            }
+        }
         struct_ser.end()
     }
 }
@@ -12692,17 +12699,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode {
         D: serde::Deserializer<'de>,
     {
         const FIELDS: &[&str] = &[
-            "aggr_function",
-            "aggrFunction",
             "expr",
             "distinct",
+            "aggr_function",
+            "aggrFunction",
+            "user_defined_aggr_function",
+            "userDefinedAggrFunction",
         ];
 
         #[allow(clippy::enum_variant_names)]
         enum GeneratedField {
-            AggrFunction,
             Expr,
             Distinct,
+            AggrFunction,
+            UserDefinedAggrFunction,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
@@ -12724,9 +12734,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode {
                         E: serde::de::Error,
                     {
                         match value {
-                            "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction),
                             "expr" => Ok(GeneratedField::Expr),
                             "distinct" => Ok(GeneratedField::Distinct),
+                            "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction),
+                            "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
                             _ => Err(serde::de::Error::unknown_field(value, FIELDS)),
                         }
                     }
@@ -12746,17 +12757,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode {
                 where
                     V: serde::de::MapAccess<'de>,
             {
-                let mut aggr_function__ = None;
                 let mut expr__ = None;
                 let mut distinct__ = None;
+                let mut aggregate_function__ = None;
                 while let Some(k) = map.next_key()? {
                     match k {
-                        GeneratedField::AggrFunction => {
-                            if aggr_function__.is_some() {
-                                return Err(serde::de::Error::duplicate_field("aggrFunction"));
-                            }
-                            aggr_function__ = Some(map.next_value::<AggregateFunction>()? as i32);
-                        }
                         GeneratedField::Expr => {
                             if expr__.is_some() {
                                 return Err(serde::de::Error::duplicate_field("expr"));
@@ -12769,12 +12774,24 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode {
                             }
                             distinct__ = Some(map.next_value()?);
                         }
+                        GeneratedField::AggrFunction => {
+                            if aggregate_function__.is_some() {
+                                return Err(serde::de::Error::duplicate_field("aggrFunction"));
+                            }
+                            aggregate_function__ = map.next_value::<::std::option::Option<AggregateFunction>>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32));
+                        }
+                        GeneratedField::UserDefinedAggrFunction => {
+                            if aggregate_function__.is_some() {
+                                return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction"));
+                            }
+                            aggregate_function__ = map.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction);
+                        }
                     }
                 }
                 Ok(PhysicalAggregateExprNode {
-                    aggr_function: aggr_function__.unwrap_or_default(),
                     expr: expr__.unwrap_or_default(),
                     distinct: distinct__.unwrap_or_default(),
+                    aggregate_function: aggregate_function__,
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index 8ec16070ee..fd3cdc1292 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1487,12 +1487,25 @@ pub struct PhysicalScalarUdfNode {
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct PhysicalAggregateExprNode {
-    #[prost(enumeration = "AggregateFunction", tag = "1")]
-    pub aggr_function: i32,
     #[prost(message, repeated, tag = "2")]
     pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
     #[prost(bool, tag = "3")]
     pub distinct: bool,
+    #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")]
+    pub aggregate_function: ::core::option::Option<
+        physical_aggregate_expr_node::AggregateFunction,
+    >,
+}
+/// Nested message and enum types in `PhysicalAggregateExprNode`.
+pub mod physical_aggregate_expr_node {
+    #[allow(clippy::derive_partial_eq_without_eq)]
+    #[derive(Clone, PartialEq, ::prost::Oneof)]
+    pub enum AggregateFunction {
+        #[prost(enumeration = "super::AggregateFunction", tag = "1")]
+        AggrFunction(i32),
+        #[prost(string, tag = "4")]
+        UserDefinedAggrFunction(::prost::alloc::string::String),
+    }
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs
index 381073dec0..2c35428e86 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -45,7 +45,7 @@ use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMerge
 use datafusion::physical_plan::union::UnionExec;
 use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec};
 use datafusion::physical_plan::{
-    AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
+    udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
 };
 use datafusion_common::{DataFusionError, Result};
 use prost::bytes::BufMut;
@@ -56,6 +56,7 @@ use crate::common::{csv_delimiter_to_string, str_to_byte};
 use crate::physical_plan::from_proto::{
     parse_physical_expr, parse_protobuf_file_scan_config,
 };
+use crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
 use crate::protobuf::physical_expr_node::ExprType;
 use crate::protobuf::physical_plan_node::PhysicalPlanType;
 use crate::protobuf::repartition_exec_node::PartitionMethod;
@@ -427,29 +428,38 @@ impl AsExecutionPlan for PhysicalPlanNode {
 
                         match expr_type {
                             ExprType::AggregateExpr(agg_node) => {
-                                let aggr_function =
-                                    protobuf::AggregateFunction::from_i32(
-                                        agg_node.aggr_function,
-                                    )
-                                        .ok_or_else(
-                                            || {
-                                                proto_error(format!(
-                                                    "Received an unknown aggregate function: {}",
-                                                    agg_node.aggr_function
-                                                ))
-                                            },
-                                        )?;
-
                                 let input_phy_expr: Vec<Arc<dyn PhysicalExpr>> = agg_node.expr.iter()
                                     .map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect();
 
-                                Ok(create_aggregate_expr(
-                                    &aggr_function.into(),
-                                    agg_node.distinct,
-                                    input_phy_expr.as_slice(),
-                                    &physical_schema,
-                                    name.to_string(),
-                                )?)
+                                agg_node.aggregate_function.as_ref().map(|func| {
+                                    match func {
+                                        AggregateFunction::AggrFunction(i) => {
+                                            let aggr_function = protobuf::AggregateFunction::from_i32(*i)
+                                                .ok_or_else(
+                                                    || {
+                                                        proto_error(format!(
+                                                            "Received an unknown aggregate function: {}",
+                                                            i
+                                                        ))
+                                                    },
+                                                )?;
+
+                                            create_aggregate_expr(
+                                                &aggr_function.into(),
+                                                agg_node.distinct,
+                                                input_phy_expr.as_slice(),
+                                                &physical_schema,
+                                                name.to_string(),
+                                            )
+                                        }
+                                        AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
+                                            let agg_udf = registry.udaf(udaf_name)?;
+                                            udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name)
+                                        }
+                                    }
+                                }).transpose()?.ok_or_else(|| {
+                                    proto_error("Invalid AggregateExpr, missing aggregate_function")
+                                })
                             }
                             _ => Err(DataFusionError::Internal(
                                 "Invalid aggregate expression for AggregateExec"
@@ -1238,9 +1248,9 @@ mod roundtrip_tests {
     use datafusion::physical_expr::ScalarFunctionExpr;
     use datafusion::physical_plan::aggregates::PhysicalGroupBy;
     use datafusion::physical_plan::expressions::{like, BinaryExpr, GetIndexedFieldExpr};
-    use datafusion::physical_plan::functions;
     use datafusion::physical_plan::functions::make_scalar_function;
     use datafusion::physical_plan::projection::ProjectionExec;
+    use datafusion::physical_plan::{functions, udaf};
     use datafusion::{
         arrow::{
             compute::kernels::sort::SortOptions,
@@ -1264,6 +1274,10 @@ mod roundtrip_tests {
         scalar::ScalarValue,
     };
     use datafusion_common::Result;
+    use datafusion_expr::{
+        Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction,
+        Signature, StateTypeFunction,
+    };
 
     fn roundtrip_test(exec_plan: Arc<dyn ExecutionPlan>) -> Result<()> {
         let ctx = SessionContext::new();
@@ -1419,6 +1433,77 @@ mod roundtrip_tests {
         )?))
     }
 
+    #[test]
+    fn roundtrip_aggregate_udaf() -> Result<()> {
+        let field_a = Field::new("a", DataType::Int64, false);
+        let field_b = Field::new("b", DataType::Int64, false);
+        let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
+        #[derive(Debug)]
+        struct Example;
+        impl Accumulator for Example {
+            fn state(&self) -> Result<Vec<ScalarValue>> {
+                Ok(vec![ScalarValue::Int64(Some(0))])
+            }
+
+            fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
+                Ok(())
+            }
+
+            fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> {
+                Ok(())
+            }
+
+            fn evaluate(&self) -> Result<ScalarValue> {
+                Ok(ScalarValue::Int64(Some(0)))
+            }
+
+            fn size(&self) -> usize {
+                0
+            }
+        }
+
+        let rt_func: ReturnTypeFunction =
+            Arc::new(move |_| Ok(Arc::new(DataType::Int64)));
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(|_| Ok(Box::new(Example)));
+        let st_func: StateTypeFunction =
+            Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64])));
+
+        let udaf = AggregateUDF::new(
+            "example",
+            &Signature::exact(vec![DataType::Int64], Volatility::Immutable),
+            &rt_func,
+            &accumulator,
+            &st_func,
+        );
+
+        let ctx = SessionContext::new();
+        ctx.register_udaf(udaf.clone());
+
+        let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
+            vec![(col("a", &schema)?, "unused".to_string())];
+
+        let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![udaf::create_aggregate_expr(
+            &udaf,
+            &[col("b", &schema)?],
+            &schema,
+            "example_agg",
+        )?];
+
+        roundtrip_test_with_context(
+            Arc::new(AggregateExec::try_new(
+                AggregateMode::Final,
+                PhysicalGroupBy::new_single(groups.clone()),
+                aggregates.clone(),
+                vec![None],
+                Arc::new(EmptyExec::new(false, schema.clone())),
+                schema,
+            )?),
+            ctx,
+        )
+    }
+
     #[test]
     fn roundtrip_filter_with_not_and_in_list() -> Result<()> {
         let field_a = Field::new("a", DataType::Boolean, false);
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs
index e18932575c..9495c841be 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -42,11 +42,12 @@ use datafusion::physical_plan::expressions::{
 use datafusion::physical_plan::{AggregateExpr, PhysicalExpr};
 
 use crate::protobuf;
-use crate::protobuf::{PhysicalSortExprNode, ScalarValue};
+use crate::protobuf::{physical_aggregate_expr_node, PhysicalSortExprNode, ScalarValue};
 use datafusion::logical_expr::BuiltinScalarFunction;
 use datafusion::physical_expr::expressions::{DateTimeIntervalExpr, GetIndexedFieldExpr};
 use datafusion::physical_expr::ScalarFunctionExpr;
 use datafusion::physical_plan::joins::utils::JoinSide;
+use datafusion::physical_plan::udaf::AggregateFunctionExpr;
 use datafusion_common::{DataFusionError, Result};
 
 impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode {
@@ -56,6 +57,12 @@ impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode {
         use datafusion::physical_plan::expressions;
         use protobuf::AggregateFunction;
 
+        let expressions: Vec<protobuf::PhysicalExprNode> = a
+            .expressions()
+            .iter()
+            .map(|e| e.clone().try_into())
+            .collect::<Result<Vec<_>>>()?;
+
         let mut distinct = false;
         let aggr_function = if a.as_any().downcast_ref::<Avg>().is_some() {
             Ok(AggregateFunction::Avg.into())
@@ -131,19 +138,31 @@ impl TryFrom<Arc<dyn AggregateExpr>> for protobuf::PhysicalExprNode {
         {
             Ok(AggregateFunction::ApproxMedian.into())
         } else {
+            if let Some(a) = a.as_any().downcast_ref::<AggregateFunctionExpr>() {
+                return Ok(protobuf::PhysicalExprNode {
+                    expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
+                        protobuf::PhysicalAggregateExprNode {
+                            aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(a.fun().name.clone())),
+                            expr: expressions,
+                            distinct,
+                        },
+                    )),
+                });
+            }
+
             Err(DataFusionError::NotImplemented(format!(
                 "Aggregate function not supported: {a:?}"
             )))
         }?;
-        let expressions: Vec<protobuf::PhysicalExprNode> = a
-            .expressions()
-            .iter()
-            .map(|e| e.clone().try_into())
-            .collect::<Result<Vec<_>>>()?;
+
         Ok(protobuf::PhysicalExprNode {
             expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
                 protobuf::PhysicalAggregateExprNode {
-                    aggr_function,
+                    aggregate_function: Some(
+                        physical_aggregate_expr_node::AggregateFunction::AggrFunction(
+                            aggr_function,
+                        ),
+                    ),
                     expr: expressions,
                     distinct,
                 },