You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/06 18:21:49 UTC
[arrow-datafusion] branch master updated: Serialize scalar UDFs in physical plan (#2130)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new e5e8125a1 Serialize scalar UDFs in physical plan (#2130)
e5e8125a1 is described below
commit e5e8125a12cb92a8d859009e650e5e4a9afe6a47
Author: Dan Harris <13...@users.noreply.github.com>
AuthorDate: Wed Apr 6 19:21:45 2022 +0100
Serialize scalar UDFs in physical plan (#2130)
* Serialize scalar UDFs in physical plan
* Remove unused ScalarFunction enum case
---
ballista/rust/core/proto/ballista.proto | 8 +
.../core/src/serde/physical_plan/from_proto.rs | 288 ++++++++++++---------
ballista/rust/core/src/serde/physical_plan/mod.rs | 160 ++++++++++--
.../rust/core/src/serde/physical_plan/to_proto.rs | 40 ++-
ballista/rust/executor/src/execution_loop.rs | 6 +-
ballista/rust/executor/src/executor_server.rs | 6 +-
6 files changed, 355 insertions(+), 153 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index ae938c6a7..1c50db6fd 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -306,9 +306,17 @@ message PhysicalExprNode {
// window expressions
PhysicalWindowExprNode window_expr = 15;
+
+ PhysicalScalarUdfNode scalar_udf = 16;
}
}
+message PhysicalScalarUdfNode {
+ string name = 1;
+ repeated PhysicalExprNode args = 2;
+ datafusion.ArrowType return_type = 4;
+}
+
message PhysicalAggregateExprNode {
datafusion.AggregateFunction aggr_function = 1;
PhysicalExprNode expr = 2;
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 13f98535e..cc7f866e9 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -18,12 +18,13 @@
//! Serde code to convert from protocol buffers to Rust data structures.
use std::convert::{TryFrom, TryInto};
+use std::ops::Deref;
use std::sync::Arc;
use crate::error::BallistaError;
+use crate::convert_required;
use crate::serde::{from_proto_binary_op, proto_error, protobuf};
-use crate::{convert_box_required, convert_required};
use chrono::{TimeZone, Utc};
use datafusion::datafusion_data_access::{
@@ -31,6 +32,7 @@ use datafusion::datafusion_data_access::{
};
use datafusion::datasource::listing::PartitionedFile;
use datafusion::execution::context::ExecutionProps;
+use datafusion::logical_plan::FunctionRegistry;
use datafusion::physical_plan::file_format::FileScanConfig;
@@ -54,125 +56,174 @@ impl From<&protobuf::PhysicalColumn> for Column {
}
}
-impl TryFrom<&protobuf::PhysicalExprNode> for Arc<dyn PhysicalExpr> {
- type Error = BallistaError;
+pub(crate) fn parse_physical_expr(
+ proto: &protobuf::PhysicalExprNode,
+ registry: &dyn FunctionRegistry,
+) -> Result<Arc<dyn PhysicalExpr>, BallistaError> {
+ let expr_type = proto
+ .expr_type
+ .as_ref()
+ .ok_or_else(|| proto_error("Unexpected empty physical expression"))?;
+
+ let pexpr: Arc<dyn PhysicalExpr> = match expr_type {
+ ExprType::Column(c) => {
+ let pcol: Column = c.into();
+ Arc::new(pcol)
+ }
+ ExprType::Literal(scalar) => {
+ Arc::new(Literal::new(convert_required!(scalar.value)?))
+ }
+ ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new(
+ parse_required_physical_box_expr(&binary_expr.l, registry, "left")?,
+ from_proto_binary_op(&binary_expr.op)?,
+ parse_required_physical_box_expr(&binary_expr.r, registry, "right")?,
+ )),
+ ExprType::AggregateExpr(_) => {
+ return Err(BallistaError::General(
+ "Cannot convert aggregate expr node to physical expression".to_owned(),
+ ));
+ }
+ ExprType::WindowExpr(_) => {
+ return Err(BallistaError::General(
+ "Cannot convert window expr node to physical expression".to_owned(),
+ ));
+ }
+ ExprType::Sort(_) => {
+ return Err(BallistaError::General(
+ "Cannot convert sort expr node to physical expression".to_owned(),
+ ));
+ }
+ ExprType::IsNullExpr(e) => Arc::new(IsNullExpr::new(
+ parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+ )),
+ ExprType::IsNotNullExpr(e) => Arc::new(IsNotNullExpr::new(
+ parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+ )),
+ ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_box_expr(
+ &e.expr, registry, "expr",
+ )?)),
+ ExprType::Negative(e) => Arc::new(NegativeExpr::new(
+ parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+ )),
+ ExprType::InList(e) => Arc::new(InListExpr::new(
+ parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+ e.list
+ .iter()
+ .map(|x| parse_physical_expr(x, registry))
+ .collect::<Result<Vec<_>, _>>()?,
+ e.negated,
+ )),
+ ExprType::Case(e) => Arc::new(CaseExpr::try_new(
+ e.expr
+ .as_ref()
+ .map(|e| parse_physical_expr(e.as_ref(), registry))
+ .transpose()?,
+ e.when_then_expr
+ .iter()
+ .map(|e| {
+ Ok((
+ parse_required_physical_expr(
+ &e.when_expr,
+ registry,
+ "when_expr",
+ )?,
+ parse_required_physical_expr(
+ &e.then_expr,
+ registry,
+ "then_expr",
+ )?,
+ ))
+ })
+ .collect::<Result<Vec<_>, BallistaError>>()?
+ .as_slice(),
+ e.else_expr
+ .as_ref()
+ .map(|e| parse_physical_expr(e.as_ref(), registry))
+ .transpose()?,
+ )?),
+ ExprType::Cast(e) => Arc::new(CastExpr::new(
+ parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+ convert_required!(e.arrow_type)?,
+ DEFAULT_DATAFUSION_CAST_OPTIONS,
+ )),
+ ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
+ parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+ convert_required!(e.arrow_type)?,
+ )),
+ ExprType::ScalarFunction(e) => {
+ let scalar_function = datafusion_proto::protobuf::ScalarFunction::from_i32(
+ e.fun,
+ )
+ .ok_or_else(|| {
+ proto_error(format!("Received an unknown scalar function: {}", e.fun,))
+ })?;
+
+ let args = e
+ .args
+ .iter()
+ .map(|x| parse_physical_expr(x, registry))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ // TODO Do not create new the ExecutionProps
+ let execution_props = ExecutionProps::new();
+
+ let fun_expr = functions::create_physical_fun(
+ &(&scalar_function).into(),
+ &execution_props,
+ )?;
+
+ Arc::new(ScalarFunctionExpr::new(
+ &e.name,
+ fun_expr,
+ args,
+ &convert_required!(e.return_type)?,
+ ))
+ }
+ ExprType::ScalarUdf(e) => {
+ let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun;
- fn try_from(expr: &protobuf::PhysicalExprNode) -> Result<Self, Self::Error> {
- let expr_type = expr
- .expr_type
- .as_ref()
- .ok_or_else(|| proto_error("Unexpected empty physical expression"))?;
+ let args = e
+ .args
+ .iter()
+ .map(|x| parse_physical_expr(x, registry))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ Arc::new(ScalarFunctionExpr::new(
+ e.name.as_str(),
+ scalar_fun,
+ args,
+ &convert_required!(e.return_type)?,
+ ))
+ }
+ };
- let pexpr: Arc<dyn PhysicalExpr> = match expr_type {
- ExprType::Column(c) => {
- let pcol: Column = c.into();
- Arc::new(pcol)
- }
- ExprType::Literal(scalar) => {
- Arc::new(Literal::new(convert_required!(scalar.value)?))
- }
- ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new(
- convert_box_required!(&binary_expr.l)?,
- from_proto_binary_op(&binary_expr.op)?,
- convert_box_required!(&binary_expr.r)?,
- )),
- ExprType::AggregateExpr(_) => {
- return Err(BallistaError::General(
- "Cannot convert aggregate expr node to physical expression"
- .to_owned(),
- ));
- }
- ExprType::WindowExpr(_) => {
- return Err(BallistaError::General(
- "Cannot convert window expr node to physical expression".to_owned(),
- ));
- }
- ExprType::Sort(_) => {
- return Err(BallistaError::General(
- "Cannot convert sort expr node to physical expression".to_owned(),
- ));
- }
- ExprType::IsNullExpr(e) => {
- Arc::new(IsNullExpr::new(convert_box_required!(e.expr)?))
- }
- ExprType::IsNotNullExpr(e) => {
- Arc::new(IsNotNullExpr::new(convert_box_required!(e.expr)?))
- }
- ExprType::NotExpr(e) => {
- Arc::new(NotExpr::new(convert_box_required!(e.expr)?))
- }
- ExprType::Negative(e) => {
- Arc::new(NegativeExpr::new(convert_box_required!(e.expr)?))
- }
- ExprType::InList(e) => Arc::new(InListExpr::new(
- convert_box_required!(e.expr)?,
- e.list
- .iter()
- .map(|x| x.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- e.negated,
- )),
- ExprType::Case(e) => Arc::new(CaseExpr::try_new(
- e.expr.as_ref().map(|e| e.as_ref().try_into()).transpose()?,
- e.when_then_expr
- .iter()
- .map(|e| {
- Ok((
- convert_required!(e.when_expr)?,
- convert_required!(e.then_expr)?,
- ))
- })
- .collect::<Result<Vec<_>, BallistaError>>()?
- .as_slice(),
- e.else_expr
- .as_ref()
- .map(|e| e.as_ref().try_into())
- .transpose()?,
- )?),
- ExprType::Cast(e) => Arc::new(CastExpr::new(
- convert_box_required!(e.expr)?,
- convert_required!(e.arrow_type)?,
- DEFAULT_DATAFUSION_CAST_OPTIONS,
- )),
- ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
- convert_box_required!(e.expr)?,
- convert_required!(e.arrow_type)?,
- )),
- ExprType::ScalarFunction(e) => {
- let scalar_function =
- datafusion_proto::protobuf::ScalarFunction::from_i32(e.fun)
- .ok_or_else(|| {
- proto_error(format!(
- "Received an unknown scalar function: {}",
- e.fun,
- ))
- })?;
-
- let args = e
- .args
- .iter()
- .map(|x| x.try_into())
- .collect::<Result<Vec<_>, _>>()?;
-
- // TODO Do not create new the ExecutionProps
- let execution_props = ExecutionProps::new();
-
- let fun_expr = functions::create_physical_fun(
- &(&scalar_function).into(),
- &execution_props,
- )?;
-
- Arc::new(ScalarFunctionExpr::new(
- &e.name,
- fun_expr,
- args,
- &convert_required!(e.return_type)?,
- ))
- }
- };
+ Ok(pexpr)
+}
- Ok(pexpr)
- }
+fn parse_required_physical_box_expr(
+ expr: &Option<Box<protobuf::PhysicalExprNode>>,
+ registry: &dyn FunctionRegistry,
+ field: &str,
+) -> Result<Arc<dyn PhysicalExpr>, BallistaError> {
+ expr.as_ref()
+ .map(|e| parse_physical_expr(e.as_ref(), registry))
+ .transpose()?
+ .ok_or_else(|| {
+ BallistaError::General(format!("Missing required field {:?}", field))
+ })
+}
+
+fn parse_required_physical_expr(
+ expr: &Option<protobuf::PhysicalExprNode>,
+ registry: &dyn FunctionRegistry,
+ field: &str,
+) -> Result<Arc<dyn PhysicalExpr>, BallistaError> {
+ expr.as_ref()
+ .map(|e| parse_physical_expr(e, registry))
+ .transpose()?
+ .ok_or_else(|| {
+ BallistaError::General(format!("Missing required field {:?}", field))
+ })
}
impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction {
@@ -210,13 +261,14 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun
pub fn parse_protobuf_hash_partitioning(
partitioning: Option<&protobuf::PhysicalHashRepartition>,
+ registry: &dyn FunctionRegistry,
) -> Result<Option<Partitioning>, BallistaError> {
match partitioning {
Some(hash_part) => {
let expr = hash_part
.hash_expr
.iter()
- .map(|e| e.try_into())
+ .map(|e| parse_physical_expr(e, registry))
.collect::<Result<Vec<Arc<dyn PhysicalExpr>>, _>>()?;
Ok(Some(Partitioning::Hash(
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 0361df677..3abb0713d 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -19,7 +19,9 @@ use crate::error::BallistaError;
use crate::execution_plans::{
ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec,
};
-use crate::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
+use crate::serde::physical_plan::from_proto::{
+ parse_physical_expr, parse_protobuf_hash_partitioning,
+};
use crate::serde::protobuf::physical_expr_node::ExprType;
use crate::serde::protobuf::physical_plan_node::PhysicalPlanType;
use crate::serde::protobuf::repartition_exec_node::PartitionMethod;
@@ -30,7 +32,7 @@ use crate::serde::{
byte_to_string, proto_error, protobuf, str_to_byte, AsExecutionPlan,
PhysicalExtensionCodec,
};
-use crate::{convert_box_required, convert_required, into_physical_plan, into_required};
+use crate::{convert_required, into_physical_plan, into_required};
use datafusion::arrow::compute::SortOptions;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::datafusion_data_access::object_store::local::LocalFileSystem;
@@ -112,7 +114,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
.expr
.iter()
.zip(projection.expr_name.iter())
- .map(|(expr, name)| Ok((expr.try_into()?, name.to_string())))
+ .map(|(expr, name)| Ok((parse_physical_expr(expr,registry)?, name.to_string())))
.collect::<Result<Vec<(Arc<dyn PhysicalExpr>, String)>, BallistaError>>(
)?;
Ok(Arc::new(ProjectionExec::try_new(exprs, input)?))
@@ -127,13 +129,14 @@ impl AsExecutionPlan for PhysicalPlanNode {
let predicate = filter
.expr
.as_ref()
+ .map(|expr| parse_physical_expr(expr, registry))
+ .transpose()?
.ok_or_else(|| {
BallistaError::General(
"filter (FilterExecNode) in PhysicalPlanNode is missing."
.to_owned(),
)
- })?
- .try_into()?;
+ })?;
Ok(Arc::new(FilterExec::try_new(predicate, input)?))
}
PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new(
@@ -184,7 +187,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
let expr = hash_part
.hash_expr
.iter()
- .map(|e| e.try_into())
+ .map(|e| parse_physical_expr(e, registry))
.collect::<Result<Vec<Arc<dyn PhysicalExpr>>, _>>()?;
Ok(Arc::new(RepartitionExec::try_new(
@@ -255,15 +258,29 @@ impl AsExecutionPlan for PhysicalPlanNode {
})?;
match expr_type {
- ExprType::WindowExpr(window_node) => Ok(create_window_expr(
- &convert_required!(window_node.window_function)?,
- name.to_owned(),
- &[convert_box_required!(window_node.expr)?],
- &[],
- &[],
- Some(WindowFrame::default()),
- &physical_schema,
- )?),
+ ExprType::WindowExpr(window_node) => {
+ let window_node_expr = window_node
+ .expr
+ .as_ref()
+ .map(|e| parse_physical_expr(e.as_ref(), registry))
+ .transpose()?
+ .ok_or_else(|| {
+ proto_error(
+ "missing window_node expr expression"
+ .to_string(),
+ )
+ })?;
+
+ Ok(create_window_expr(
+ &convert_required!(window_node.window_function)?,
+ name.to_owned(),
+ &[window_node_expr],
+ &[],
+ &[],
+ Some(WindowFrame::default()),
+ &physical_schema,
+ )?)
+ }
_ => Err(BallistaError::General(
"Invalid expression for WindowAggrExec".to_string(),
)),
@@ -302,7 +319,8 @@ impl AsExecutionPlan for PhysicalPlanNode {
.iter()
.zip(hash_agg.group_expr_name.iter())
.map(|(expr, name)| {
- expr.try_into().map(|expr| (expr, name.to_string()))
+ parse_physical_expr(expr, registry)
+ .map(|expr| (expr, name.to_string()))
})
.collect::<Result<Vec<_>, _>>()?;
@@ -342,10 +360,15 @@ impl AsExecutionPlan for PhysicalPlanNode {
},
)?;
+ let input_phy_expr = agg_node.expr.as_ref()
+ .map(|e| parse_physical_expr(e.as_ref(), registry))
+ .transpose()?
+ .ok_or_else(|| proto_error("missing aggregate expression".to_string()))?;
+
Ok(create_aggregate_expr(
&aggr_function.into(),
false,
- &[convert_box_required!(agg_node.expr)?],
+ &[input_phy_expr],
&physical_schema,
name.to_string(),
)?)
@@ -453,6 +476,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
let output_partitioning = parse_protobuf_hash_partitioning(
shuffle_writer.output_partitioning.as_ref(),
+ registry,
)?;
Ok(Arc::new(ShuffleWriterExec::try_new(
@@ -508,7 +532,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
})?
.as_ref();
Ok(PhysicalSortExpr {
- expr: expr.try_into()?,
+ expr: parse_physical_expr(expr,registry)?,
options: SortOptions {
descending: !sort_expr.asc,
nulls_first: sort_expr.nulls_first,
@@ -1015,6 +1039,14 @@ mod roundtrip_tests {
use std::sync::Arc;
use crate::serde::{AsExecutionPlan, BallistaCodec};
+ use datafusion::arrow::array::ArrayRef;
+ use datafusion::execution::context::ExecutionProps;
+ use datafusion::logical_plan::create_udf;
+ use datafusion::physical_plan::functions;
+ use datafusion::physical_plan::functions::{
+ make_scalar_function, BuiltinScalarFunction, ScalarFunctionExpr, Volatility,
+ };
+ use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::{
arrow::{
compute::kernels::sort::SortOptions,
@@ -1069,6 +1101,33 @@ mod roundtrip_tests {
Ok(())
}
+ fn roundtrip_test_with_context(
+ exec_plan: Arc<dyn ExecutionPlan>,
+ ctx: SessionContext,
+ ) -> Result<()> {
+ let codec: BallistaCodec<LogicalPlanNode, PhysicalPlanNode> =
+ BallistaCodec::default();
+ let proto: protobuf::PhysicalPlanNode =
+ protobuf::PhysicalPlanNode::try_from_physical_plan(
+ exec_plan.clone(),
+ codec.physical_extension_codec(),
+ )
+ .expect("to proto");
+ let runtime = ctx.runtime_env();
+ let result_exec_plan: Arc<dyn ExecutionPlan> = proto
+ .try_into_physical_plan(
+ &ctx,
+ runtime.deref(),
+ codec.physical_extension_codec(),
+ )
+ .expect("from proto");
+ assert_eq!(
+ format!("{:?}", exec_plan),
+ format!("{:?}", result_exec_plan)
+ );
+ Ok(())
+ }
+
#[test]
fn roundtrip_empty() -> Result<()> {
roundtrip_test(Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))))
@@ -1241,4 +1300,69 @@ mod roundtrip_tests {
let predicate = datafusion::prelude::col("col").eq(datafusion::prelude::lit("1"));
roundtrip_test(Arc::new(ParquetExec::new(scan_config, Some(predicate))))
}
+
+ #[test]
+ fn roundtrip_builtin_scalar_function() -> Result<()> {
+ let field_a = Field::new("a", DataType::Int64, false);
+ let field_b = Field::new("b", DataType::Int64, false);
+ let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
+ let input = Arc::new(EmptyExec::new(false, schema.clone()));
+
+ let execution_props = ExecutionProps::new();
+
+ let fun_expr = functions::create_physical_fun(
+ &BuiltinScalarFunction::Abs,
+ &execution_props,
+ )?;
+
+ let expr = ScalarFunctionExpr::new(
+ "abs",
+ fun_expr,
+ vec![col("a", &schema)?],
+ &DataType::Int64,
+ );
+
+ let project =
+ ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?;
+
+ roundtrip_test(Arc::new(project))
+ }
+
+ #[test]
+ fn roundtrip_scalar_udf() -> Result<()> {
+ let field_a = Field::new("a", DataType::Int64, false);
+ let field_b = Field::new("b", DataType::Int64, false);
+ let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
+ let input = Arc::new(EmptyExec::new(false, schema.clone()));
+
+ let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);
+
+ let scalar_fn = make_scalar_function(fn_impl);
+
+ let udf = create_udf(
+ "dummy",
+ vec![DataType::Int64],
+ Arc::new(DataType::Int64),
+ Volatility::Immutable,
+ scalar_fn.clone(),
+ );
+
+ let expr = ScalarFunctionExpr::new(
+ "dummy",
+ scalar_fn,
+ vec![col("a", &schema)?],
+ &DataType::Int64,
+ );
+
+ let project =
+ ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?;
+
+ let mut ctx = SessionContext::new();
+
+ ctx.register_udf(udf);
+
+ roundtrip_test_with_context(Arc::new(project), ctx)
+ }
}
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 9a63762a4..1a1276ec1 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -286,24 +286,38 @@ impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::PhysicalExprNode {
)),
})
} else if let Some(expr) = expr.downcast_ref::<ScalarFunctionExpr>() {
- let fun: BuiltinScalarFunction =
- BuiltinScalarFunction::from_str(expr.name())?;
- let fun: datafusion_proto::protobuf::ScalarFunction = (&fun).try_into()?;
let args: Vec<protobuf::PhysicalExprNode> = expr
.args()
.iter()
.map(|e| e.to_owned().try_into())
.collect::<Result<Vec<_>, _>>()?;
- Ok(protobuf::PhysicalExprNode {
- expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction(
- protobuf::PhysicalScalarFunctionNode {
- name: expr.name().to_string(),
- fun: fun.into(),
- args,
- return_type: Some(expr.return_type().into()),
- },
- )),
- })
+ if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) {
+ let fun: datafusion_proto::protobuf::ScalarFunction =
+ (&fun).try_into()?;
+
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(
+ protobuf::physical_expr_node::ExprType::ScalarFunction(
+ protobuf::PhysicalScalarFunctionNode {
+ name: expr.name().to_string(),
+ fun: fun.into(),
+ args,
+ return_type: Some(expr.return_type().into()),
+ },
+ ),
+ ),
+ })
+ } else {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf(
+ protobuf::PhysicalScalarUdfNode {
+ name: expr.name().to_string(),
+ args,
+ return_type: Some(expr.return_type().into()),
+ },
+ )),
+ })
+ }
} else {
Err(BallistaError::General(format!(
"physical_plan::to_proto() unsupported expression {:?}",
diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs
index bb4edb6b9..06128f8db 100644
--- a/ballista/rust/executor/src/execution_loop.rs
+++ b/ballista/rust/executor/src/execution_loop.rs
@@ -161,8 +161,10 @@ async fn run_received_tasks<T: 'static + AsLogicalPlan, U: 'static + AsExecution
)
})?;
- let shuffle_output_partitioning =
- parse_protobuf_hash_partitioning(task.output_partitioning.as_ref())?;
+ let shuffle_output_partitioning = parse_protobuf_hash_partitioning(
+ task.output_partitioning.as_ref(),
+ task_context.as_ref(),
+ )?;
tokio::spawn(async move {
let execution_result = executor
diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs
index 81910f264..11a5c7552 100644
--- a/ballista/rust/executor/src/executor_server.rs
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -216,8 +216,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
)
})?;
- let shuffle_output_partitioning =
- parse_protobuf_hash_partitioning(task.output_partitioning.as_ref())?;
+ let shuffle_output_partitioning = parse_protobuf_hash_partitioning(
+ task.output_partitioning.as_ref(),
+ task_context.as_ref(),
+ )?;
let execution_result = self
.executor