You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/06 18:21:49 UTC

[arrow-datafusion] branch master updated: Serialize scalar UDFs in physical plan (#2130)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new e5e8125a1 Serialize scalar UDFs in physical plan (#2130)
e5e8125a1 is described below

commit e5e8125a12cb92a8d859009e650e5e4a9afe6a47
Author: Dan Harris <13...@users.noreply.github.com>
AuthorDate: Wed Apr 6 19:21:45 2022 +0100

    Serialize scalar UDFs in physical plan (#2130)
    
    * Serialize scalar UDFs in physical plan
    
    * Remove unused ScalarFunction enum case
---
 ballista/rust/core/proto/ballista.proto            |   8 +
 .../core/src/serde/physical_plan/from_proto.rs     | 288 ++++++++++++---------
 ballista/rust/core/src/serde/physical_plan/mod.rs  | 160 ++++++++++--
 .../rust/core/src/serde/physical_plan/to_proto.rs  |  40 ++-
 ballista/rust/executor/src/execution_loop.rs       |   6 +-
 ballista/rust/executor/src/executor_server.rs      |   6 +-
 6 files changed, 355 insertions(+), 153 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index ae938c6a7..1c50db6fd 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -306,9 +306,17 @@ message PhysicalExprNode {
 
     // window expressions
     PhysicalWindowExprNode window_expr = 15;
+
+    PhysicalScalarUdfNode scalar_udf = 16;
   }
 }
 
+message PhysicalScalarUdfNode {
+  string name = 1;
+  repeated PhysicalExprNode args = 2;
+  datafusion.ArrowType return_type = 4;
+}
+
 message PhysicalAggregateExprNode {
   datafusion.AggregateFunction aggr_function = 1;
   PhysicalExprNode expr = 2;
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 13f98535e..cc7f866e9 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -18,12 +18,13 @@
 //! Serde code to convert from protocol buffers to Rust data structures.
 
 use std::convert::{TryFrom, TryInto};
+use std::ops::Deref;
 use std::sync::Arc;
 
 use crate::error::BallistaError;
 
+use crate::convert_required;
 use crate::serde::{from_proto_binary_op, proto_error, protobuf};
-use crate::{convert_box_required, convert_required};
 use chrono::{TimeZone, Utc};
 
 use datafusion::datafusion_data_access::{
@@ -31,6 +32,7 @@ use datafusion::datafusion_data_access::{
 };
 use datafusion::datasource::listing::PartitionedFile;
 use datafusion::execution::context::ExecutionProps;
+use datafusion::logical_plan::FunctionRegistry;
 
 use datafusion::physical_plan::file_format::FileScanConfig;
 
@@ -54,125 +56,174 @@ impl From<&protobuf::PhysicalColumn> for Column {
     }
 }
 
-impl TryFrom<&protobuf::PhysicalExprNode> for Arc<dyn PhysicalExpr> {
-    type Error = BallistaError;
+pub(crate) fn parse_physical_expr(
+    proto: &protobuf::PhysicalExprNode,
+    registry: &dyn FunctionRegistry,
+) -> Result<Arc<dyn PhysicalExpr>, BallistaError> {
+    let expr_type = proto
+        .expr_type
+        .as_ref()
+        .ok_or_else(|| proto_error("Unexpected empty physical expression"))?;
+
+    let pexpr: Arc<dyn PhysicalExpr> = match expr_type {
+        ExprType::Column(c) => {
+            let pcol: Column = c.into();
+            Arc::new(pcol)
+        }
+        ExprType::Literal(scalar) => {
+            Arc::new(Literal::new(convert_required!(scalar.value)?))
+        }
+        ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new(
+            parse_required_physical_box_expr(&binary_expr.l, registry, "left")?,
+            from_proto_binary_op(&binary_expr.op)?,
+            parse_required_physical_box_expr(&binary_expr.r, registry, "right")?,
+        )),
+        ExprType::AggregateExpr(_) => {
+            return Err(BallistaError::General(
+                "Cannot convert aggregate expr node to physical expression".to_owned(),
+            ));
+        }
+        ExprType::WindowExpr(_) => {
+            return Err(BallistaError::General(
+                "Cannot convert window expr node to physical expression".to_owned(),
+            ));
+        }
+        ExprType::Sort(_) => {
+            return Err(BallistaError::General(
+                "Cannot convert sort expr node to physical expression".to_owned(),
+            ));
+        }
+        ExprType::IsNullExpr(e) => Arc::new(IsNullExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+        )),
+        ExprType::IsNotNullExpr(e) => Arc::new(IsNotNullExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+        )),
+        ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_box_expr(
+            &e.expr, registry, "expr",
+        )?)),
+        ExprType::Negative(e) => Arc::new(NegativeExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+        )),
+        ExprType::InList(e) => Arc::new(InListExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+            e.list
+                .iter()
+                .map(|x| parse_physical_expr(x, registry))
+                .collect::<Result<Vec<_>, _>>()?,
+            e.negated,
+        )),
+        ExprType::Case(e) => Arc::new(CaseExpr::try_new(
+            e.expr
+                .as_ref()
+                .map(|e| parse_physical_expr(e.as_ref(), registry))
+                .transpose()?,
+            e.when_then_expr
+                .iter()
+                .map(|e| {
+                    Ok((
+                        parse_required_physical_expr(
+                            &e.when_expr,
+                            registry,
+                            "when_expr",
+                        )?,
+                        parse_required_physical_expr(
+                            &e.then_expr,
+                            registry,
+                            "then_expr",
+                        )?,
+                    ))
+                })
+                .collect::<Result<Vec<_>, BallistaError>>()?
+                .as_slice(),
+            e.else_expr
+                .as_ref()
+                .map(|e| parse_physical_expr(e.as_ref(), registry))
+                .transpose()?,
+        )?),
+        ExprType::Cast(e) => Arc::new(CastExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+            convert_required!(e.arrow_type)?,
+            DEFAULT_DATAFUSION_CAST_OPTIONS,
+        )),
+        ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+            convert_required!(e.arrow_type)?,
+        )),
+        ExprType::ScalarFunction(e) => {
+            let scalar_function = datafusion_proto::protobuf::ScalarFunction::from_i32(
+                e.fun,
+            )
+            .ok_or_else(|| {
+                proto_error(format!("Received an unknown scalar function: {}", e.fun,))
+            })?;
+
+            let args = e
+                .args
+                .iter()
+                .map(|x| parse_physical_expr(x, registry))
+                .collect::<Result<Vec<_>, _>>()?;
+
+            // TODO Do not create new the ExecutionProps
+            let execution_props = ExecutionProps::new();
+
+            let fun_expr = functions::create_physical_fun(
+                &(&scalar_function).into(),
+                &execution_props,
+            )?;
+
+            Arc::new(ScalarFunctionExpr::new(
+                &e.name,
+                fun_expr,
+                args,
+                &convert_required!(e.return_type)?,
+            ))
+        }
+        ExprType::ScalarUdf(e) => {
+            let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun;
 
-    fn try_from(expr: &protobuf::PhysicalExprNode) -> Result<Self, Self::Error> {
-        let expr_type = expr
-            .expr_type
-            .as_ref()
-            .ok_or_else(|| proto_error("Unexpected empty physical expression"))?;
+            let args = e
+                .args
+                .iter()
+                .map(|x| parse_physical_expr(x, registry))
+                .collect::<Result<Vec<_>, _>>()?;
+
+            Arc::new(ScalarFunctionExpr::new(
+                e.name.as_str(),
+                scalar_fun,
+                args,
+                &convert_required!(e.return_type)?,
+            ))
+        }
+    };
 
