You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2020/08/26 02:15:40 UTC

[arrow] branch master updated: ARROW-9849: [Rust] [DataFusion] Simplified argument types of ScalarFunctions.

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new d02e166  ARROW-9849: [Rust] [DataFusion] Simplified argument types of ScalarFunctions.
d02e166 is described below

commit d02e16646cb3ee2a568dfeeaa68a0f75986592c8
Author: Jorge C. Leitao <jo...@gmail.com>
AuthorDate: Tue Aug 25 20:15:16 2020 -0600

    ARROW-9849: [Rust] [DataFusion] Simplified argument types of ScalarFunctions.
    
    Deprecates "Field" as argument to the UDF declaration, since we are only using its type.
    
    This is a spin-off of #8032 with a much smaller scope, as the other one is getting to large to handle.
    
    Closes #8045 from jorgecarleitao/clean_args
    
    Authored-by: Jorge C. Leitao <jo...@gmail.com>
    Signed-off-by: Andy Grove <an...@gmail.com>
---
 rust/datafusion/src/execution/context.rs                       |  5 +----
 .../datafusion/src/execution/physical_plan/math_expressions.rs |  6 +++---
 rust/datafusion/src/execution/physical_plan/mod.rs             |  4 ++--
 rust/datafusion/src/execution/physical_plan/udf.rs             | 10 +++++-----
 rust/datafusion/src/optimizer/type_coercion.rs                 |  3 +--
 rust/datafusion/src/sql/planner.rs                             |  8 +++-----
 rust/datafusion/tests/sql.rs                                   |  2 +-
 7 files changed, 16 insertions(+), 22 deletions(-)

diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs
index 7b0c8e9..e3d5820 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -996,10 +996,7 @@ mod tests {
 
         let my_add = ScalarFunction::new(
             "my_add",
-            vec![
-                Field::new("a", DataType::Int32, true),
-                Field::new("b", DataType::Int32, true),
-            ],
+            vec![DataType::Int32, DataType::Int32],
             DataType::Int32,
             myfunc,
         );
diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs
index 97098d6..ea40ac5 100644
--- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs
+++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs
@@ -21,7 +21,7 @@ use crate::error::ExecutionError;
 use crate::execution::physical_plan::udf::ScalarFunction;
 
 use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder};
-use arrow::datatypes::{DataType, Field};
+use arrow::datatypes::DataType;
 
 use std::sync::Arc;
 
