You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/30 22:06:10 UTC

(arrow-datafusion) branch main updated: Implement Aliases for ScalarUDF (#8360)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d45cf00771 Implement Aliases for ScalarUDF (#8360)
d45cf00771 is described below

commit d45cf00771064ffea934c476eb12b86eb4ad75b1
Author: Tan Wei <co...@tanweime.com>
AuthorDate: Fri Dec 1 06:06:05 2023 +0800

    Implement Aliases for ScalarUDF (#8360)
    
    * Implement Aliases for ScalarUDF
    
    Signed-off-by: veeupup <co...@tanweime.com>
    
    * fix comments
    
    Signed-off-by: veeupup <co...@tanweime.com>
    
    ---------
    
    Signed-off-by: veeupup <co...@tanweime.com>
---
 datafusion/core/src/execution/context/mod.rs       | 11 +++++--
 .../user_defined/user_defined_scalar_functions.rs  | 37 ++++++++++++++++++++++
 datafusion/expr/src/udf.rs                         | 18 +++++++++++
 3 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs
index 46388f990a..dbebedce3c 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -810,9 +810,16 @@ impl SessionContext {
     ///
     /// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
     /// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
+    /// Any functions registered with the udf name or its aliases will be overwritten with this new function
     pub fn register_udf(&self, f: ScalarUDF) {
-        self.state
-            .write()
+        let mut state = self.state.write();
+        let aliases = f.aliases();
+        for alias in aliases {
+            state
+                .scalar_functions
+                .insert(alias.to_string(), Arc::new(f.clone()));
+        }
+        state
             .scalar_functions
             .insert(f.name().to_string(), Arc::new(f));
     }
diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index 1c7e713729..985b0bd5bc 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -341,6 +341,43 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn test_user_defined_functions_with_alias() -> Result<()> {
+    let ctx = SessionContext::new();
+    let arr = Int32Array::from(vec![1]);
+    let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
+    ctx.register_batch("t", batch).unwrap();
+
+    let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
+    let myfunc = make_scalar_function(myfunc);
+
+    let udf = create_udf(
+        "dummy",
+        vec![DataType::Int32],
+        Arc::new(DataType::Int32),
+        Volatility::Immutable,
+        myfunc,
+    )
+    .with_aliases(vec!["dummy_alias"]);
+
+    ctx.register_udf(udf);
+
+    let expected = [
+        "+------------+",
+        "| dummy(t.i) |",
+        "+------------+",
+        "| 1          |",
+        "+------------+",
+    ];
+    let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?;
+    assert_batches_eq!(expected, &result);
+
+    let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?;
+    assert_batches_eq!(expected, &alias_result);
+
+    Ok(())
+}
+
 fn create_udf_context() -> SessionContext {
     let ctx = SessionContext::new();
     // register a custom UDF
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index bc910b928a..3a18ca2d25 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -49,6 +49,8 @@ pub struct ScalarUDF {
     /// the batch's row count (so that the generative zero-argument function can know
     /// the result array size).
     fun: ScalarFunctionImplementation,
+    /// Optional aliases for the function. This list should NOT include the value of `name` as well
+    aliases: Vec<String>,
 }
 
 impl Debug for ScalarUDF {
@@ -89,9 +91,20 @@ impl ScalarUDF {
             signature: signature.clone(),
             return_type: return_type.clone(),
             fun: fun.clone(),
+            aliases: vec![],
         }
     }
 
+    /// Adds additional names that can be used to invoke this function, in addition to `name`
+    pub fn with_aliases(
+        mut self,
+        aliases: impl IntoIterator<Item = &'static str>,
+    ) -> Self {
+        self.aliases
+            .extend(aliases.into_iter().map(|s| s.to_string()));
+        self
+    }
+
     /// creates a logical expression with a call of the UDF
     /// This utility allows using the UDF without requiring access to the registry.
     pub fn call(&self, args: Vec<Expr>) -> Expr {
@@ -106,6 +119,11 @@ impl ScalarUDF {
         &self.name
     }
 
+    /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
+    pub fn aliases(&self) -> &[String] {
+        &self.aliases
+    }
+
     /// Returns this function's signature (what input types are accepted)
     pub fn signature(&self) -> &Signature {
         &self.signature