-        let pexpr: Arc<dyn PhysicalExpr> = match expr_type {
-            ExprType::Column(c) => {
-                let pcol: Column = c.into();
-                Arc::new(pcol)
-            }
-            ExprType::Literal(scalar) => {
-                Arc::new(Literal::new(convert_required!(scalar.value)?))
-            }
-            ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new(
-                convert_box_required!(&binary_expr.l)?,
-                from_proto_binary_op(&binary_expr.op)?,
-                convert_box_required!(&binary_expr.r)?,
-            )),
-            ExprType::AggregateExpr(_) => {
-                return Err(BallistaError::General(
-                    "Cannot convert aggregate expr node to physical expression"
-                        .to_owned(),
-                ));
-            }
-            ExprType::WindowExpr(_) => {
-                return Err(BallistaError::General(
-                    "Cannot convert window expr node to physical expression".to_owned(),
-                ));
-            }
-            ExprType::Sort(_) => {
-                return Err(BallistaError::General(
-                    "Cannot convert sort expr node to physical expression".to_owned(),
-                ));
-            }
-            ExprType::IsNullExpr(e) => {
-                Arc::new(IsNullExpr::new(convert_box_required!(e.expr)?))
-            }
-            ExprType::IsNotNullExpr(e) => {
-                Arc::new(IsNotNullExpr::new(convert_box_required!(e.expr)?))
-            }
-            ExprType::NotExpr(e) => {
-                Arc::new(NotExpr::new(convert_box_required!(e.expr)?))
-            }
-            ExprType::Negative(e) => {
-                Arc::new(NegativeExpr::new(convert_box_required!(e.expr)?))
-            }
-            ExprType::InList(e) => Arc::new(InListExpr::new(
-                convert_box_required!(e.expr)?,
-                e.list
-                    .iter()
-                    .map(|x| x.try_into())
-                    .collect::<Result<Vec<_>, _>>()?,
-                e.negated,
-            )),
-            ExprType::Case(e) => Arc::new(CaseExpr::try_new(
-                e.expr.as_ref().map(|e| e.as_ref().try_into()).transpose()?,
-                e.when_then_expr
-                    .iter()
-                    .map(|e| {
-                        Ok((
-                            convert_required!(e.when_expr)?,
-                            convert_required!(e.then_expr)?,
-                        ))
-                    })
-                    .collect::<Result<Vec<_>, BallistaError>>()?
-                    .as_slice(),
-                e.else_expr
-                    .as_ref()
-                    .map(|e| e.as_ref().try_into())
-                    .transpose()?,
-            )?),
-            ExprType::Cast(e) => Arc::new(CastExpr::new(
-                convert_box_required!(e.expr)?,
-                convert_required!(e.arrow_type)?,
-                DEFAULT_DATAFUSION_CAST_OPTIONS,
-            )),
-            ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
-                convert_box_required!(e.expr)?,
-                convert_required!(e.arrow_type)?,
-            )),
-            ExprType::ScalarFunction(e) => {
-                let scalar_function =
-                    datafusion_proto::protobuf::ScalarFunction::from_i32(e.fun)
-                        .ok_or_else(|| {
-                            proto_error(format!(
-                                "Received an unknown scalar function: {}",
-                                e.fun,
-                            ))
-                        })?;
-
-                let args = e
-                    .args
-                    .iter()
-                    .map(|x| x.try_into())
-                    .collect::<Result<Vec<_>, _>>()?;
-
-                // TODO Do not create new the ExecutionProps
-                let execution_props = ExecutionProps::new();
-
-                let fun_expr = functions::create_physical_fun(
-                    &(&scalar_function).into(),
-                    &execution_props,
-                )?;
-
-                Arc::new(ScalarFunctionExpr::new(
-                    &e.name,
-                    fun_expr,
-                    args,
-                    &convert_required!(e.return_type)?,
-                ))
-            }
-        };
+    Ok(pexpr)
+}
 
-        Ok(pexpr)
-    }
+fn parse_required_physical_box_expr(
+    expr: &Option<Box<protobuf::PhysicalExprNode>>,
+    registry: &dyn FunctionRegistry,
+    field: &str,
+) -> Result<Arc<dyn PhysicalExpr>, BallistaError> {
+    expr.as_ref()
+        .map(|e| parse_physical_expr(e.as_ref(), registry))
+        .transpose()?
+        .ok_or_else(|| {
+            BallistaError::General(format!("Missing required field {:?}", field))
+        })
+}
+
+fn parse_required_physical_expr(
+    expr: &Option<protobuf::PhysicalExprNode>,
+    registry: &dyn FunctionRegistry,
+    field: &str,
+) -> Result<Arc<dyn PhysicalExpr>, BallistaError> {
+    expr.as_ref()
+        .map(|e| parse_physical_expr(e, registry))
+        .transpose()?
+        .ok_or_else(|| {
+            BallistaError::General(format!("Missing required field {:?}", field))
+        })
 }
 
 impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction {
@@ -210,13 +261,14 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun
 
 pub fn parse_protobuf_hash_partitioning(
     partitioning: Option<&protobuf::PhysicalHashRepartition>,
+    registry: &dyn FunctionRegistry,
 ) -> Result<Option<Partitioning>, BallistaError> {
     match partitioning {
         Some(hash_part) => {
             let expr = hash_part
                 .hash_expr
                 .iter()
-                .map(|e| e.try_into())
+                .map(|e| parse_physical_expr(e, registry))
                 .collect::<Result<Vec<Arc<dyn PhysicalExpr>>, _>>()?;
 
             Ok(Some(Partitioning::Hash(
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 0361df677..3abb0713d 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -19,7 +19,9 @@ use crate::error::BallistaError;
 use crate::execution_plans::{
     ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec,
 };
-use crate::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
+use crate::serde::physical_plan::from_proto::{
+    parse_physical_expr, parse_protobuf_hash_partitioning,
+};
 use crate::serde::protobuf::physical_expr_node::ExprType;
 use crate::serde::protobuf::physical_plan_node::PhysicalPlanType;
 use crate::serde::protobuf::repartition_exec_node::PartitionMethod;
@@ -30,7 +32,7 @@ use crate::serde::{
     byte_to_string, proto_error, protobuf, str_to_byte, AsExecutionPlan,
     PhysicalExtensionCodec,
 };
-use crate::{convert_box_required, convert_required, into_physical_plan, into_required};
+use crate::{convert_required, into_physical_plan, into_required};
 use datafusion::arrow::compute::SortOptions;
 use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::datafusion_data_access::object_store::local::LocalFileSystem;
@@ -112,7 +114,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     .expr
                     .iter()
                     .zip(projection.expr_name.iter())
-                    .map(|(expr, name)| Ok((expr.try_into()?, name.to_string())))
+                    .map(|(expr, name)| Ok((parse_physical_expr(expr,registry)?, name.to_string())))
                     .collect::<Result<Vec<(Arc<dyn PhysicalExpr>, String)>, BallistaError>>(
                     )?;
                 Ok(Arc::new(ProjectionExec::try_new(exprs, input)?))
@@ -127,13 +129,14 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 let predicate = filter
                     .expr
                     .as_ref()
+                    .map(|expr| parse_physical_expr(expr, registry))
+                    .transpose()?
                     .ok_or_else(|| {
                         BallistaError::General(
                             "filter (FilterExecNode) in PhysicalPlanNode is missing."
                                 .to_owned(),
                         )
-                    })?
-                    .try_into()?;
+                    })?;
                 Ok(Arc::new(FilterExec::try_new(predicate, input)?))
             }
             PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new(
@@ -184,7 +187,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                         let expr = hash_part
                             .hash_expr
                             .iter()
-                            .map(|e| e.try_into())
+                            .map(|e| parse_physical_expr(e, registry))
                             .collect::<Result<Vec<Arc<dyn PhysicalExpr>>, _>>()?;
 
                         Ok(Arc::new(RepartitionExec::try_new(
@@ -255,15 +258,29 @@ impl AsExecutionPlan for PhysicalPlanNode {
                         })?;
 
                         match expr_type {
-                            ExprType::WindowExpr(window_node) => Ok(create_window_expr(
-                                &convert_required!(window_node.window_function)?,
-                                name.to_owned(),
-                                &[convert_box_required!(window_node.expr)?],
-                                &[],
-                                &[],
-                                Some(WindowFrame::default()),
-                                &physical_schema,
-                            )?),
+                            ExprType::WindowExpr(window_node) => {
+                                let window_node_expr = window_node
+                                    .expr
+                                    .as_ref()
+                                    .map(|e| parse_physical_expr(e.as_ref(), registry))
+                                    .transpose()?
+                                    .ok_or_else(|| {
+                                        proto_error(
+                                            "missing window_node expr expression"
+                                                .to_string(),
+                                        )
+                                    })?;
+
+                                Ok(create_window_expr(
+                                    &convert_required!(window_node.window_function)?,
+                                    name.to_owned(),
+                                    &[window_node_expr],
+                                    &[],
+                                    &[],
+                                    Some(WindowFrame::default()),
+                                    &physical_schema,
+                                )?)
+                            }
                             _ => Err(BallistaError::General(
                                 "Invalid expression for WindowAggrExec".to_string(),
                             )),
@@ -302,7 +319,8 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     .iter()
                     .zip(hash_agg.group_expr_name.iter())
                     .map(|(expr, name)| {
-                        expr.try_into().map(|expr| (expr, name.to_string()))
+                        parse_physical_expr(expr, registry)
+                            .map(|expr| (expr, name.to_string()))
                     })
                     .collect::<Result<Vec<_>, _>>()?;
 
@@ -342,10 +360,15 @@ impl AsExecutionPlan for PhysicalPlanNode {
                                             },
                                         )?;
 
+                                let input_phy_expr = agg_node.expr.as_ref()
+                                    .map(|e| parse_physical_expr(e.as_ref(), registry))
+                                    .transpose()?
+                                    .ok_or_else(|| proto_error("missing aggregate expression".to_string()))?;
+
                                 Ok(create_aggregate_expr(
                                     &aggr_function.into(),
                                     false,
-                                    &[convert_box_required!(agg_node.expr)?],
+                                    &[input_phy_expr],
                                     &physical_schema,
                                     name.to_string(),
                                 )?)
@@ -453,6 +476,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
 
                 let output_partitioning = parse_protobuf_hash_partitioning(
                     shuffle_writer.output_partitioning.as_ref(),
+                    registry,
                 )?;
 
                 Ok(Arc::new(ShuffleWriterExec::try_new(
@@ -508,7 +532,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                                 })?
                                 .as_ref();
                             Ok(PhysicalSortExpr {
-                                expr: expr.try_into()?,
+                                expr: parse_physical_expr(expr,registry)?,
                                 options: SortOptions {
                                     descending: !sort_expr.asc,
                                     nulls_first: sort_expr.nulls_first,
@@ -1015,6 +1039,14 @@ mod roundtrip_tests {
     use std::sync::Arc;
 
     use crate::serde::{AsExecutionPlan, BallistaCodec};
+    use datafusion::arrow::array::ArrayRef;
+    use datafusion::execution::context::ExecutionProps;
+    use datafusion::logical_plan::create_udf;
+    use datafusion::physical_plan::functions;
+    use datafusion::physical_plan::functions::{
+        make_scalar_function, BuiltinScalarFunction, ScalarFunctionExpr, Volatility,
+    };
+    use datafusion::physical_plan::projection::ProjectionExec;
     use datafusion::{
         arrow::{
             compute::kernels::sort::SortOptions,
@@ -1069,6 +1101,33 @@ mod roundtrip_tests {
         Ok(())
     }
 
+    fn roundtrip_test_with_context(
+        exec_plan: Arc<dyn ExecutionPlan>,
+        ctx: SessionContext,
+    ) -> Result<()> {
+        let codec: BallistaCodec<LogicalPlanNode, PhysicalPlanNode> =
+            BallistaCodec::default();
+        let proto: protobuf::PhysicalPlanNode =
+            protobuf::PhysicalPlanNode::try_from_physical_plan(
+                exec_plan.clone(),
+                codec.physical_extension_codec(),
+            )
+            .expect("to proto");
+        let runtime = ctx.runtime_env();
+        let result_exec_plan: Arc<dyn ExecutionPlan> = proto
+            .try_into_physical_plan(
+                &ctx,
+                runtime.deref(),
+                codec.physical_extension_codec(),
+            )
+            .expect("from proto");
+        assert_eq!(
+            format!("{:?}", exec_plan),
+            format!("{:?}", result_exec_plan)
+        );
+        Ok(())
+    }
+
     #[test]
     fn roundtrip_empty() -> Result<()> {
         roundtrip_test(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))))
@@ -1241,4 +1300,69 @@ mod roundtrip_tests {
         let predicate = datafusion::prelude::col("col").eq(datafusion::prelude::lit("1"));
         roundtrip_test(Arc::new(ParquetExec::new(scan_config, Some(predicate))))
     }
+
+    #[test]
+    fn roundtrip_builtin_scalar_function() -> Result<()> {
+        let field_a = Field::new("a", DataType::Int64, false);
+        let field_b = Field::new("b", DataType::Int64, false);
+        let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
+        let input = Arc::new(EmptyExec::new(false, schema.clone()));
+
+        let execution_props = ExecutionProps::new();
+
+        let fun_expr = functions::create_physical_fun(
+            &BuiltinScalarFunction::Abs,
+            &execution_props,
+        )?;
+
+        let expr = ScalarFunctionExpr::new(
+            "abs",
+            fun_expr,
+            vec![col("a", &schema)?],
+            &DataType::Int64,
+        );
+
+        let project =
+            ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?;
+
+        roundtrip_test(Arc::new(project))
+    }
+
+    #[test]
+    fn roundtrip_scalar_udf() -> Result<()> {
+        let field_a = Field::new("a", DataType::Int64, false);
+        let field_b = Field::new("b", DataType::Int64, false);
+        let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
+        let input = Arc::new(EmptyExec::new(false, schema.clone()));
+
+        let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);
+
+        let scalar_fn = make_scalar_function(fn_impl);
+
+        let udf = create_udf(
+            "dummy",
+            vec![DataType::Int64],
+            Arc::new(DataType::Int64),
+            Volatility::Immutable,
+            scalar_fn.clone(),
+        );
+
+        let expr = ScalarFunctionExpr::new(
+            "dummy",
+            scalar_fn,
+            vec![col("a", &schema)?],
+            &DataType::Int64,
+        );
+
+        let project =
+            ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?;
+
+        let mut ctx = SessionContext::new();
+
+        ctx.register_udf(udf);
+
+        roundtrip_test_with_context(Arc::new(project), ctx)
+    }
 }
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 9a63762a4..1a1276ec1 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -286,24 +286,38 @@ impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::PhysicalExprNode {
                 )),
             })
         } else if let Some(expr) = expr.downcast_ref::<ScalarFunctionExpr>() {
-            let fun: BuiltinScalarFunction =
-                BuiltinScalarFunction::from_str(expr.name())?;
-            let fun: datafusion_proto::protobuf::ScalarFunction = (&fun).try_into()?;
             let args: Vec<protobuf::PhysicalExprNode> = expr
                 .args()
                 .iter()
                 .map(|e| e.to_owned().try_into())
                 .collect::<Result<Vec<_>, _>>()?;
-            Ok(protobuf::PhysicalExprNode {
-                expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction(
-                    protobuf::PhysicalScalarFunctionNode {
-                        name: expr.name().to_string(),
-                        fun: fun.into(),
-                        args,
-                        return_type: Some(expr.return_type().into()),
-                    },
-                )),
-            })
+            if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) {
+                let fun: datafusion_proto::protobuf::ScalarFunction =
+                    (&fun).try_into()?;
+
+                Ok(protobuf::PhysicalExprNode {
+                    expr_type: Some(
+                        protobuf::physical_expr_node::ExprType::ScalarFunction(
+                            protobuf::PhysicalScalarFunctionNode {
+                                name: expr.name().to_string(),
+                                fun: fun.into(),
+                                args,
+                                return_type: Some(expr.return_type().into()),
+                            },
+                        ),
+                    ),
+                })
+            } else {
+                Ok(protobuf::PhysicalExprNode {
+                    expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf(
+                        protobuf::PhysicalScalarUdfNode {
+                            name: expr.name().to_string(),
+                            args,
+                            return_type: Some(expr.return_type().into()),
+                        },
+                    )),
+                })
+            }
         } else {
             Err(BallistaError::General(format!(
                 "physical_plan::to_proto() unsupported expression {:?}",
diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs
index bb4edb6b9..06128f8db 100644
--- a/ballista/rust/executor/src/execution_loop.rs
+++ b/ballista/rust/executor/src/execution_loop.rs
@@ -161,8 +161,10 @@ async fn run_received_tasks<T: 'static + AsLogicalPlan, U: 'static + AsExecution
             )
         })?;
 
-    let shuffle_output_partitioning =
-        parse_protobuf_hash_partitioning(task.output_partitioning.as_ref())?;
+    let shuffle_output_partitioning = parse_protobuf_hash_partitioning(
+        task.output_partitioning.as_ref(),
+        task_context.as_ref(),
+    )?;
 
     tokio::spawn(async move {
         let execution_result = executor
diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs
index 81910f264..11a5c7552 100644
--- a/ballista/rust/executor/src/executor_server.rs
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -216,8 +216,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
                 )
             })?;
 
-        let shuffle_output_partitioning =
-            parse_protobuf_hash_partitioning(task.output_partitioning.as_ref())?;
+        let shuffle_output_partitioning = parse_protobuf_hash_partitioning(
+            task.output_partitioning.as_ref(),
+            task_context.as_ref(),
+        )?;
 
         let execution_result = self
             .executor