You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/05/08 20:33:48 UTC
[arrow-datafusion] branch main updated: refactor: Expr::ScalarUDF to use a struct (#6284)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 6f53ff24c3 refactor: Expr::ScalarUDF to use a struct (#6284)
6f53ff24c3 is described below
commit 6f53ff24c3a2dc97952fa68be87ba101777aa9c0
Author: jakevin <ja...@gmail.com>
AuthorDate: Tue May 9 04:33:41 2023 +0800
refactor: Expr::ScalarUDF to use a struct (#6284)
---
datafusion/core/src/datasource/listing/helpers.rs | 3 +-
datafusion/core/src/physical_plan/planner.rs | 4 +--
datafusion/expr/src/expr.rs | 32 +++++++++++++++-------
datafusion/expr/src/expr_schema.rs | 8 +++---
datafusion/expr/src/tree_node/expr.rs | 11 ++++----
datafusion/expr/src/udf.rs | 5 +---
datafusion/expr/src/utils.rs | 2 +-
datafusion/optimizer/src/analyzer/type_coercion.rs | 25 ++++++++---------
datafusion/optimizer/src/push_down_filter.rs | 2 +-
.../src/simplify_expressions/expr_simplifier.rs | 30 ++++++++------------
datafusion/physical-expr/src/planner.rs | 4 +--
datafusion/proto/src/logical_plan/from_proto.rs | 9 +++---
datafusion/proto/src/logical_plan/mod.rs | 9 +++---
datafusion/proto/src/logical_plan/to_proto.rs | 4 +--
datafusion/sql/src/expr/function.rs | 4 +--
datafusion/sql/src/utils.rs | 17 ++++++------
16 files changed, 84 insertions(+), 85 deletions(-)
diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs
index 2a09445665..f4995f7411 100644
--- a/datafusion/core/src/datasource/listing/helpers.rs
+++ b/datafusion/core/src/datasource/listing/helpers.rs
@@ -41,6 +41,7 @@ use datafusion_common::{
cast::{as_date64_array, as_string_array, as_uint64_array},
Column, DataFusionError,
};
+use datafusion_expr::expr::ScalarUDF;
use datafusion_expr::{Expr, Volatility};
use object_store::path::Path;
use object_store::{ObjectMeta, ObjectStore};
@@ -105,7 +106,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
}
}
}
- Expr::ScalarUDF { fun, .. } => {
+ Expr::ScalarUDF(ScalarUDF { fun, .. }) => {
match fun.signature.volatility {
Volatility::Immutable => VisitRecursion::Continue,
// TODO: Stable functions could be `applicable`, but that would require access to the context
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 88e0d6e3d0..03f981f54a 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -63,7 +63,7 @@ use async_trait::async_trait;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::{
self, AggregateFunction, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet,
- Like, TryCast, WindowFunction,
+ Like, ScalarUDF, TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
@@ -187,7 +187,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
Expr::ScalarFunction(func) => {
create_function_physical_name(&func.fun.to_string(), false, &func.args)
}
- Expr::ScalarUDF { fun, args, .. } => {
+ Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_physical_name(&fun.name, false, args)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index f508efb2ab..3674c1779d 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -26,7 +26,6 @@ use crate::window_frame;
use crate::window_function;
use crate::AggregateUDF;
use crate::Operator;
-use crate::ScalarUDF;
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_common::{plan_err, Column};
@@ -154,12 +153,7 @@ pub enum Expr {
/// Represents the call of a built-in scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
/// Represents the call of a user-defined scalar function with arguments.
- ScalarUDF {
- /// The function
- fun: Arc<ScalarUDF>,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- },
+ ScalarUDF(ScalarUDF),
/// Represents the call of an aggregate built-in function with arguments.
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
@@ -364,6 +358,22 @@ impl ScalarFunction {
}
}
+/// ScalarUDF expression
+#[derive(Clone, PartialEq, Eq, Hash)]
+pub struct ScalarUDF {
+ /// The function
+ pub fun: Arc<crate::ScalarUDF>,
+ /// List of expressions to feed to the functions as arguments
+ pub args: Vec<Expr>,
+}
+
+impl ScalarUDF {
+ /// Create a new ScalarUDF expression
+ pub fn new(fun: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
+ Self { fun, args }
+ }
+}
+
/// Returns the field of a [`arrow::array::ListArray`] or [`arrow::array::StructArray`] by key
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct GetIndexedField {
@@ -605,7 +615,7 @@ impl Expr {
Expr::QualifiedWildcard { .. } => "QualifiedWildcard",
Expr::ScalarFunction(..) => "ScalarFunction",
Expr::ScalarSubquery { .. } => "ScalarSubquery",
- Expr::ScalarUDF { .. } => "ScalarUDF",
+ Expr::ScalarUDF(..) => "ScalarUDF",
Expr::ScalarVariable(..) => "ScalarVariable",
Expr::Sort { .. } => "Sort",
Expr::TryCast { .. } => "TryCast",
@@ -941,7 +951,7 @@ impl fmt::Debug for Expr {
Expr::ScalarFunction(func) => {
fmt_function(f, &func.fun.to_string(), false, &func.args, false)
}
- Expr::ScalarUDF { fun, ref args, .. } => {
+ Expr::ScalarUDF(ScalarUDF { fun, args }) => {
fmt_function(f, &fun.name, false, args, false)
}
Expr::WindowFunction(WindowFunction {
@@ -1300,7 +1310,9 @@ fn create_name(e: &Expr) -> Result<String> {
Expr::ScalarFunction(func) => {
create_function_name(&func.fun.to_string(), false, &func.args)
}
- Expr::ScalarUDF { fun, args, .. } => create_function_name(&fun.name, false, args),
+ Expr::ScalarUDF(ScalarUDF { fun, args }) => {
+ create_function_name(&fun.name, false, args)
+ }
Expr::WindowFunction(WindowFunction {
fun,
args,
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index 4cdf8debca..fdb8a34aba 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -17,8 +17,8 @@
use super::{Between, Expr, Like};
use crate::expr::{
- AggregateFunction, BinaryExpr, Cast, GetIndexedField, ScalarFunction, Sort, TryCast,
- WindowFunction,
+ AggregateFunction, BinaryExpr, Cast, GetIndexedField, ScalarFunction, ScalarUDF,
+ Sort, TryCast, WindowFunction,
};
use crate::field_util::get_indexed_field;
use crate::type_coercion::binary::get_result_type;
@@ -95,7 +95,7 @@ impl ExprSchemable for Expr {
}
Expr::Cast(Cast { data_type, .. })
| Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
- Expr::ScalarUDF { fun, args } => {
+ Expr::ScalarUDF(ScalarUDF { fun, args }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
@@ -218,7 +218,7 @@ impl ExprSchemable for Expr {
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::ScalarFunction(..)
- | Expr::ScalarUDF { .. }
+ | Expr::ScalarUDF(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. } => Ok(true),
diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs
index 37df3ce201..5a3442394d 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -19,7 +19,7 @@
use crate::expr::{
AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet,
- Like, ScalarFunction, Sort, TryCast, WindowFunction,
+ Like, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction,
};
use crate::Expr;
use datafusion_common::tree_node::VisitRecursion;
@@ -51,7 +51,7 @@ impl TreeNode for Expr {
}
Expr::GroupingSet(GroupingSet::Rollup(exprs))
| Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(),
- Expr::ScalarFunction (ScalarFunction{ args, .. } )| Expr::ScalarUDF { args, .. } => {
+ Expr::ScalarFunction (ScalarFunction{ args, .. } )| Expr::ScalarUDF(ScalarUDF { args, .. }) => {
args.clone()
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
@@ -270,10 +270,9 @@ impl TreeNode for Expr {
Expr::ScalarFunction(ScalarFunction { args, fun }) => Expr::ScalarFunction(
ScalarFunction::new(fun, transform_vec(args, &mut transform)?),
),
- Expr::ScalarUDF { args, fun } => Expr::ScalarUDF {
- args: transform_vec(args, &mut transform)?,
- fun,
- },
+ Expr::ScalarUDF(ScalarUDF { args, fun }) => {
+ Expr::ScalarUDF(ScalarUDF::new(fun, transform_vec(args, &mut transform)?))
+ }
Expr::WindowFunction(WindowFunction {
args,
fun,
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index ffb98bc7a1..be6c90aa59 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -87,9 +87,6 @@ impl ScalarUDF {
/// creates a logical expression with a call of the UDF
/// This utility allows using the UDF without requiring access to the registry.
pub fn call(&self, args: Vec<Expr>) -> Expr {
- Expr::ScalarUDF {
- fun: Arc::new(self.clone()),
- args,
- }
+ Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args))
}
}
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 41c40352d9..7698297b79 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -296,7 +296,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
| Expr::TryCast { .. }
| Expr::Sort { .. }
| Expr::ScalarFunction(..)
- | Expr::ScalarUDF { .. }
+ | Expr::ScalarUDF(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 759275dccc..0a411f828a 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -25,7 +25,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{
- self, Between, BinaryExpr, Case, Like, ScalarFunction, WindowFunction,
+ self, Between, BinaryExpr, Case, Like, ScalarFunction, ScalarUDF, WindowFunction,
};
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::Subquery;
@@ -369,16 +369,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
let case = coerce_case_expression(case, &self.schema)?;
Ok(Expr::Case(case))
}
- Expr::ScalarUDF { fun, args } => {
+ Expr::ScalarUDF(ScalarUDF { fun, args }) => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
- let expr = Expr::ScalarUDF {
- fun,
- args: new_expr,
- };
+ let expr = Expr::ScalarUDF(ScalarUDF::new(fun, new_expr));
Ok(expr)
}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
@@ -814,15 +811,15 @@ mod test {
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation =
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
- let udf = Expr::ScalarUDF {
- fun: Arc::new(ScalarUDF::new(
+ let udf = Expr::ScalarUDF(expr::ScalarUDF::new(
+ Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
&return_type,
&fun,
)),
- args: vec![lit(123_i32)],
- };
+ vec![lit(123_i32)],
+ ));
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
let expected =
"Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation";
@@ -835,15 +832,15 @@ mod test {
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!());
- let udf = Expr::ScalarUDF {
- fun: Arc::new(ScalarUDF::new(
+ let udf = Expr::ScalarUDF(expr::ScalarUDF::new(
+ Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Int32], Volatility::Stable),
&return_type,
&fun,
)),
- args: vec![lit("Apple")],
- };
+ vec![lit("Apple")],
+ ));
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "")
.err()
diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs
index 33ea7d1867..2f3524a8d2 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -160,7 +160,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
| Expr::InSubquery { .. }
| Expr::ScalarSubquery(_)
| Expr::OuterReferenceColumn(_, _)
- | Expr::ScalarUDF { .. } => {
+ | Expr::ScalarUDF(..) => {
is_evaluate = false;
Ok(VisitRecursion::Stop)
}
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 5e8248c2b8..a0c5ff6636 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -30,7 +30,7 @@ use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
- and, lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Like,
+ and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Like,
Volatility,
};
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
@@ -269,7 +269,9 @@ impl<'a> ConstEvaluator<'a> {
Expr::ScalarFunction(ScalarFunction { fun, .. }) => {
Self::volatility_ok(fun.volatility())
}
- Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility),
+ Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => {
+ Self::volatility_ok(fun.signature.volatility)
+ }
Expr::Literal(_)
| Expr::BinaryExpr { .. }
| Expr::Not(_)
@@ -1426,32 +1428,24 @@ mod tests {
// immutable UDF should get folded
// udf_add(1+2, 30+40) --> 73
- let expr = Expr::ScalarUDF {
- args: args.clone(),
- fun: make_udf_add(Volatility::Immutable),
- };
+ let expr = Expr::ScalarUDF(expr::ScalarUDF::new(
+ make_udf_add(Volatility::Immutable),
+ args.clone(),
+ ));
test_evaluate(expr, lit(73));
// stable UDF should be entirely folded
// udf_add(1+2, 30+40) --> 73
let fun = make_udf_add(Volatility::Stable);
- let expr = Expr::ScalarUDF {
- args: args.clone(),
- fun: Arc::clone(&fun),
- };
+ let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args.clone()));
test_evaluate(expr, lit(73));
// volatile UDF should have args folded
// udf_add(1+2, 30+40) --> udf_add(3, 70)
let fun = make_udf_add(Volatility::Volatile);
- let expr = Expr::ScalarUDF {
- args,
- fun: Arc::clone(&fun),
- };
- let expected_expr = Expr::ScalarUDF {
- args: folded_args,
- fun: Arc::clone(&fun),
- };
+ let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), args));
+ let expected_expr =
+ Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun), folded_args));
test_evaluate(expr, expected_expr);
}
diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs
index a1fdce4f3a..c90313c80d 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -27,7 +27,7 @@ use crate::{
};
use arrow::datatypes::{DataType, Schema};
use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
-use datafusion_expr::expr::{Cast, ScalarFunction};
+use datafusion_expr::expr::{Cast, ScalarFunction, ScalarUDF};
use datafusion_expr::{
binary_expr, Between, BinaryExpr, Expr, GetIndexedField, Like, Operator, TryCast,
};
@@ -397,7 +397,7 @@ pub fn create_physical_expr(
execution_props,
)
}
- Expr::ScalarUDF { fun, args } => {
+ Expr::ScalarUDF(ScalarUDF { fun, args }) => {
let mut physical_args = vec![];
for e in args {
physical_args.push(create_physical_expr(
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs
index 23fae1cd1d..65407bb555 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1367,13 +1367,12 @@ pub fn parse_expr(
}
ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args }) => {
let scalar_fn = registry.udf(fun_name.as_str())?;
- Ok(Expr::ScalarUDF {
- fun: scalar_fn,
- args: args
- .iter()
+ Ok(Expr::ScalarUDF(expr::ScalarUDF::new(
+ scalar_fn,
+ args.iter()
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, Error>>()?,
- })
+ )))
}
ExprType::AggregateUdfExpr(pb) => {
let agg_fn = registry.udaf(pb.fun_name.as_str())?;
diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs
index 04e30ed8d8..23a6766cc3 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -1416,7 +1416,8 @@ mod roundtrip_tests {
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{
- self, Between, BinaryExpr, Case, Cast, GroupingSet, Like, ScalarFunction, Sort,
+ self, Between, BinaryExpr, Case, Cast, GroupingSet, Like, ScalarFunction,
+ ScalarUDF, Sort,
};
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore};
use datafusion_expr::{
@@ -2627,10 +2628,8 @@ mod roundtrip_tests {
scalar_fn,
);
- let test_expr = Expr::ScalarUDF {
- fun: Arc::new(udf.clone()),
- args: vec![lit("")],
- };
+ let test_expr =
+ Expr::ScalarUDF(ScalarUDF::new(Arc::new(udf.clone()), vec![lit("")]));
let ctx = SessionContext::new();
ctx.register_udf(udf);
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs
index 64735568b4..d35d214513 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -37,7 +37,7 @@ use arrow::datatypes::{
use datafusion_common::{Column, DFField, DFSchemaRef, OwnedTableReference, ScalarValue};
use datafusion_expr::expr::{
self, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, ScalarFunction,
- Sort,
+ ScalarUDF, Sort,
};
use datafusion_expr::{
logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
@@ -695,7 +695,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
)),
}
}
- Expr::ScalarUDF { fun, args } => Self {
+ Expr::ScalarUDF(ScalarUDF { fun, args }) => Self {
expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
fun_name: fun.name.clone(),
args: args
diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs
index 996221a61e..fedb7eaacd 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -17,7 +17,7 @@
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{DFSchema, DataFusionError, Result};
-use datafusion_expr::expr::ScalarFunction;
+use datafusion_expr::expr::{ScalarFunction, ScalarUDF};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::window_frame::regularize;
use datafusion_expr::{
@@ -121,7 +121,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
if let Some(fm) = self.schema_provider.get_function_meta(&name) {
let args =
self.function_args_to_expr(function.args, schema, planner_context)?;
- return Ok(Expr::ScalarUDF { fun: fm, args });
+ return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args)));
}
// User defined aggregate functions
diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs
index 11fc0a58cf..df8891093a 100644
--- a/datafusion/sql/src/utils.rs
+++ b/datafusion/sql/src/utils.rs
@@ -23,7 +23,7 @@ use sqlparser::ast::Ident;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{
AggregateFunction, Between, BinaryExpr, Case, GetIndexedField, GroupingSet, Like,
- ScalarFunction, WindowFunction,
+ ScalarFunction, ScalarUDF, WindowFunction,
};
use datafusion_expr::expr::{Cast, Sort};
use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
@@ -303,13 +303,14 @@ where
.collect::<Result<Vec<Expr>>>()?,
)))
}
- Expr::ScalarUDF { fun, args } => Ok(Expr::ScalarUDF {
- fun: fun.clone(),
- args: args
- .iter()
- .map(|arg| clone_with_replacement(arg, replacement_fn))
- .collect::<Result<Vec<Expr>>>()?,
- }),
+ Expr::ScalarUDF(ScalarUDF { fun, args }) => {
+ Ok(Expr::ScalarUDF(ScalarUDF::new(
+ fun.clone(),
+ args.iter()
+ .map(|arg| clone_with_replacement(arg, replacement_fn))
+ .collect::<Result<Vec<Expr>>>()?,
+ )))
+ }
Expr::Negative(nested_expr) => Ok(Expr::Negative(Box::new(
clone_with_replacement(nested_expr, replacement_fn)?,
))),