You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/05/30 17:39:30 UTC
[arrow-datafusion] branch main updated: Substrait: Support Expr::ScalarFunction (#6471)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new fd9bdad453 Substrait: Support Expr::ScalarFunction (#6471)
fd9bdad453 is described below
commit fd9bdad45397d195334d2cb797ca50f3cc5928c3
Author: Jay Zhan <ja...@gmail.com>
AuthorDate: Wed May 31 01:39:23 2023 +0800
Substrait: Support Expr::ScalarFunction (#6471)
* Add support for abs
Signed-off-by: jayzhan211 <ja...@gmail.com>
* fmt
Signed-off-by: jayzhan211 <ja...@gmail.com>
* Add Op or ScalarFunction for two arguments cases
Signed-off-by: jayzhan211 <ja...@gmail.com>
* refactor name
Signed-off-by: jayzhan211 <ja...@gmail.com>
* address comment
Signed-off-by: jayzhan211 <ja...@gmail.com>
---------
Signed-off-by: jayzhan211 <ja...@gmail.com>
---
datafusion/substrait/src/logical_plan/consumer.rs | 129 ++++++++++++++++-----
datafusion/substrait/src/logical_plan/producer.rs | 28 ++++-
.../substrait/tests/roundtrip_logical_plan.rs | 15 +++
3 files changed, 143 insertions(+), 29 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs
index 9343d31130..f914b62a14 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -19,8 +19,8 @@ use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::common::{DFField, DFSchema, DFSchemaRef};
use datafusion::logical_expr::{
- aggregate_function, window_function::find_df_window_func, BinaryExpr, Case, Expr,
- LogicalPlan, Operator,
+ aggregate_function, window_function::find_df_window_func, BinaryExpr,
+ BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
};
use datafusion::logical_expr::{build_join_schema, Extension, LogicalPlanBuilder};
use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
@@ -64,6 +64,11 @@ use crate::variation_const::{
TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
};
+enum ScalarFunctionType {
+ Builtin(BuiltinScalarFunction),
+ Op(Operator),
+}
+
pub fn name_to_op(name: &str) -> Result<Operator> {
match name {
"equal" => Ok(Operator::Eq),
@@ -97,6 +102,20 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
}
}
+fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
+ if let Ok(op) = name_to_op(name) {
+ return Ok(ScalarFunctionType::Op(op));
+ }
+
+ if let Ok(fun) = BuiltinScalarFunction::from_str(name) {
+ return Ok(ScalarFunctionType::Builtin(fun));
+ }
+
+ Err(DataFusionError::NotImplemented(format!(
+ "Unsupported function name: {name:?}"
+ )))
+}
+
/// Convert Substrait Plan to DataFusion DataFrame
pub async fn from_substrait_plan(
ctx: &mut SessionContext,
@@ -727,38 +746,92 @@ pub async fn from_substrait_rex(
else_expr,
})))
}
- Some(RexType::ScalarFunction(f)) => {
- assert!(f.arguments.len() == 2);
- let op = match extensions.get(&f.function_reference) {
- Some(fname) => name_to_op(fname),
- None => Err(DataFusionError::NotImplemented(format!(
- "Aggregated function not found: function reference = {:?}",
- f.function_reference
- ))),
- };
- match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
+ Some(RexType::ScalarFunction(f)) => match f.arguments.len() {
+ // BinaryExpr or ScalarFunction
+ 2 => match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
(Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
- Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
- left: Box::new(
- from_substrait_rex(l, input_schema, extensions)
- .await?
- .as_ref()
- .clone(),
- ),
- op: op?,
- right: Box::new(
- from_substrait_rex(r, input_schema, extensions)
- .await?
- .as_ref()
- .clone(),
- ),
- })))
+ let op_or_fun = match extensions.get(&f.function_reference) {
+ Some(fname) => name_to_op_or_scalar_function(fname),
+ None => Err(DataFusionError::NotImplemented(format!(
+ "Aggregated function not found: function reference = {:?}",
+ f.function_reference
+ ))),
+ };
+ match op_or_fun {
+ Ok(ScalarFunctionType::Op(op)) => {
+ return Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(
+ from_substrait_rex(l, input_schema, extensions)
+ .await?
+ .as_ref()
+ .clone(),
+ ),
+ op,
+ right: Box::new(
+ from_substrait_rex(r, input_schema, extensions)
+ .await?
+ .as_ref()
+ .clone(),
+ ),
+ })))
+ }
+ Ok(ScalarFunctionType::Builtin(fun)) => {
+ Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
+ fun,
+ args: vec![
+ from_substrait_rex(l, input_schema, extensions)
+ .await?
+ .as_ref()
+ .clone(),
+ from_substrait_rex(r, input_schema, extensions)
+ .await?
+ .as_ref()
+ .clone(),
+ ],
+ })))
+ }
+ Err(e) => Err(e),
+ }
}
(l, r) => Err(DataFusionError::NotImplemented(format!(
"Invalid arguments for binary expression: {l:?} and {r:?}"
))),
+ },
+ // ScalarFunction
+ _ => {
+ let fun = match extensions.get(&f.function_reference) {
+ Some(fname) => BuiltinScalarFunction::from_str(fname),
+ None => Err(DataFusionError::NotImplemented(format!(
+ "Aggregated function not found: function reference = {:?}",
+ f.function_reference
+ ))),
+ };
+
+ let mut args: Vec<Expr> = vec![];
+ for arg in f.arguments.iter() {
+ match &arg.arg_type {
+ Some(ArgType::Value(e)) => {
+ args.push(
+ from_substrait_rex(e, input_schema, extensions)
+ .await?
+ .as_ref()
+ .clone(),
+ );
+ }
+ e => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Invalid arguments for scalar function: {e:?}"
+ )))
+ }
+ }
+ }
+
+ Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
+ fun: fun?,
+ args,
+ })))
}
- }
+ },
Some(RexType::Literal(lit)) => {
let scalar_value = from_substrait_literal(lit)?;
Ok(Arc::new(Expr::Literal(scalar_value)))
diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs
index 872760390a..785bfa4ea6 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -28,7 +28,9 @@ use datafusion::{
use datafusion::common::DFSchemaRef;
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
-use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort, WindowFunction};
+use datafusion::logical_expr::expr::{
+ BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction,
+};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::{binary_expr, Expr};
use prost_types::Any as ProtoAny;
@@ -564,6 +566,7 @@ pub fn make_binary_op_scalar_func(
}
/// Convert DataFusion Expr to Substrait Rex
+#[allow(deprecated)]
pub fn to_substrait_rex(
expr: &Expr,
schema: &DFSchemaRef,
@@ -573,6 +576,29 @@ pub fn to_substrait_rex(
),
) -> Result<Expression> {
match expr {
+ Expr::ScalarFunction(DFScalarFunction { fun, args }) => {
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in args {
+ arguments.push(FunctionArgument {
+ arg_type: Some(ArgType::Value(to_substrait_rex(
+ arg,
+ schema,
+ extension_info,
+ )?)),
+ });
+ }
+ let function_name = fun.to_string().to_lowercase();
+ let function_anchor = _register_function(function_name, extension_info);
+ Ok(Expression {
+ rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+ function_reference: function_anchor,
+ arguments,
+ output_type: None,
+ args: vec![],
+ options: vec![],
+ })),
+ })
+ }
Expr::Between(Between {
expr,
negated,
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 77a02cbf47..8cdf89b294 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -278,6 +278,21 @@ mod tests {
.await
}
+ #[tokio::test]
+ async fn simple_scalar_function_abs() -> Result<()> {
+ roundtrip("SELECT ABS(a) FROM data").await
+ }
+
+ #[tokio::test]
+ async fn simple_scalar_function_pow() -> Result<()> {
+ roundtrip("SELECT POW(a, 2) FROM data").await
+ }
+
+ #[tokio::test]
+ async fn simple_scalar_function_substr() -> Result<()> {
+ roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await
+ }
+
#[tokio::test]
async fn case_without_base_expression() -> Result<()> {
roundtrip(