You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/12 16:08:43 UTC

[arrow-datafusion] branch main updated: Prioritize UDF over scalar built-in function in case of function name… (#6601)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c96949ea8c Prioritize UDF over scalar built-in function in case of function name… (#6601)
c96949ea8c is described below

commit c96949ea8c69d3a43ed95468cf3fef3a505b9633
Author: epsio-banay <12...@users.noreply.github.com>
AuthorDate: Mon Jun 12 19:08:37 2023 +0300

    Prioritize UDF over scalar built-in function in case of function name… (#6601)
    
    * Prioritize UDF over scalar built-in function in case of function name collision
    
    * Remove prioritize_udf config flag (assume true by default)
    
    * Add test scalar_udf_override_built_in_scalar_function
---
 datafusion/core/tests/sql/udf.rs    | 33 +++++++++++++++++++++++++++++++++
 datafusion/sql/src/expr/function.rs | 16 ++++++++--------
 2 files changed, 41 insertions(+), 8 deletions(-)

diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs
index a31028fd71..0ecd5d0fde 100644
--- a/datafusion/core/tests/sql/udf.rs
+++ b/datafusion/core/tests/sql/udf.rs
@@ -179,6 +179,39 @@ async fn scalar_udf_zero_params() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn scalar_udf_override_built_in_scalar_function() -> Result<()> {
+    let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+    let batch = RecordBatch::try_new(
+        Arc::new(schema.clone()),
+        vec![Arc::new(Int32Array::from(vec![-100]))],
+    )?;
+    let ctx = SessionContext::new();
+
+    ctx.register_batch("t", batch)?;
+    // register a UDF that has the same name as a builtin function (abs) and just returns 1 regardless of input
+    ctx.register_udf(create_udf(
+        "abs",
+        vec![DataType::Int32],
+        Arc::new(DataType::Int32),
+        Volatility::Immutable,
+        Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))),
+    ));
+
+    // Make sure that the UDF is used instead of the built-in function
+    let result = plan_and_collect(&ctx, "select abs(a) a from t").await?;
+    let expected = vec![
+        "+---+", //
+        "| a |", //
+        "+---+", //
+        "| 1 |", //
+        "+---+", //
+    ];
+    assert_batches_eq!(expected, &result);
+    Ok(())
+}
+
 /// tests the creation, registration and usage of a UDAF
 #[tokio::test]
 async fn simple_udaf() -> Result<()> {
diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs
index 0fb6b75547..0289e80411 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -47,6 +47,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             crate::utils::normalize_ident(function.name.0[0].clone())
         };
 
+        // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function
+        if let Some(fm) = self.schema_provider.get_function_meta(&name) {
+            let args =
+                self.function_args_to_expr(function.args, schema, planner_context)?;
+            return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args)));
+        }
+
         // next, scalar built-in
         if let Ok(fun) = BuiltinScalarFunction::from_str(&name) {
             let args =
@@ -139,14 +146,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 )));
             };
 
-            // finally, user-defined functions (UDF) and UDAF
-            if let Some(fm) = self.schema_provider.get_function_meta(&name) {
-                let args =
-                    self.function_args_to_expr(function.args, schema, planner_context)?;
-                return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args)));
-            }
-
-            // User defined aggregate functions
+            // User defined aggregate functions (UDAF)
             if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) {
                 let args =
                     self.function_args_to_expr(function.args, schema, planner_context)?;