@@ -29,7 +29,7 @@ macro_rules! math_unary_function {
     ($NAME:expr, $FUNC:ident) => {
         ScalarFunction::new(
             $NAME,
-            vec![Field::new("n", DataType::Float64, true)],
+            vec![DataType::Float64],
             DataType::Float64,
             Arc::new(|args: &[ArrayRef]| {
                 let n = &args[0].as_any().downcast_ref::<Float64Array>();
@@ -86,7 +86,7 @@ mod tests {
         execution::context::ExecutionContext,
         logicalplan::{col, sqrt, LogicalPlanBuilder},
     };
-    use arrow::datatypes::Schema;
+    use arrow::datatypes::{Field, Schema};
 
     #[test]
     fn cast_i8_input() -> Result<()> {
diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs
index f79907d..780cdba 100644
--- a/rust/datafusion/src/execution/physical_plan/mod.rs
+++ b/rust/datafusion/src/execution/physical_plan/mod.rs
@@ -26,7 +26,7 @@ use crate::error::Result;
 use crate::execution::context::ExecutionContextState;
 use crate::logicalplan::{LogicalPlan, ScalarValue};
 use arrow::array::ArrayRef;
-use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use arrow::datatypes::{DataType, Schema, SchemaRef};
 use arrow::{
     compute::kernels::length::length,
     record_batch::{RecordBatch, RecordBatchReader},
@@ -138,7 +138,7 @@ pub trait Accumulator: Debug {
 pub fn scalar_functions() -> Vec<ScalarFunction> {
     let mut udfs = vec![ScalarFunction::new(
         "length",
-        vec![Field::new("n", DataType::Utf8, true)],
+        vec![DataType::Utf8],
         DataType::UInt32,
         Arc::new(|args: &[ArrayRef]| Ok(Arc::new(length(args[0].as_ref())?))),
     )];
diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs
index fb777f9..ca59087 100644
--- a/rust/datafusion/src/execution/physical_plan/udf.rs
+++ b/rust/datafusion/src/execution/physical_plan/udf.rs
@@ -20,7 +20,7 @@
 use std::fmt;
 
 use arrow::array::ArrayRef;
-use arrow::datatypes::{DataType, Field, Schema};
+use arrow::datatypes::{DataType, Schema};
 
 use crate::error::Result;
 use crate::execution::physical_plan::PhysicalExpr;
@@ -38,7 +38,7 @@ pub struct ScalarFunction {
     /// Function name
     pub name: String,
     /// Function argument meta-data
-    pub args: Vec<Field>,
+    pub arg_types: Vec<DataType>,
     /// Return type
     pub return_type: DataType,
     /// UDF implementation
@@ -61,7 +61,7 @@ impl Debug for ScalarFunction {
     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
         f.debug_struct("ScalarFunction")
             .field("name", &self.name)
-            .field("args", &self.args)
+            .field("arg_types", &self.arg_types)
             .field("return_type", &self.return_type)
             .field("fun", &"<FUNC>")
             .finish()
@@ -72,13 +72,13 @@ impl ScalarFunction {
     /// Create a new ScalarFunction
     pub fn new(
         name: &str,
-        args: Vec<Field>,
+        arg_types: Vec<DataType>,
         return_type: DataType,
         fun: ScalarUdf,
     ) -> Self {
         Self {
             name: name.to_owned(),
-            args,
+            arg_types,
             return_type,
             fun,
         }
diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs
index 22d49b5..65940ba 100644
--- a/rust/datafusion/src/optimizer/type_coercion.rs
+++ b/rust/datafusion/src/optimizer/type_coercion.rs
@@ -69,9 +69,8 @@ where
                 match self.scalar_functions.lookup(name) {
                     Some(func_meta) => {
                         for i in 0..expressions.len() {
-                            let field = &func_meta.args[i];
                             let actual_type = expressions[i].get_type(schema)?;
-                            let required_type = field.data_type();
+                            let required_type = &func_meta.arg_types[i];
                             if &actual_type != required_type {
                                 // attempt to coerce using numerical coercion
                                 // todo: also try string coercion.
diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs
index 8f366de..f627d05 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -520,10 +520,8 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
 
                             let mut safe_args: Vec<Expr> = vec![];
                             for i in 0..rex_args.len() {
-                                safe_args.push(
-                                    rex_args[i]
-                                        .cast_to(fm.args[i].data_type(), schema)?,
-                                );
+                                safe_args
+                                    .push(rex_args[i].cast_to(&fm.arg_types[i], schema)?);
                             }
 
                             Ok(Expr::ScalarFunction {
@@ -908,7 +906,7 @@ mod tests {
             match name {
                 "sqrt" => Some(Arc::new(ScalarFunction::new(
                     "sqrt",
-                    vec![Field::new("n", DataType::Float64, false)],
+                    vec![DataType::Float64],
                     DataType::Float64,
                     Arc::new(|_| Err(ExecutionError::NotImplemented("".to_string()))),
                 ))),
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 43ac27a..15120d7 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -220,7 +220,7 @@ fn create_ctx() -> Result<ExecutionContext> {
     // register a custom UDF
     ctx.register_udf(ScalarFunction::new(
         "custom_sqrt",
-        vec![Field::new("n", DataType::Float64, true)],
+        vec![DataType::Float64],
         DataType::Float64,
         Arc::new(custom_sqrt),
     ));