You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/30 22:06:10 UTC
(arrow-datafusion) branch main updated: Implement Aliases for ScalarUDF (#8360)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d45cf00771 Implement Aliases for ScalarUDF (#8360)
d45cf00771 is described below
commit d45cf00771064ffea934c476eb12b86eb4ad75b1
Author: Tan Wei <co...@tanweime.com>
AuthorDate: Fri Dec 1 06:06:05 2023 +0800
Implement Aliases for ScalarUDF (#8360)
* Implement Aliases for ScalarUDF
Signed-off-by: veeupup <co...@tanweime.com>
* fix comments
Signed-off-by: veeupup <co...@tanweime.com>
---------
Signed-off-by: veeupup <co...@tanweime.com>
---
datafusion/core/src/execution/context/mod.rs | 11 +++++--
.../user_defined/user_defined_scalar_functions.rs | 37 ++++++++++++++++++++++
datafusion/expr/src/udf.rs | 18 +++++++++++
3 files changed, 64 insertions(+), 2 deletions(-)
diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs
index 46388f990a..dbebedce3c 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -810,9 +810,16 @@ impl SessionContext {
///
/// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
/// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
+ /// Any functions registered with the udf name or its aliases will be overwritten with this new function
pub fn register_udf(&self, f: ScalarUDF) {
- self.state
- .write()
+ let mut state = self.state.write();
+ let aliases = f.aliases();
+ for alias in aliases {
+ state
+ .scalar_functions
+ .insert(alias.to_string(), Arc::new(f.clone()));
+ }
+ state
.scalar_functions
.insert(f.name().to_string(), Arc::new(f));
}
diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index 1c7e713729..985b0bd5bc 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -341,6 +341,43 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_user_defined_functions_with_alias() -> Result<()> {
+ let ctx = SessionContext::new();
+ let arr = Int32Array::from(vec![1]);
+ let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
+ ctx.register_batch("t", batch).unwrap();
+
+ let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
+ let myfunc = make_scalar_function(myfunc);
+
+ let udf = create_udf(
+ "dummy",
+ vec![DataType::Int32],
+ Arc::new(DataType::Int32),
+ Volatility::Immutable,
+ myfunc,
+ )
+ .with_aliases(vec!["dummy_alias"]);
+
+ ctx.register_udf(udf);
+
+ let expected = [
+ "+------------+",
+ "| dummy(t.i) |",
+ "+------------+",
+ "| 1 |",
+ "+------------+",
+ ];
+ let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?;
+ assert_batches_eq!(expected, &result);
+
+ let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?;
+ assert_batches_eq!(expected, &alias_result);
+
+ Ok(())
+}
+
fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index bc910b928a..3a18ca2d25 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -49,6 +49,8 @@ pub struct ScalarUDF {
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
fun: ScalarFunctionImplementation,
+ /// Optional aliases for the function. This list should NOT include the value of `name` as well
+ aliases: Vec<String>,
}
impl Debug for ScalarUDF {
@@ -89,9 +91,20 @@ impl ScalarUDF {
signature: signature.clone(),
return_type: return_type.clone(),
fun: fun.clone(),
+ aliases: vec![],
}
}
+ /// Adds additional names that can be used to invoke this function, in addition to `name`
+ pub fn with_aliases(
+ mut self,
+ aliases: impl IntoIterator<Item = &'static str>,
+ ) -> Self {
+ self.aliases
+ .extend(aliases.into_iter().map(|s| s.to_string()));
+ self
+ }
+
/// creates a logical expression with a call of the UDF
/// This utility allows using the UDF without requiring access to the registry.
pub fn call(&self, args: Vec<Expr>) -> Expr {
@@ -106,6 +119,11 @@ impl ScalarUDF {
&self.name
}
+ /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
+ pub fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+
/// Returns this function's signature (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature